refactor: 修改训练脚本
This commit is contained in:
parent
c74fbf84b7
commit
2331713fde
|
|
@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from khaosz.data import DatasetLoader
|
||||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
from khaosz.trainer import Trainer, SchedulerFactory
|
from khaosz.trainer import Trainer, SchedulerFactory
|
||||||
from khaosz.data import DatasetLoader
|
from khaosz.parallel import get_rank
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
|
@ -58,7 +59,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def ddp_wrap(model: nn.Module):
|
def ddp_wrap(model: nn.Module):
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
local_rank = get_rank()
|
||||||
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
||||||
ddp_model = DDP(
|
ddp_model = DDP(
|
||||||
model,
|
model,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue