From c98b175cd5ac6752cef234cf5eee85de52888522 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 7 Dec 2025 21:23:05 +0800 Subject: [PATCH] =?UTF-8?q?refactor(trainer):=20=20=E4=BC=98=E5=8C=96train?= =?UTF-8?q?er=20=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/__init__.py | 2 - khaosz/api.py | 5 +- khaosz/config/__init__.py | 4 +- khaosz/config/param_config.py | 147 +------------------------------ khaosz/config/schedule_config.py | 12 ++- khaosz/config/train_config.py | 49 +++++++---- khaosz/trainer/checkpoint.py | 65 ++++++++++++++ khaosz/trainer/schedule.py | 2 +- khaosz/trainer/strategy.py | 43 +-------- khaosz/trainer/train_callback.py | 124 +++++++++++++------------- khaosz/trainer/train_context.py | 87 +++++++++--------- khaosz/trainer/trainer.py | 46 +++++----- tests/test_callbacks.py | 32 +++---- tests/test_early_stopping.py | 23 +++-- tests/test_module.py | 7 -- tests/test_train_config.py | 23 +++-- tests/test_train_strategy.py | 19 ++-- tools/train.py | 48 ++++------ 18 files changed, 314 insertions(+), 424 deletions(-) create mode 100644 khaosz/trainer/checkpoint.py diff --git a/khaosz/__init__.py b/khaosz/__init__.py index f8eb153..35312d0 100644 --- a/khaosz/__init__.py +++ b/khaosz/__init__.py @@ -4,7 +4,6 @@ __author__ = "ViperEkura" from khaosz.api import Khaosz from khaosz.config import ( ModelConfig, - ParameterLoader, TrainConfig, ) from khaosz.model.transformer import Transformer @@ -42,7 +41,6 @@ __all__ = [ "PriorityTextSplitter", "ModelConfig", - "ParameterLoader", "TrainConfig", "DatasetLoader", diff --git a/khaosz/api.py b/khaosz/api.py index 96fc988..5ce2349 100644 --- a/khaosz/api.py +++ b/khaosz/api.py @@ -9,12 +9,13 @@ from khaosz.inference.generator import ( RetrievalGenerator, EmbeddingEncoder ) -from khaosz.config.param_config import ParameterLoader +from khaosz.config.param_config import ModelParameter class Khaosz: def __init__(self, model_dir: str): - self.parameter = ParameterLoader.load(model_dir) + self.parameter = ModelParameter() + self.parameter.load(model_dir) def to(self, *args, **kwargs): self.parameter.to(*args, **kwargs) diff --git a/khaosz/config/__init__.py b/khaosz/config/__init__.py index baabd79..1caac27 100644 --- a/khaosz/config/__init__.py +++ b/khaosz/config/__init__.py @@ -1,5 +1,5 @@ from khaosz.config.model_config import ModelConfig -from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader +from khaosz.config.param_config import BaseModelIO, ModelParameter from khaosz.config.schedule_config import ScheduleConfig, CosineScheduleConfig, SGDRScheduleConfig from khaosz.config.train_config import TrainConfig @@ -7,8 +7,6 @@ from khaosz.config.train_config import TrainConfig __all__ = [ "BaseModelIO", "ModelParameter", - "Checkpoint", - "ParameterLoader", "ModelConfig", "TrainConfig", diff --git a/khaosz/config/param_config.py b/khaosz/config/param_config.py index 136d8bd..47120e5 100644 --- a/khaosz/config/param_config.py +++ b/khaosz/config/param_config.py @@ -1,11 +1,8 @@ -import pickle as pkl -import matplotlib.pyplot as plt -import safetensors.torch as st import torch.nn as nn -import torch.optim as optim +import safetensors.torch as st from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Self, Union +from typing import Optional, Self, Union from pathlib import Path from khaosz.data.tokenizer import BpeTokenizer @@ -63,7 +60,7 @@ class BaseModelIO: return self - def to(self, *args, **kwargs) -> Self: + def to(self, *args, **kwargs) -> "BaseModelIO": """Move model to device.""" if self.model is not None: self.model.to(*args, **kwargs) @@ -77,142 +74,6 @@ class ModelParameter(BaseModelIO): def save(self, save_dir: Union[str, Path]): self.save_components(save_dir) - def load(self, load_dir: Union[str, Path]) -> Self: + def load(self, load_dir: Union[str, Path]) -> "ModelParameter": return self.load_components(load_dir) - -@dataclass -class Checkpoint(BaseModelIO): - """Extended model parameters with training state.""" - - optimizer_state: Dict[str, Any] = field( - default=None, - metadata={"help": "Optimizer state."} - ) - scheduler_state: Dict[str, Any] = field( - default=None, - metadata={"help": "Sampler state."} - ) - loss_list: List[float] = field( - default_factory=list, - metadata={"help": "List of training losses."} - ) - epoch: int = field( - default=0, - metadata={"help": "Current epoch."} - ) - batch_iter: int = field( - default=0, - metadata={"help": "Current iteration."} - ) - - def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]: - dir_path = Path(directory) - return { - "loss_plot": dir_path / "loss_plot.png", - "training_state": dir_path / "training_state.pkl" - } - - def to_dict(self) -> Dict[str, Any]: - return { - "optimizer_state": self.optimizer_state, - "scheduler_state": self.scheduler_state, - "epoch": self.epoch, - "batch_iter": self.batch_iter, - "loss_list": self.loss_list, - } - - def from_dict(self, data: Dict[str, Any]) -> Self: - self.optimizer_state = data["optimizer_state"] - self.scheduler_state = data["scheduler_state"] - self.epoch = data["epoch"] - self.batch_iter = data["batch_iter"] - self.loss_list = data["loss_list"] - - def save_training_state(self, save_dir: Union[str, Path]): - paths = self._get_training_paths(save_dir) - - # Save loss plot - self._plot_loss(str(paths["loss_plot"])) - - # Save training state - with open(str(paths["training_state"]), "wb") as f: - pkl.dump(self.to_dict(), f) - - def load_training_state(self, load_dir: Union[str, Path]) -> Self: - paths = self._get_training_paths(load_dir) - - # Load training state - with open(str(paths["training_state"]), "rb") as f: - train_state = pkl.load(f) - - self.from_dict(train_state) - - return self - - def _plot_loss(self, save_path: str): - """Plot and save loss curve.""" - if not self.loss_list: - return - - batch_iter = len(self.loss_list) - - plt.figure(figsize=(10, 6)) - plt.plot(self.loss_list) - plt.title(f"Training Loss - Iteration {batch_iter}") - plt.xlabel("Batch") - plt.ylabel("Loss") - plt.grid(True) - plt.savefig(save_path, dpi=30, bbox_inches="tight") - plt.close() - - def save(self, save_dir: Union[str, Path]): - """Save complete checkpoint.""" - self.save_components(save_dir) - self.save_training_state(save_dir) - - def load(self, load_dir: Union[str, Path]) -> Self: - """Load complete checkpoint.""" - self.load_components(load_dir) - self.load_training_state(load_dir) - return self - - -class ParameterLoader: - """Factory class for loading model parameters or checkpoints.""" - - @staticmethod - def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]: - """Load either ModelParameter or Checkpoint based on directory contents.""" - load_dir = Path(load_dir) - - # Check for training-specific files - loss_file = load_dir / "loss.pkl" - has_training_data = loss_file.exists() - - # Create appropriate instance - if has_training_data: - checkpoint = Checkpoint() - checkpoint.load(str(load_dir)) - return checkpoint - else: - params = ModelParameter() - params.load(str(load_dir)) - return params - - @staticmethod - def create_checkpoint( - model: nn.Module, - tokenizer: BpeTokenizer, - config: ModelConfig, - loss_list: Optional[list[float]] = None, - optimizer: Optional[optim.Optimizer] = None, - ) -> Checkpoint: - """Convenience method to create a training checkpoint.""" - return Checkpoint( - model=model, - tokenizer=tokenizer, - config=config, - loss_list=loss_list or [], - optimizer_state=optimizer - ) \ No newline at end of file diff --git a/khaosz/config/schedule_config.py b/khaosz/config/schedule_config.py index 3724ab8..10c9d39 100644 --- a/khaosz/config/schedule_config.py +++ b/khaosz/config/schedule_config.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Dict +from typing import Any, Dict from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -39,7 +39,10 @@ class CosineScheduleConfig(ScheduleConfig): default=None, metadata={"help": "Total training steps for cosine schedule."} ) - schedule_type: Literal["cosine"] = "cosine" + + def __post_init__(self) -> None: + self.schedule_type = "cosine" + self.validate() def get_kwargs(self) -> Dict[str, Any]: if self.total_steps is None: @@ -68,7 +71,10 @@ class SGDRScheduleConfig(ScheduleConfig): default=2, metadata={"help": "Multiplier for cycle length growth."} ) - schedule_type: Literal["sgdr"] = "sgdr" + + def __post_init__(self) -> None: + self.schedule_type = "sgdr" + self.validate() def get_kwargs(self) -> Dict[str, Any]: return { diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index 49b0254..f089ed4 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -1,15 +1,20 @@ -from dataclasses import dataclass, field -from typing import Optional, TYPE_CHECKING +from torch import nn from torch.utils.data import Dataset from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler -if TYPE_CHECKING: - from khaosz.trainer.strategy import BaseStrategy +from dataclasses import dataclass, field +from typing import Optional @dataclass class TrainConfig: - strategy: "BaseStrategy" = field( + # basic setting + model: nn.Module = field( + default=None, + metadata={"help": "Model for training."} + ) + strategy: str = field( default=None, metadata={"help": "Training strategy."} ) @@ -21,9 +26,9 @@ class TrainConfig: default=None, metadata={"help": "Optimizer for training."} ) - checkpoint_dir: str = field( - default="./checkpoint", - metadata={"help": "Checkpoint directory."} + scheduler: LRScheduler = field( + default=None, + metadata={"help": "Scheduler for training."} ) n_epoch: int = field( default=1, @@ -33,6 +38,16 @@ class TrainConfig: default=4, metadata={"help": "Batch size for training."} ) + accumulation_steps: int = field( + default=1, + metadata={"help": "Number of iterations between steps."} + ) + max_grad_norm: float = field( + default=1.0, + metadata={"help": "Maximum gradient norm."} + ) + + # checkpoint setting start_epoch: int = field( default=0, metadata={"help": "Start epoch for training."} @@ -41,18 +56,14 @@ class TrainConfig: default=0, metadata={"help": "Start batch iteration for training."} ) + checkpoint_dir: str = field( + default="./checkpoint", + metadata={"help": "Checkpoint directory."} + ) checkpoint_interval: int = field( default=5000, metadata={"help": "Number of iterations between checkpoints."} ) - accumulation_steps: int = field( - default=1, - metadata={"help": "Number of iterations between steps."} - ) - max_grad_norm: float = field( - default=1.0, - metadata={"help": "Maximum gradient norm."} - ) # dataloader setting random_seed: int = field( @@ -76,4 +87,10 @@ class TrainConfig: nprocs: int = field( default=1, metadata={"help": "Number of processes for distributed training."} + ) + + # others + kwargs: dict = field( + default_factory=dict, + metadata={"help": "Other arguments."} ) \ No newline at end of file diff --git a/khaosz/trainer/checkpoint.py b/khaosz/trainer/checkpoint.py new file mode 100644 index 0000000..6cd107a --- /dev/null +++ b/khaosz/trainer/checkpoint.py @@ -0,0 +1,65 @@ +import os +import pickle as pkl +import matplotlib.pyplot as plt + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from typing import Dict, Optional + + +class Checkpoint: + def __init__( + self, + optimizer_state: Optimizer, + scheduler_state: LRScheduler, + epoch: int = 0, + iteration: int = 0, + metrics: Optional[Dict[str, list]] = None, + ): + self.optimizer_state = optimizer_state + self.scheduler_state = scheduler_state + self.epoch, self.iteration = epoch, iteration + self.metrics = metrics + + def save(self, save_dir: str, save_metric_plot=True) -> None: + os.makedirs(save_dir, exist_ok=True) + + train_state = { + "epoch": self.epoch, + "iteration": self.iteration, + "metrics": self.metrics, + "optimizer_state": self.optimizer_state, + "scheduler_state": self.scheduler_state, + } + + with open(os.path.join(save_dir, "train_state.pkl"), "wb") as f: + pkl.dump(train_state, f) + + if save_metric_plot and self.metrics: + self._plot_metrics() + + def load(self, save_dir: str) -> "Checkpoint": + if not os.path.exists(save_dir): + raise FileNotFoundError(f"Checkpoint directory {save_dir} does not exist.") + + with open(os.path.join(save_dir, "train_state.pkl"), "rb") as f: + train_state = pkl.load(f) + self.epoch = train_state["epoch"] + self.iteration = train_state["iteration"] + self.metrics = train_state["metrics"] + self.optimizer_state = train_state["optimizer_state"] + self.scheduler_state = train_state["scheduler_state"] + + return self + + def _plot_metrics(self): + for metric_name, metric_value in self.metrics.items(): + plt.figure(figsize=(10, 6)) + plt.plot(metric_value, label=metric_name) + plt.xlabel('Step') + plt.ylabel('Value') + plt.legend() + plt.grid(True, alpha=0.3) + + plt.savefig(f'{metric_name}.png', dpi=150, bbox_inches='tight') + plt.close() \ No newline at end of file diff --git a/khaosz/trainer/schedule.py b/khaosz/trainer/schedule.py index 2309b04..1b9c094 100644 --- a/khaosz/trainer/schedule.py +++ b/khaosz/trainer/schedule.py @@ -151,7 +151,7 @@ class SchedulerFactory: """ @staticmethod - def load_scheduler(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler: + def load(optimizer, scedule_config: ScheduleConfig) -> BaseScheduler: kwargs = scedule_config.get_kwargs() schedule_type = kwargs.pop("schedule_type") diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 9d79d84..21e57ac 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from typing import Any, Tuple, Callable, Dict, Union +from typing import Any, Callable, Dict, Union from abc import ABC, abstractmethod @@ -41,7 +41,7 @@ class BaseStrategy(ABC): def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: raise NotImplementedError - def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor: + def __call__(self, batch: Dict[str, Tensor]) -> Tensor: return self.compute_loss(batch) @@ -94,7 +94,7 @@ class DpoStrategy(BaseStrategy): self.pad_token_id = pad_token_id self.beta = beta - def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor: + def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: batch = move_to_device(batch, self.device) good_ids, bad_ids = batch["chosen"], batch["rejected"] good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"] @@ -115,41 +115,6 @@ class DpoStrategy(BaseStrategy): return dpo_loss -class PpoStrategy(BaseStrategy): - def __init__(self, model, pad_token_id, epsilon): - super().__init__(model) - ref_model = copy.deepcopy(self.model) - ref_model.requires_grad_(False) - ref_model.eval() - - self.ref_model = ref_model - self.pad_token_id = pad_token_id - self.epsilon = epsilon - - def ppo_clip_loss_masked( - self, - log_probs: Tensor, - old_log_probs: Tensor, - advantages: Tensor, - values: Tensor, - returns: Tensor, - mask: Tensor, - clip_eps: float=0.2, - ): - ratio = torch.exp(log_probs - old_log_probs) - surr1 = ratio * advantages - surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages - policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean() - - value_loss = F.mse_loss(values.masked_select(mask), - returns.masked_select(mask)) - - entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean() - entropy_loss = -entropy - return policy_loss, value_loss, entropy_loss - - - class StrategyFactory: def load(model, train_type, device, **kwargs): @@ -157,7 +122,7 @@ class StrategyFactory: "seq": lambda: SeqStrategy(model, device), "sft": lambda: SftStrategy(model, device), "dpo": lambda: DpoStrategy( - model, + model, device, kwargs.get("pad_token_id"), kwargs.get("dpo_beta") diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 4805e37..90ef9f6 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -5,10 +5,9 @@ import time from pathlib import Path from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ -from torch.optim.lr_scheduler import LambdaLR +from torch.optim.lr_scheduler import LRScheduler from typing import List, Optional, Protocol, TYPE_CHECKING -from khaosz.config import ScheduleConfig from khaosz.trainer.metric_util import ( grad_max, grad_min, @@ -17,9 +16,9 @@ from khaosz.trainer.metric_util import ( grad_std, grad_nan_num ) +from khaosz.trainer.checkpoint import Checkpoint if TYPE_CHECKING: - from khaosz.trainer.trainer import Trainer from khaosz.trainer.train_context import TrainContext @@ -28,31 +27,31 @@ class TrainCallback(Protocol): Callback interface for trainer. """ - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_train_begin(self, context: 'TrainContext'): """ Called at the beginning of training. """ - def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): + def on_train_end(self, context: 'TrainContext'): """ Called at the end of training. """ - def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_epoch_begin(self, context: 'TrainContext'): """ Called at the beginning of each epoch. """ - def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'): + def on_epoch_end(self, context: 'TrainContext'): """ Called at the end of each epoch. """ - def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_step_begin(self, context: 'TrainContext'): """ Called at the beginning of each step. """ - def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'): + def on_step_end(self, context: 'TrainContext'): """ Called at the end of each step.""" - def on_batch_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_batch_begin(self, context: 'TrainContext'): """ Called at the beginning of each batch. """ - def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): + def on_batch_end(self, context: 'TrainContext'): """ Called at the end of each batch. """ - def on_error(self, trainer: 'Trainer', context: 'TrainContext'): + def on_error(self, context: 'TrainContext'): """ Called when an error occurs during training. """ @@ -63,29 +62,27 @@ class GradientClippingCallback(TrainCallback): def __init__(self, max_grad_norm: float): self.max_grad_norm = max_grad_norm - def on_step_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_step_begin(self, context: 'TrainContext'): _ = context - clip_grad_norm_(trainer.parameter.model.parameters(), self.max_grad_norm) + clip_grad_norm_(context.model.parameters(), self.max_grad_norm) class SchedulerCallback(TrainCallback): """ Scheduler callback for trainer. """ - def __init__(self, schedule_config: ScheduleConfig): - self.schedule_config = schedule_config - self.scheduler: Optional[LambdaLR] = None + def __init__(self, scheduler: LRScheduler): + self.scheduler: LRScheduler = scheduler - def on_train_begin(self, trainer: 'Trainer', context: 'TrainContext'): - - for group in trainer.train_config.optimizer.param_groups: + def on_train_begin(self, context: 'TrainContext'): + for group in context.optimizer.param_groups: if "initial_lr" not in group: group["initial_lr"] = group["lr"] self.scheduler = context.scheduler - def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): - _ = trainer, context + def on_batch_end(self, context: 'TrainContext'): + _ = context if self.scheduler: self.scheduler.step() @@ -94,54 +91,59 @@ class CheckpointCallback(TrainCallback): """ Checkpoint callback for trainer. """ - def __init__(self, checkpoint_interval: int): - self.checkpoint_interval = checkpoint_interval + def __init__(self, interval: int, save_dir: str): + self.interval = interval + self.save_dir = save_dir + self.checkpoint = None self.last_ckpt_iter = 0 - def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): - save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.batch_iter}") - context.checkpoint.optimizer_state = context.optimizer.state_dict() - context.checkpoint.scheduler_state = context.scheduler.state_dict() - context.checkpoint.epoch = context.epoch - context.checkpoint.batch_iter = context.batch_iter - context.checkpoint.save(save_path) - self.last_ckpt_iter = context.batch_iter + def _save_checkpoint(self, context: 'TrainContext'): + save_path = os.path.join(self.save_dir, f"epoch_{context.epoch}iter_{context.iteration}") + self.checkpoint = Checkpoint( + context.optimizer.state_dict(), + context.scheduler.state_dict(), + context.epoch, + context.iteration + ) + self.checkpoint.save(save_path) + self.last_ckpt_iter = context.iteration - def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): - context.checkpoint.loss_list.append(context.loss) - - if context.batch_iter - self.last_ckpt_iter >= self.checkpoint_interval: - self._save_checkpoint(trainer, context) - - def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): - if context.batch_iter != self.last_ckpt_iter: - self._save_checkpoint(trainer, context) + def on_batch_end(self, context: 'TrainContext'): + if context.iteration - self.last_ckpt_iter >= self.interval: + self._save_checkpoint(context) + + def on_train_end(self, context: 'TrainContext'): + if context.iteration != self.last_ckpt_iter: + self._save_checkpoint(context) + + def on_error(self, context: 'TrainContext'): + self._save_checkpoint(context) class ProgressBarCallback(TrainCallback): """ Progress bar callback for trainer. """ - def __init__(self): + def __init__(self, num_epoch: int): + self.num_epoch = num_epoch self.progress_bar: tqdm = None - def on_epoch_begin(self, trainer: 'Trainer', context: 'TrainContext'): + def on_epoch_begin(self, context: 'TrainContext'): self.progress_bar = tqdm( context.dataloader, - desc=f"Epoch {context.epoch+1}/{trainer.train_config.n_epoch}", + desc=f"Epoch {context.epoch+1}/{self.num_epoch}", dynamic_ncols=True ) - def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): - _ = trainer + def on_batch_end(self, context: 'TrainContext'): self.progress_bar.set_postfix({ "loss": f"{context.loss:.4f}", "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}" }) self.progress_bar.update(1) - def on_epoch_end(self, trainer: 'Trainer', context: 'TrainContext'): - _ = trainer, context + def on_epoch_end(self, context: 'TrainContext'): + _ = context if self.progress_bar: self.progress_bar.close() @@ -177,13 +179,13 @@ class StepMonitorCallback(TrainCallback): self.log_dir.mkdir(parents=True, exist_ok=True) - def _handle_info(self, trainer: 'Trainer', context: 'TrainContext'): + def _handle_info(self, context: 'TrainContext'): """ Logs training information to console and file. """ log_data = { "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "epoch": context.epoch, - "iter": context.batch_iter, + "iter": context.iteration, "metrics": self.metrics, } @@ -193,34 +195,34 @@ class StepMonitorCallback(TrainCallback): elif metric == 'lr': log_data[metric] = context.optimizer.param_groups[-1]['lr'] elif metric == 'grad_norm': - log_data[metric] = grad_norm(trainer.parameter.model) + log_data[metric] = grad_norm(context.model) elif metric == 'grad_std': - log_data[metric] = grad_std(trainer.parameter.model) + log_data[metric] = grad_std(context.model) elif metric == 'grad_max': - log_data[metric] = grad_max(trainer.parameter.model) + log_data[metric] = grad_max(context.model) elif metric == 'grad_min': - log_data[metric] = grad_min(trainer.parameter.model) + log_data[metric] = grad_min(context.model) elif metric == 'grad_mean': - log_data[metric] = grad_mean(trainer.parameter.model) + log_data[metric] = grad_mean(context.model) elif metric == 'grad_nan_num': - log_data[metric] = grad_nan_num(trainer.parameter.model) + log_data[metric] = grad_nan_num(context.model) else: raise ValueError(f"Invalid metric: {metric}") return log_data - def _handle_log(self, trainer: 'Trainer', context: 'TrainContext'): + def _handle_log(self, context: 'TrainContext'): """ Logs training information to console and file. """ - log_data = self._handle_info(trainer, context) + log_data = self._handle_info(context) try: - log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.batch_iter}.json" + log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.iteration}.json" with open(log_file, 'a') as f: json.dump(log_data, f, indent=4) except Exception: raise - def on_step_end(self, trainer: 'Trainer', context: 'TrainContext'): + def on_step_end(self, context: 'TrainContext'): if self.step_num % self.log_interval == 0: - self._handle_log(trainer, context) + self._handle_log(context) self.step_num += 1 \ No newline at end of file diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index db36b5b..aba4a54 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -1,61 +1,60 @@ -from dataclasses import dataclass, field, fields -from typing import Optional, Self, TYPE_CHECKING +import torch.nn as nn from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader -from khaosz.config import Checkpoint -from khaosz.data import ResumableDistributedSampler -from khaosz.trainer.schedule import BaseScheduler, SchedulerFactory -from khaosz.parallel.utils import get_world_size, get_rank -if TYPE_CHECKING: - from khaosz.trainer.trainer import Trainer +from khaosz.data import ResumableDistributedSampler +from khaosz.trainer.checkpoint import Checkpoint +from khaosz.trainer.strategy import StrategyFactory, BaseStrategy +from khaosz.config.train_config import TrainConfig +from khaosz.parallel.utils import get_current_device, get_world_size, get_rank + +from dataclasses import dataclass, field +from typing import Optional, Self @dataclass class TrainContext: + model: nn.Module = field(default=None) + strategy: BaseStrategy = field(default=None) dataloader: DataLoader = field(default=None) optimizer: Optimizer = field(default=None) - scheduler: BaseScheduler = field(default=None) + scheduler: LRScheduler = field(default=None) checkpoint: Checkpoint = field(default=None) + epoch: int = field(default=0) - batch_iter: int = field(default=0) + iteration: int = field(default=0) loss: float = field(default=0.0) wolrd_size: int = field(default=1) rank: int = field(default=0) - - def asdict(self) -> dict: - return {field.name: getattr(self, field.name) - for field in fields(self)} class TrainContextBuilder: - def __init__(self, trainer: 'Trainer'): - self.trainer = trainer + def __init__(self, config: TrainConfig): + self.config = config self._context: TrainContext = None def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: self._context = TrainContext() if checkpoint is None: checkpoint = Checkpoint( - model=self.trainer.parameter.model, - tokenizer=self.trainer.parameter.tokenizer, - config=self.trainer.parameter.config, + optimizer_state=self.config.optimizer.state_dict(), + scheduler_state=self.config.scheduler.state_dict(), ) else: # resume from the assigned checkpoint or assigned iteration - self._context.epoch = max(checkpoint.epoch, self.trainer.train_config.start_epoch) - self._context.batch_iter = max(checkpoint.batch_iter, self.trainer.train_config.start_batch) + self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) + self._context.iteration = max(checkpoint.iteration, self.config.start_batch) self._context.checkpoint = checkpoint return self def with_optimizer(self) -> Self: + optimizer = self.config.optimizer if self._context is None: raise RuntimeError("Must call with_checkpoint() before with_optimizer()") - optimizer = self.trainer.train_config.optimizer - if self._context.checkpoint and self._context.checkpoint.optimizer_state: optimizer.load_state_dict(self._context.checkpoint.optimizer_state) @@ -67,13 +66,7 @@ class TrainContextBuilder: return self def with_scheduler(self) -> Self: - if not hasattr(self._context, 'optimizer') or self._context.optimizer is None: - raise RuntimeError("Must call with_optimizer() before with_scheduler()") - - optimizer = self.trainer.train_config.optimizer - schedule_config = self.trainer.schedule_config - scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config) - + scheduler = self.config.scheduler if self._context.checkpoint and self._context.checkpoint.scheduler_state: scheduler.load_state_dict(self._context.checkpoint.scheduler_state) @@ -85,29 +78,41 @@ class TrainContextBuilder: return self def with_dataloader(self) -> Self: - # fix: change batch level batch_iter to sample level offset - sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size + # fix: change batch level iteration to sample level offset + config = self.config + sampler_offset = self._context.iteration * config.batch_size resumeable_sampler = ResumableDistributedSampler( - data_source=self.trainer.train_config.dataset, + data_source=config.dataset, start_epoch=self._context.epoch, start_iter=sampler_offset, - seed=self.trainer.train_config.random_seed + seed=config.random_seed ) dataloader = DataLoader( - self.trainer.train_config.dataset, - batch_size=self.trainer.train_config.batch_size, + config.dataset, + batch_size=config.batch_size, sampler=resumeable_sampler, - num_workers=self.trainer.train_config.num_workers, - pin_memory=self.trainer.train_config.pin_memory, - prefetch_factor=self.trainer.train_config.prefetch_factor + num_workers=config.num_workers, + pin_memory=config.pin_memory, + prefetch_factor=config.prefetch_factor ) self._context.dataloader = dataloader return self + def with_strategy(self) -> Self: + device = get_current_device() + self._context.strategy = StrategyFactory.load( + model=self.config.model, + train_type=self.config.strategy, + device=device, + **self.config.kwargs + ) + return self + def build(self) -> TrainContext: - - if self.trainer.train_config.nprocs > 1: + self._context.model = self.config.model + + if self.config.nprocs > 1: self._context.wolrd_size = get_world_size() self._context.rank = get_rank() diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index be28e8c..c6a61c9 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,11 +1,6 @@ import logging from typing import Optional, List -from khaosz.config import ( - ModelParameter, - Checkpoint, - ScheduleConfig, - TrainConfig -) +from khaosz.config import TrainConfig from khaosz.trainer.train_callback import ( TrainCallback, ProgressBarCallback, @@ -13,7 +8,7 @@ from khaosz.trainer.train_callback import ( GradientClippingCallback, SchedulerCallback ) -from khaosz.trainer.train_context import TrainContext, TrainContextBuilder +from khaosz.trainer.train_context import TrainContext, TrainContextBuilder, Checkpoint logger = logging.getLogger(__name__) @@ -21,66 +16,65 @@ logger = logging.getLogger(__name__) class Trainer: def __init__( self, - parameter: ModelParameter, train_config: TrainConfig, - schedule_config: ScheduleConfig, callbacks: Optional[List[TrainCallback]] = None ): - self.parameter = parameter self.train_config = train_config - self.schedule_config = schedule_config self.callbacks = callbacks or self._get_default_callbacks() def _get_default_callbacks(self) -> List[TrainCallback]: + train_config = self.train_config return [ - ProgressBarCallback(), - CheckpointCallback(self.train_config.checkpoint_interval), - GradientClippingCallback(self.train_config.max_grad_norm), - SchedulerCallback(self.schedule_config), + ProgressBarCallback(train_config.n_epoch), + CheckpointCallback(train_config.checkpoint_interval, train_config.checkpoint_dir), + GradientClippingCallback(train_config.max_grad_norm), + SchedulerCallback(train_config.scheduler), ] - + def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: - return (TrainContextBuilder(self) + return (TrainContextBuilder(self.train_config) .with_checkpoint(checkpoint) .with_optimizer() .with_scheduler() .with_dataloader() + .with_strategy() .build()) def _call_callbacks(self, method_name: str, context: TrainContext): for callback in self.callbacks: method = getattr(callback, method_name, None) if method: - method(self, context) + method(context) def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: context = self._build_context(checkpoint) self._call_callbacks('on_train_begin', context) try: - self.parameter.model.train() + context.model.train() # 1.epoch for epoch in range(context.epoch, self.train_config.n_epoch): context.epoch = epoch self._call_callbacks('on_epoch_begin', context) for batch in context.dataloader: - if context.batch_iter % self.train_config.accumulation_steps == 0: + if context.iteration % self.train_config.accumulation_steps == 0: # 2. step self._call_callbacks('on_step_begin', context) - self.train_config.optimizer.step() - self.train_config.optimizer.zero_grad() + context.optimizer.step() + context.optimizer.zero_grad() self._call_callbacks('on_step_end', context) # 3. batch self._call_callbacks('on_batch_begin', context) - loss = self.train_config.strategy(batch) + loss = context.strategy(batch) context.loss = loss.item() - context.batch_iter += 1 + context.iteration += 1 # to make the loss normalized by accumulation steps - normalized_loss = loss / self.train_config.accumulation_steps - normalized_loss.backward() + stand_batch = self.train_config.accumulation_steps * self.train_config.nprocs + stand_loss = loss / stand_batch + stand_loss.backward() self._call_callbacks('on_batch_end', context) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index c93ccd5..d4c1581 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -5,10 +5,20 @@ from khaosz.trainer import * def test_callback_integration(base_test_env, random_dataset): """Test that all callbacks are properly integrated""" + schedule_config = CosineScheduleConfig( + warmup_steps=10, + total_steps=20 + ) + optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) + scheduler = SchedulerFactory.load(optimizer, schedule_config) + train_config = TrainConfig( + model=base_test_env["model"], + strategy='seq', dataset=random_dataset, optimizer=optimizer, + scheduler=scheduler, checkpoint_dir=base_test_env["test_dir"], n_epoch=1, batch_size=2, @@ -18,36 +28,26 @@ def test_callback_integration(base_test_env, random_dataset): random_seed=42 ) - schedule_config = CosineScheduleConfig( - warmup_steps=10, - total_steps=20 - ) + # Create custom callbacks to track calls callback_calls = [] class TrackingCallback(TrainCallback): - def on_train_begin(self, trainer, context): + def on_train_begin(self, context): callback_calls.append('on_train_begin') - def on_batch_end(self, trainer, context): + def on_batch_end(self, context): callback_calls.append('on_batch_end') - def on_epoch_end(self, trainer, context): + def on_epoch_end(self, context): callback_calls.append('on_epoch_end') - train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) - model_parameter = ModelParameter( - base_test_env["model"], - base_test_env["tokenizer"], - base_test_env["transformer_config"] - ) + trainer = Trainer( - model_parameter, train_config, - schedule_config, - callbacks=[TrackingCallback(), ProgressBarCallback()] + callbacks=[TrackingCallback()] ) trainer.train() diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index 782e74a..92f5d94 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -7,35 +7,34 @@ from khaosz.trainer import * def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior""" + schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) + scheduler = SchedulerFactory.load(optimizer, schedule_config) + train_config = TrainConfig( + strategy="seq", + scheduler=scheduler, + model=base_test_env["model"], dataset=early_stopping_dataset, optimizer=optimizer, checkpoint_dir=base_test_env["test_dir"], n_epoch=2, batch_size=2, - checkpoint_interval=2, + checkpoint_interval=1, accumulation_steps=2, random_seed=np.random.randint(1e4), ) - - train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) - model_parameter = ModelParameter( - base_test_env["model"], - base_test_env["tokenizer"], - base_test_env["transformer_config"] - ) - schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - trainer = Trainer(model_parameter, train_config, schedule_config) + + trainer = Trainer(train_config) # Should handle early stopping gracefully checkpoint = None try: checkpoint = trainer.train() - assert len(checkpoint.loss_list) == 2 + assert checkpoint.iteration == 2 except Exception: # Handle any exceptions pass checkpoint = trainer.train(checkpoint) - assert len(checkpoint.loss_list) == 10 \ No newline at end of file + assert checkpoint.iteration == 10 \ No newline at end of file diff --git a/tests/test_module.py b/tests/test_module.py index 710ce00..6ecb9ab 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -51,13 +51,6 @@ def test_env(request: pytest.FixtureRequest): shutil.rmtree(test_dir) -# parameter loader -def test_parameter_loader(test_env): - loaded_param = ParameterLoader.load(test_env["test_dir"]) - assert loaded_param.model is not None - assert loaded_param.tokenizer is not None - assert loaded_param.config == test_env["transformer_config"] - def test_model_parameter(test_env): save_dir = os.path.join(test_env["test_dir"], "save") model_param = ModelParameter(test_env["model"],test_env["tokenizer"] , test_env["transformer_config"]) diff --git a/tests/test_train_config.py b/tests/test_train_config.py index 3cd2112..38d1778 100644 --- a/tests/test_train_config.py +++ b/tests/test_train_config.py @@ -31,10 +31,18 @@ def test_gradient_accumulation(base_test_env, random_dataset): accumulation_steps_list = [1, 2, 4] for accumulation_steps in accumulation_steps_list: + schedule_config = CosineScheduleConfig( + warmup_steps=10, + total_steps=20 + ) optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) + scheduler = SchedulerFactory.load(optimizer, schedule_config) train_config = TrainConfig( - dataset=random_dataset, + strategy="seq", + model=base_test_env["model"], optimizer=optimizer, + scheduler=scheduler, + dataset=random_dataset, checkpoint_dir=base_test_env["test_dir"], n_epoch=1, batch_size=2, @@ -44,18 +52,7 @@ def test_gradient_accumulation(base_test_env, random_dataset): random_seed=42 ) - schedule_config = CosineScheduleConfig( - warmup_steps=10, - total_steps=20 - ) - train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) - model_parameter = ModelParameter( - base_test_env["model"], - base_test_env["tokenizer"], - base_test_env["transformer_config"] - ) - - trainer = Trainer(model_parameter, train_config, schedule_config) + trainer = Trainer(train_config) trainer.train() assert train_config.accumulation_steps == accumulation_steps diff --git a/tests/test_train_strategy.py b/tests/test_train_strategy.py index d97c254..7f96294 100644 --- a/tests/test_train_strategy.py +++ b/tests/test_train_strategy.py @@ -35,7 +35,7 @@ def test_schedule_factory_random_configs(): config.validate() # Create scheduler using factory - scheduler = SchedulerFactory.load_scheduler(optimizer, config) + scheduler = SchedulerFactory.load(optimizer, config) # Verify scheduler type if isinstance(config, CosineScheduleConfig): @@ -83,7 +83,7 @@ def test_schedule_factory_edge_cases(): for config in edge_cases: config.validate() - scheduler = SchedulerFactory.load_scheduler(optimizer, config) + scheduler = SchedulerFactory.load(optimizer, config) assert scheduler is not None # Test multiple steps @@ -97,16 +97,17 @@ def test_schedule_factory_invalid_configs(): # Test invalid configurations that should raise errors invalid_configs = [ # Negative warmup steps - CosineScheduleConfig(warmup_steps=-10, total_steps=1000, min_rate=0.1), + {"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1}, # Total steps less than warmup steps - CosineScheduleConfig(warmup_steps=500, total_steps=400, min_rate=0.1), + {"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1}, # Invalid min_rate - CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=-0.1), - CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=1.1), + {"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1}, + {"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1}, ] - for config in invalid_configs: + for kwargs in invalid_configs: with pytest.raises(ValueError): + config = CosineScheduleConfig(**kwargs) config.validate() @@ -117,7 +118,7 @@ def test_schedule_factory_state_persistence(): optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1) - scheduler = SchedulerFactory.load_scheduler(optimizer, config) + scheduler = SchedulerFactory.load(optimizer, config) # Take a few steps for _ in range(5): @@ -127,7 +128,7 @@ def test_schedule_factory_state_persistence(): state_dict = scheduler.state_dict() # Create new scheduler and load state - new_scheduler = SchedulerFactory.load_scheduler(optimizer, config) + new_scheduler = SchedulerFactory.load(optimizer, config) new_scheduler.load_state_dict(state_dict) # Verify states match diff --git a/tools/train.py b/tools/train.py index 8704092..697b7e9 100644 --- a/tools/train.py +++ b/tools/train.py @@ -3,8 +3,8 @@ import argparse import torch from torch.optim import AdamW -from khaosz.config import ParameterLoader, Checkpoint, TrainConfig, CosineScheduleConfig -from khaosz.trainer import Trainer, StrategyFactory +from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig +from khaosz.trainer import Trainer, SchedulerFactory from khaosz.data import DatasetLoader @@ -36,7 +36,6 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") - parser.add_argument("--resume_from_checkpoint", action="store_true", help="Train from checkpoint or not.") args = parser.parse_args() @@ -66,16 +65,12 @@ def train( pin_memory: bool, window_size: int, stride: int, - resume_from_checkpoint: bool ): assert train_type in ["seq", "sft", "dpo"] assert os.path.exists(param_path) - parameter = ParameterLoader.load(param_path) - checkpoint = None - - if isinstance(parameter, Checkpoint) and resume_from_checkpoint: - checkpoint = parameter + parameter = ModelParameter() + parameter.load(param_path) if window_size is None: window_size = parameter.config.m_len @@ -90,13 +85,6 @@ def train( "eos_token_id": parameter.tokenizer.eos_id, "pad_token_id": parameter.tokenizer.pad_id, } - - strategy = StrategyFactory.load( - model, - train_type, - device, - **kwargs - ) dataset = DatasetLoader.load( train_type=train_type, @@ -111,16 +99,25 @@ def train( {"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr} ] - optim = AdamW( + optimizer = AdamW( param_groups, betas=(adamw_beta1, adamw_beta2), weight_decay=adamw_weight_decay ) + schedule_config = CosineScheduleConfig( + warmup_steps=warmup_steps, + total_steps=len(dataset) * n_epoch // batch_size, + ) + + scheduler = SchedulerFactory.load(optimizer, schedule_config) + train_config = TrainConfig( - strategy=strategy, + model=model, + strategy=train_type, dataset=dataset, - optimizer=optim, + optimizer=optimizer, + scheduler=scheduler, checkpoint_dir=checkpoint_dir, n_epoch=n_epoch, batch_size=batch_size, @@ -134,17 +131,8 @@ def train( pin_memory=pin_memory ) - schedule_config = CosineScheduleConfig( - warmup_steps=warmup_steps, - total_steps=len(dataset) * n_epoch // batch_size, - ) - - trainer = Trainer( - parameter=parameter, - train_config=train_config, - schedule_config=schedule_config, - ) - trainer.train(checkpoint) + trainer = Trainer(train_config) + trainer.train() if __name__ == "__main__":