feat(trainer): 支持分布式训练配置与检查点加载优化

This commit is contained in:
ViperEkura 2025-12-19 19:34:39 +08:00
parent eab7a51bb6
commit 573f041c51
8 changed files with 67 additions and 27 deletions

View File

@ -4,7 +4,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Callable, Optional
@dataclass @dataclass
@ -88,6 +88,22 @@ class TrainConfig:
default=1, default=1,
metadata={"help": "Number of processes for distributed training."} metadata={"help": "Number of processes for distributed training."}
) )
backend: str = field(
default="nccl",
metadata={"help": "Distributed training backend."}
)
master_addr: str = field(
default="localhost",
metadata={"help": "Master address for distributed training."}
)
master_port: str = field(
default="29500",
metadata={"help": "Master port for distributed training."}
)
parallel_fn: Optional[Callable] = field(
default=None,
metadata={"help": "Parallel function for training."}
)
# others # others
extra_kwargs: dict = field( extra_kwargs: dict = field(

View File

@ -1,4 +1,5 @@
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
from khaosz.trainer.checkpoint import Checkpoint
from khaosz.trainer.strategy import StrategyFactory from khaosz.trainer.strategy import StrategyFactory
from khaosz.trainer.schedule import SchedulerFactory from khaosz.trainer.schedule import SchedulerFactory
@ -15,6 +16,9 @@ __all__ = [
# trainer # trainer
"Trainer", "Trainer",
# checkpoint
"Checkpoint",
# factory # factory
"StrategyFactory", "StrategyFactory",
"SchedulerFactory", "SchedulerFactory",

View File

@ -38,19 +38,23 @@ class Checkpoint:
if save_metric_plot and self.metrics: if save_metric_plot and self.metrics:
self._plot_metrics() self._plot_metrics()
def load(self, save_dir: str) -> "Checkpoint": @classmethod
if not os.path.exists(save_dir): def load(cls, save_dir: str) -> "Checkpoint":
raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.") checkpoint_path = os.path.join(save_dir, "train_state.pkl")
with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f: if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.")
with open(checkpoint_path, "rb") as f:
train_state = pkl.load(f) train_state = pkl.load(f)
self.epoch = train_state["epoch"]
self.iteration = train_state["iteration"]
self.metrics = train_state["metrics"]
self.optimizer_state = train_state["optimizer_state"]
self.scheduler_state = train_state["scheduler_state"]
return self return cls(
optimizer_state=train_state["optimizer_state"],
scheduler_state=train_state["scheduler_state"],
epoch=train_state["epoch"],
iteration=train_state["iteration"],
metrics=train_state["metrics"]
)
def _plot_metrics(self): def _plot_metrics(self):
for metric_name, metric_value in self.metrics.items(): for metric_name, metric_value in self.metrics.items():

View File

@ -99,7 +99,7 @@ class CheckpointCallback(TrainCallback):
@only_on_rank(0) @only_on_rank(0)
def _save_checkpoint(self, context: 'TrainContext'): def _save_checkpoint(self, context: 'TrainContext'):
save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}") save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}")
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
context.optimizer.state_dict(), context.optimizer.state_dict(),
context.scheduler.state_dict(), context.scheduler.state_dict(),

View File

@ -89,5 +89,12 @@ class TrainContextBuilder:
) )
return self return self
def with_parallel_fn(self) -> Self:
fn = self.config.parallel_fn
if fn is not None:
self._context.model = fn(self._context.model)
return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
return self._context return self._context

View File

@ -10,6 +10,7 @@ from khaosz.trainer.train_callback import (
) )
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.trainer.checkpoint import Checkpoint from khaosz.trainer.checkpoint import Checkpoint
from khaosz.parallel.setup import spawn_parallel_fn
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,6 +38,7 @@ class Trainer:
.with_checkpoint(checkpoint) .with_checkpoint(checkpoint)
.with_dataloader() .with_dataloader()
.with_strategy() .with_strategy()
.with_parallel_fn()
.build()) .build())
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
@ -44,8 +46,19 @@ class Trainer:
method = getattr(callback, method_name, None) method = getattr(callback, method_name, None)
if method: if method:
method(context) method(context)
def train(self, checkpoint: Optional[Checkpoint] = None):
config = self.train_config
spawn_parallel_fn(
self._train_impl,
backend=config.backend,
world_size=config.nprocs,
master_addr=config.master_addr,
master_port=config.master_port,
checkpoint=checkpoint
)
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint) context = self._build_context(checkpoint)
self._call_callbacks('on_train_begin', context) self._call_callbacks('on_train_begin', context)
@ -84,5 +97,4 @@ class Trainer:
self._call_callbacks('on_error', context) self._call_callbacks('on_error', context)
raise raise
finally: finally:
self._call_callbacks('on_train_end', context) self._call_callbacks('on_train_end', context)
return context.checkpoint

View File

@ -1,3 +1,4 @@
import os
import torch import torch
import numpy as np import numpy as np
from khaosz.config import * from khaosz.config import *
@ -31,10 +32,14 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
checkpoint = None checkpoint = None
try: try:
checkpoint = trainer.train() checkpoint = trainer.train()
assert checkpoint.iteration == 2
except Exception: except Exception:
# Handle any exceptions # Handle any exceptions
pass pass
checkpoint = trainer.train(checkpoint) load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
checkpoint = Checkpoint.load(load_dir)
trainer.train(checkpoint)
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10 assert checkpoint.iteration == 10

View File

@ -8,7 +8,6 @@ from torch.optim import AdamW
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory from khaosz.trainer import Trainer, SchedulerFactory
from khaosz.data import DatasetLoader from khaosz.data import DatasetLoader
from khaosz.parallel import get_current_device, spawn_parallel_fn
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
@ -96,8 +95,6 @@ def train(
window_size = parameter.config.m_len window_size = parameter.config.m_len
model = parameter.model model = parameter.model
current_device = get_current_device()
model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16))
kwargs = { kwargs = {
"dpo_beta": dpo_beta, "dpo_beta": dpo_beta,
@ -150,6 +147,7 @@ def train(
pin_memory=pin_memory, pin_memory=pin_memory,
nprocs=nprocs, nprocs=nprocs,
extra_kwargs=kwargs, extra_kwargs=kwargs,
parallel_fn=fsdp_wrap
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
@ -158,10 +156,4 @@ def train(
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
train(**vars(args))
spawn_parallel_fn(
train,
world_size=args.nprocs,
backend="nccl",
**vars(args)
)