refactor: 修改训练脚本

This commit is contained in:
ViperEkura 2026-03-05 14:40:26 +08:00
parent c74fbf84b7
commit 2331713fde
1 changed files with 3 additions and 2 deletions

View File

@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from typing import List, Optional from typing import List, Optional
from functools import partial from functools import partial
from khaosz.data import DatasetLoader
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory from khaosz.trainer import Trainer, SchedulerFactory
from khaosz.data import DatasetLoader from khaosz.parallel import get_rank
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
@ -58,7 +59,7 @@ def parse_args() -> argparse.Namespace:
return args return args
def ddp_wrap(model: nn.Module): def ddp_wrap(model: nn.Module):
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = get_rank()
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16) model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
ddp_model = DDP( ddp_model = DDP(
model, model,