diff --git a/tools/train.py b/tools/train.py index 64dddb7..387649a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP from typing import List, Optional from functools import partial +from khaosz.data import DatasetLoader from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, SchedulerFactory -from khaosz.data import DatasetLoader +from khaosz.parallel import get_rank def parse_args() -> argparse.Namespace: @@ -58,7 +59,7 @@ def parse_args() -> argparse.Namespace: return args def ddp_wrap(model: nn.Module): - local_rank = int(os.environ.get("LOCAL_RANK", 0)) + local_rank = get_rank() model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16) ddp_model = DDP( model,