feat(trainer): 支持分布式训练配置与检查点加载优化

This commit is contained in:
ViperEkura 2025-12-19 19:34:39 +08:00
parent eab7a51bb6
commit 573f041c51
8 changed files with 67 additions and 27 deletions

View File

@ -4,7 +4,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from dataclasses import dataclass, field
from typing import Optional
from typing import Callable, Optional
@dataclass
@ -88,6 +88,22 @@ class TrainConfig:
default=1,
metadata={"help": "Number of processes for distributed training."}
)
backend: str = field(
default="nccl",
metadata={"help": "Distributed training backend."}
)
master_addr: str = field(
default="localhost",
metadata={"help": "Master address for distributed training."}
)
master_port: str = field(
default="29500",
metadata={"help": "Master port for distributed training."}
)
parallel_fn: Optional[Callable] = field(
default=None,
metadata={"help": "Parallel function for training."}
)
# others
extra_kwargs: dict = field(

View File

@ -1,4 +1,5 @@
from khaosz.trainer.trainer import Trainer
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory
from khaosz.trainer.schedule import SchedulerFactory
@ -15,6 +16,9 @@ __all__ = [
# trainer
"Trainer",
# checkpoint
"Checkpoint",
# factory
"StrategyFactory",
"SchedulerFactory",

View File

@ -38,19 +38,23 @@ class Checkpoint:
if save_metric_plot and self.metrics:
self._plot_metrics()
def load(self, save_dir: str) -> "Checkpoint":
if not os.path.exists(save_dir):
raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.")
@classmethod
def load(cls, save_dir: str) -> "Checkpoint":
checkpoint_path = os.path.join(save_dir, "train_state.pkl")
with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.")
with open(checkpoint_path, "rb") as f:
train_state = pkl.load(f)
self.epoch = train_state["epoch"]
self.iteration = train_state["iteration"]
self.metrics = train_state["metrics"]
self.optimizer_state = train_state["optimizer_state"]
self.scheduler_state = train_state["scheduler_state"]
return self
return cls(
optimizer_state=train_state["optimizer_state"],
scheduler_state=train_state["scheduler_state"],
epoch=train_state["epoch"],
iteration=train_state["iteration"],
metrics=train_state["metrics"]
)
def _plot_metrics(self):
for metric_name, metric_value in self.metrics.items():

View File

@ -99,7 +99,7 @@ class CheckpointCallback(TrainCallback):
@only_on_rank(0)
def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}")
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
context.checkpoint = Checkpoint(
context.optimizer.state_dict(),
context.scheduler.state_dict(),

View File

@ -89,5 +89,12 @@ class TrainContextBuilder:
)
return self
def with_parallel_fn(self) -> Self:
fn = self.config.parallel_fn
if fn is not None:
self._context.model = fn(self._context.model)
return self
def build(self) -> TrainContext:
return self._context

View File

@ -10,6 +10,7 @@ from khaosz.trainer.train_callback import (
)
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.parallel.setup import spawn_parallel_fn
logger = logging.getLogger(__name__)
@ -37,6 +38,7 @@ class Trainer:
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.with_parallel_fn()
.build())
def _call_callbacks(self, method_name: str, context: TrainContext):
@ -45,7 +47,18 @@ class Trainer:
if method:
method(context)
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
def train(self, checkpoint: Optional[Checkpoint] = None):
config = self.train_config
spawn_parallel_fn(
self._train_impl,
backend=config.backend,
world_size=config.nprocs,
master_addr=config.master_addr,
master_port=config.master_port,
checkpoint=checkpoint
)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint)
self._call_callbacks('on_train_begin', context)
@ -85,4 +98,3 @@ class Trainer:
raise
finally:
self._call_callbacks('on_train_end', context)
return context.checkpoint

View File

@ -1,3 +1,4 @@
import os
import torch
import numpy as np
from khaosz.config import *
@ -31,10 +32,14 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
checkpoint = None
try:
checkpoint = trainer.train()
assert checkpoint.iteration == 2
except Exception:
# Handle any exceptions
pass
checkpoint = trainer.train(checkpoint)
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
checkpoint = Checkpoint.load(load_dir)
trainer.train(checkpoint)
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10

View File

@ -8,7 +8,6 @@ from torch.optim import AdamW
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory
from khaosz.data import DatasetLoader
from khaosz.parallel import get_current_device, spawn_parallel_fn
def parse_args() -> argparse.Namespace:
@ -96,8 +95,6 @@ def train(
window_size = parameter.config.m_len
model = parameter.model
current_device = get_current_device()
model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16))
kwargs = {
"dpo_beta": dpo_beta,
@ -150,6 +147,7 @@ def train(
pin_memory=pin_memory,
nprocs=nprocs,
extra_kwargs=kwargs,
parallel_fn=fsdp_wrap
)
trainer = Trainer(train_config)
@ -158,10 +156,4 @@ def train(
if __name__ == "__main__":
args = parse_args()
spawn_parallel_fn(
train,
world_size=args.nprocs,
backend="nccl",
**vars(args)
)
train(**vars(args))