refactor: 修改训练脚本
This commit is contained in:
parent
c74fbf84b7
commit
2331713fde
|
|
@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||
|
||||
from typing import List, Optional
|
||||
from functools import partial
|
||||
from khaosz.data import DatasetLoader
|
||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
from khaosz.trainer import Trainer, SchedulerFactory
|
||||
from khaosz.data import DatasetLoader
|
||||
from khaosz.parallel import get_rank
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
|
|
@ -58,7 +59,7 @@ def parse_args() -> argparse.Namespace:
|
|||
return args
|
||||
|
||||
def ddp_wrap(model: nn.Module):
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
local_rank = get_rank()
|
||||
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
||||
ddp_model = DDP(
|
||||
model,
|
||||
|
|
|
|||
Loading…
Reference in New Issue