From 2331713fde5bee370ae389252cb83c7a96a4e04a Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 5 Mar 2026 14:40:26 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BF=AE=E6=94=B9=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/train.py b/tools/train.py index 64dddb7..387649a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP from typing import List, Optional from functools import partial +from khaosz.data import DatasetLoader from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, SchedulerFactory -from khaosz.data import DatasetLoader +from khaosz.parallel import get_rank def parse_args() -> argparse.Namespace: @@ -58,7 +59,7 @@ def parse_args() -> argparse.Namespace: return args def ddp_wrap(model: nn.Module): - local_rank = int(os.environ.get("LOCAL_RANK", 0)) + local_rank = get_rank() model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16) ddp_model = DDP( model,