From 573f041c510067f475e7d5367f4c76b81ffd4813 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 19 Dec 2025 19:34:39 +0800 Subject: [PATCH] =?UTF-8?q?feat(trainer):=20=E6=94=AF=E6=8C=81=E5=88=86?= =?UTF-8?q?=E5=B8=83=E5=BC=8F=E8=AE=AD=E7=BB=83=E9=85=8D=E7=BD=AE=E4=B8=8E?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=E7=82=B9=E5=8A=A0=E8=BD=BD=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/config/train_config.py | 18 +++++++++++++++++- khaosz/trainer/__init__.py | 4 ++++ khaosz/trainer/checkpoint.py | 24 ++++++++++++++---------- khaosz/trainer/train_callback.py | 2 +- khaosz/trainer/train_context.py | 7 +++++++ khaosz/trainer/trainer.py | 18 +++++++++++++++--- tests/test_early_stopping.py | 9 +++++++-- tools/train.py | 12 ++---------- 8 files changed, 67 insertions(+), 27 deletions(-) diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index 21a6ae4..434172a 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -4,7 +4,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from dataclasses import dataclass, field -from typing import Optional +from typing import Callable, Optional @dataclass @@ -88,6 +88,22 @@ class TrainConfig: default=1, metadata={"help": "Number of processes for distributed training."} ) + backend: str = field( + default="nccl", + metadata={"help": "Distributed training backend."} + ) + master_addr: str = field( + default="localhost", + metadata={"help": "Master address for distributed training."} + ) + master_port: str = field( + default="29500", + metadata={"help": "Master port for distributed training."} + ) + parallel_fn: Optional[Callable] = field( + default=None, + metadata={"help": "Parallel function for training."} + ) # others extra_kwargs: dict = field( diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index d856750..2e92aa4 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -1,4 +1,5 @@ from khaosz.trainer.trainer import Trainer +from khaosz.trainer.checkpoint import Checkpoint from khaosz.trainer.strategy import StrategyFactory from khaosz.trainer.schedule import SchedulerFactory @@ -15,6 +16,9 @@ __all__ = [ # trainer "Trainer", + # checkpoint + "Checkpoint", + # factory "StrategyFactory", "SchedulerFactory", diff --git a/khaosz/trainer/checkpoint.py b/khaosz/trainer/checkpoint.py index 6cd107a..a4ff101 100644 --- a/khaosz/trainer/checkpoint.py +++ b/khaosz/trainer/checkpoint.py @@ -38,19 +38,23 @@ class Checkpoint: if save_metric_plot and self.metrics: self._plot_metrics() - def load(self, save_dir: str) -> "Checkpoint": - if not os.path.exists(save_dir): - raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.") + @classmethod + def load(cls, save_dir: str) -> "Checkpoint": + checkpoint_path = os.path.join(save_dir, "train_state.pkl") - with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f: + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.") + + with open(checkpoint_path, "rb") as f: train_state = pkl.load(f) - self.epoch = train_state["epoch"] - self.iteration = train_state["iteration"] - self.metrics = train_state["metrics"] - self.optimizer_state = train_state["optimizer_state"] - self.scheduler_state = train_state["scheduler_state"] - return self + return cls( + optimizer_state=train_state["optimizer_state"], + scheduler_state=train_state["scheduler_state"], + epoch=train_state["epoch"], + iteration=train_state["iteration"], + metrics=train_state["metrics"] + ) def _plot_metrics(self): for metric_name, metric_value in self.metrics.items(): diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index a06bf1d..a44f0e2 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -99,7 +99,7 @@ class CheckpointCallback(TrainCallback): @only_on_rank(0) def _save_checkpoint(self, context: 'TrainContext'): - save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}") + save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}") context.checkpoint = Checkpoint( context.optimizer.state_dict(), context.scheduler.state_dict(), diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index d95e3f5..40c5e10 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -89,5 +89,12 @@ class TrainContextBuilder: ) return self + def with_parallel_fn(self) -> Self: + fn = self.config.parallel_fn + if fn is not None: + self._context.model = fn(self._context.model) + + return self + def build(self) -> TrainContext: return self._context \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 2ca75e8..e409845 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -10,6 +10,7 @@ from khaosz.trainer.train_callback import ( ) from khaosz.trainer.train_context import TrainContext, TrainContextBuilder from khaosz.trainer.checkpoint import Checkpoint +from khaosz.parallel.setup import spawn_parallel_fn logger = logging.getLogger(__name__) @@ -37,6 +38,7 @@ class Trainer: .with_checkpoint(checkpoint) .with_dataloader() .with_strategy() + .with_parallel_fn() .build()) def _call_callbacks(self, method_name: str, context: TrainContext): @@ -44,8 +46,19 @@ class Trainer: method = getattr(callback, method_name, None) if method: method(context) + + def train(self, checkpoint: Optional[Checkpoint] = None): + config = self.train_config + spawn_parallel_fn( + self._train_impl, + backend=config.backend, + world_size=config.nprocs, + master_addr=config.master_addr, + master_port=config.master_port, + checkpoint=checkpoint + ) - def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: + def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: context = self._build_context(checkpoint) self._call_callbacks('on_train_begin', context) @@ -84,5 +97,4 @@ class Trainer: self._call_callbacks('on_error', context) raise finally: - self._call_callbacks('on_train_end', context) - return context.checkpoint \ No newline at end of file + self._call_callbacks('on_train_end', context) \ No newline at end of file diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index 92f5d94..a7db7f0 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -1,3 +1,4 @@ +import os import torch import numpy as np from khaosz.config import * @@ -31,10 +32,14 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): checkpoint = None try: checkpoint = trainer.train() - assert checkpoint.iteration == 2 except Exception: # Handle any exceptions pass - checkpoint = trainer.train(checkpoint) + load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2") + checkpoint = Checkpoint.load(load_dir) + trainer.train(checkpoint) + + load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10") + checkpoint = Checkpoint.load(load_dir) assert checkpoint.iteration == 10 \ No newline at end of file diff --git a/tools/train.py b/tools/train.py index e8cf86b..358b12f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -8,7 +8,6 @@ from torch.optim import AdamW from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from khaosz.trainer import Trainer, SchedulerFactory from khaosz.data import DatasetLoader -from khaosz.parallel import get_current_device, spawn_parallel_fn def parse_args() -> argparse.Namespace: @@ -96,8 +95,6 @@ def train( window_size = parameter.config.m_len model = parameter.model - current_device = get_current_device() - model = fsdp_wrap(model.to(device=current_device, dtype=torch.bfloat16)) kwargs = { "dpo_beta": dpo_beta, @@ -150,6 +147,7 @@ def train( pin_memory=pin_memory, nprocs=nprocs, extra_kwargs=kwargs, + parallel_fn=fsdp_wrap ) trainer = Trainer(train_config) @@ -158,10 +156,4 @@ def train( if __name__ == "__main__": args = parse_args() - - spawn_parallel_fn( - train, - world_size=args.nprocs, - backend="nccl", - **vars(args) - ) \ No newline at end of file + train(**vars(args)) \ No newline at end of file