From f25a2492917f5e89901ed1b08aff7949488e267d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 28 Sep 2025 14:00:21 +0800 Subject: [PATCH] =?UTF-8?q?feat(khaosz):=20=E4=BC=98=E5=8C=96=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=8F=82=E6=95=B0=E4=BF=9D=E5=AD=98=E4=B8=8E=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/core/parameter.py | 45 +++++--------- khaosz/trainer/trainer.py | 124 ++++++++++++-------------------------- 2 files changed, 54 insertions(+), 115 deletions(-) diff --git a/khaosz/core/parameter.py b/khaosz/core/parameter.py index f752792..75c8fba 100644 --- a/khaosz/core/parameter.py +++ b/khaosz/core/parameter.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.optim as optim from dataclasses import dataclass, field -from typing import Optional, Self, Union +from typing import Any, Dict, List, Optional, Self, Union from pathlib import Path from khaosz.core.tokenizer import BpeTokenizer @@ -84,11 +84,9 @@ class ModelParameter(BaseModelIO): ) def save(self, save_dir: Union[str, Path]): - """Save model parameters.""" self.save_components(save_dir) def load(self, load_dir: Union[str, Path]) -> Self: - """Load model parameters.""" return self.load_components(load_dir) @@ -108,26 +106,16 @@ class Checkpoint(BaseModelIO): default_factory=TransformerConfig, metadata={"help": "Transformer model configuration."} ) - loss_list: list[float] = field( - default_factory=list, - metadata={"help": "List of training losses."} - ) - current_iter: int = field( - default=0, - metadata={"help": "Current training iteration."} - ) - optimizer: Optional[optim.Optimizer] = field( + optim_state: Dict[str, Any] = field( default=None, metadata={"help": "Optimizer state."} ) - - def __post_init__(self): - # Ensure current_iter matches loss list length if not explicitly set - if self.current_iter == 0 and self.loss_list: - self.current_iter = len(self.loss_list) + loss_list: List[float] = field( + default_factory=list, + metadata={"help": "List of training losses."} + ) def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]: - """Get file paths for training-specific files.""" paths = self._get_file_paths(directory) paths.update({ "loss_list": paths["model"].parent / "loss.pkl", @@ -137,7 +125,6 @@ class Checkpoint(BaseModelIO): return paths def save_training_state(self, save_dir: Union[str, Path]): - """Save training-specific state.""" paths = self._get_training_paths(save_dir) # Save loss plot @@ -148,25 +135,21 @@ class Checkpoint(BaseModelIO): pkl.dump(self.loss_list, f) # Save optimizer state - if self.optimizer is not None: - with open(str(paths["optimizer"]), "wb") as f: - pkl.dump(self.optimizer.state_dict(), f) + with open(str(paths["optimizer"]), "wb") as f: + pkl.dump(self.optim_state, f) def load_training_state(self, load_dir: Union[str, Path]) -> Self: - """Load training-specific state.""" paths = self._get_training_paths(load_dir) # Load loss list if paths["loss_list"].exists(): with open(str(paths["loss_list"]), "rb") as f: self.loss_list = pkl.load(f) - self.current_iter = len(self.loss_list) # Load optimizer state - if paths["optimizer"].exists() and self.optimizer is not None: + if paths["optimizer"].exists(): with open(str(paths["optimizer"]), "rb") as f: - optim_state = pkl.load(f) - self.optimizer.load_state_dict(optim_state) + self.optim_state = pkl.load(f) return self @@ -174,10 +157,12 @@ class Checkpoint(BaseModelIO): """Plot and save loss curve.""" if not self.loss_list: return + + current_iter = len(self.loss_list) plt.figure(figsize=(10, 6)) plt.plot(self.loss_list) - plt.title(f"Training Loss - Iteration {self.current_iter}") + plt.title(f"Training Loss - Iteration {current_iter}") plt.xlabel("Batch") plt.ylabel("Loss") plt.grid(True) @@ -224,7 +209,7 @@ class ParameterLoader: tokenizer: BpeTokenizer, config: TransformerConfig, loss_list: Optional[list[float]] = None, - optimizer: Optional[optim.Optimizer] = None + optimizer: Optional[optim.Optimizer] = None, ) -> Checkpoint: """Convenience method to create a training checkpoint.""" return Checkpoint( @@ -232,7 +217,7 @@ class ParameterLoader: tokenizer=tokenizer, config=config, loss_list=loss_list or [], - optimizer=optimizer + optimizer_state=optimizer ) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 85293eb..285f5c0 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,8 +1,7 @@ import os import torch -import logging -from typing import Tuple +from typing import Optional from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, RandomSampler @@ -15,82 +14,57 @@ from khaosz.trainer.strategy import SchedulerFactory, StrategyFactory, TrainConf class Trainer: def __init__( self, - parameter: ModelParameter, - log_path: str="./train_log.log" + parameter: ModelParameter ): - logger = logging.getLogger() - logger.setLevel(level = logging.INFO) - handler = logging.FileHandler(log_path) - handler.setLevel(logging.INFO) - handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s')) - logger.addHandler(handler) - logger.info("initializing trainer ...") - - self.logger = logger - self.model = parameter.model - self.tokenizer = parameter.tokenizer - self.config = parameter.config + self.checkpoint = Checkpoint( + model=parameter.model, + tokenizer=parameter.tokenizer, + config=parameter.config, + ) def save_checkpoint( self, - loss_list: list, - ckpt_dir: str, - current_iter: int, - last_ckpt_iter: int + loss_list: list, + train_config: TrainConfig ): - save_path = os.path.join(ckpt_dir, f"iter_{current_iter}") - Checkpoint( - self.model, - self.tokenizer, - self.config, - loss_list, - current_iter - ).save(save_path) - - diff_iter = current_iter - last_ckpt_iter - avg_loss = sum(loss_list[last_ckpt_iter:current_iter]) / diff_iter - self.logger.info(f"iter: {current_iter} loss: {avg_loss}") - - return current_iter - - def load_checkpoint(self, train_checkpoint: Checkpoint) -> Tuple[list, int]: - self.model = train_checkpoint.model - self.tokenizer = train_checkpoint.tokenizer - self.config = train_checkpoint.config - loss_list = train_checkpoint.loss_list - last_ckpt_iter = train_checkpoint.current_iter - - return loss_list, last_ckpt_iter + current_iter = len(loss_list) + save_path = os.path.join(train_config.ckpt_dir, f"iter_{current_iter}") + self.checkpoint.loss_list = loss_list + self.checkpoint.optim_state = train_config.optimizer.state_dict() + self.checkpoint.save(save_path) def train( self, train_config: TrainConfig, schedule_config: ScheduleConfig, - train_checkpoint: Checkpoint = None - ): + train_checkpoint: Optional[Checkpoint] = None + ) -> Checkpoint: assert schedule_config.schedule_type in ["cosine", "sgdr"] assert train_config.train_type in ["seq", "sft", "dpo"] if train_checkpoint: - loss_list, last_ckpt_iter = self.load_checkpoint(train_checkpoint) - current_iter = train_checkpoint.current_iter + 1 - self.logger.info(f"Resuming training from checkpoint: iter {current_iter}") - else: - current_iter = 0 - last_ckpt_iter = 0 - loss_list = [] + self.checkpoint = train_checkpoint + train_config.optimizer.load_state_dict(train_checkpoint.optim_state) + loss_list = self.checkpoint.loss_list + current_iter = len(self.checkpoint.loss_list) + last_ckpt_iter = current_iter + lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( **schedule_config.get_kwargs() ) + strategy_kwargs = { + "bos_token_id": self.checkpoint.tokenizer.bos_id, + "eos_token_id": self.checkpoint.tokenizer.eos_id, + "pad_token_id": self.checkpoint.tokenizer.pad_id, + "dpo_beta": train_config.dpo_beta + } + strategy = StrategyFactory.load( - self.model, - train_type=train_config.train_type, - bos_token_id=self.tokenizer.bos_id, - eos_token_id=self.tokenizer.eos_id, - pad_token_id=self.tokenizer.pad_id, - dpo_beta=train_config.dpo_beta + self.checkpoint.model, + train_config.train_type, + **strategy_kwargs ) scheduler = LambdaLR( @@ -104,11 +78,8 @@ class Trainer: sampler = RandomSampler(train_config.dataset, generator=generator) remaining_epochs = train_config.n_epoch - current_iter // (len(train_config.dataset) // train_config.batch_size) - self.logger.info(f"Starting {train_config.train_type.upper()} training for {train_config.n_epoch} epochs") - self.logger.info(f"Checkpoint interval: {train_config.n_iter_ckpt} iterations") - for epoch in range(remaining_epochs): - self.model.train() + self.checkpoint.model.train() dataloader = DataLoader( train_config.dataset, batch_size=train_config.batch_size, @@ -128,7 +99,7 @@ class Trainer: #step if current_iter % train_config.n_iter_step == 0: clip_grad_norm_( - self.model.parameters(), + self.checkpoint.model.parameters(), train_config.max_grad_norm ) train_config.optimizer.step() @@ -142,28 +113,11 @@ class Trainer: }) #save checkpotint if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt: - last_ckpt_iter = self.save_checkpoint( - loss_list, - train_config.ckpt_dir, - current_iter, - last_ckpt_iter - ) + self.save_checkpoint(loss_list, train_config) + last_ckpt_iter = current_iter if current_iter != last_ckpt_iter: - last_ckpt_iter = self.save_checkpoint( - loss_list, - train_config.ckpt_dir, - current_iter, - last_ckpt_iter - ) - - self.logger.info("Training completed") + self.save_checkpoint(loss_list, train_config) + last_ckpt_iter = current_iter - return Checkpoint( - self.model, - self.tokenizer, - self.config, - loss_list, - current_iter, - train_config.optimizer - ) + return self.checkpoint \ No newline at end of file