From 315ce1990aa700a92e236ba1613302b0366fd456 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 30 Sep 2025 16:33:18 +0800 Subject: [PATCH] =?UTF-8?q?feat(khaosz/trainer):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=99=A8=E5=9B=9E=E8=B0=83=E6=9C=BA=E5=88=B6?= =?UTF-8?q?=E4=B8=8E=E6=95=B0=E6=8D=AE=E9=87=87=E6=A0=B7=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/trainer/__init__.py | 2 + khaosz/trainer/data_util.py | 23 +++--- khaosz/trainer/strategy.py | 7 +- khaosz/trainer/trainer.py | 125 +++++++++++++++++------------ khaosz/trainer/trainer_callback.py | 46 ++++++----- 5 files changed, 118 insertions(+), 85 deletions(-) diff --git a/khaosz/trainer/__init__.py b/khaosz/trainer/__init__.py index 39470f0..f7aea3f 100644 --- a/khaosz/trainer/__init__.py +++ b/khaosz/trainer/__init__.py @@ -8,6 +8,7 @@ from khaosz.trainer.strategy import ( SchedulerFactory ) from khaosz.trainer.trainer_callback import ( + TrainerCallback, ProgressBarCallback, CheckpointCallback, TrainerCallback, @@ -25,6 +26,7 @@ __all__ = [ "SchedulerFactory", # callback + "TrainerCallback", "ProgressBarCallback", "CheckpointCallback", "TrainerCallback", diff --git a/khaosz/trainer/data_util.py b/khaosz/trainer/data_util.py index ed19708..ba012f6 100644 --- a/khaosz/trainer/data_util.py +++ b/khaosz/trainer/data_util.py @@ -53,7 +53,8 @@ def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).bool() attention_mask = seq_mask & causal_mask - return attention_mask + # fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast + return attention_mask.unsqueeze(0) class BaseSegmentFetcher: @@ -117,10 +118,13 @@ class BaseDataset(Dataset, ABC): self.total_samples = 0 self.device = device - def save(self, save_path: str): + def save(self, save_path: str): + keys = list(self.segments.keys()) + if not keys: + return + first_item = self.segments[keys[0]] segment_size = len(first_item) - keys = list(self.segments.keys()) for i in range(segment_size): formated_segment = {key: self.segments[key][i] for key in keys} @@ -272,7 +276,7 @@ class RandomSampler(Sampler[int]): self.data_source = data_source self.seed = seed self.epoch = 0 - self.current_index = 0 + self.current_iter = 0 self._indices = None if generator is None: @@ -291,20 +295,21 @@ class RandomSampler(Sampler[int]): if self._indices is None: self._generate_indices() - for i in range(self.current_index, n): + for i in range(self.current_iter, n): yield self._indices[i] + self.current_iter += 1 self.epoch += 1 - self.current_index = 0 self._indices = None def __len__(self): - return len(self.data_source) - self.current_index + n = len(self.data_source) + return n - self.current_iter % n def state_dict(self): return { 'epoch': self.epoch, - 'current_index': self.current_index, + 'current_iter': self.current_iter, 'seed': self.seed, 'generator_state': self.generator.get_state() if self.generator else None, 'indices': self._indices @@ -312,7 +317,7 @@ class RandomSampler(Sampler[int]): def load_state_dict(self, state_dict): self.epoch = state_dict['epoch'] - self.current_index = state_dict['current_index'] + self.current_iter = state_dict['current_iter'] self.seed = state_dict['seed'] if self.generator and state_dict['generator_state'] is not None: diff --git a/khaosz/trainer/strategy.py b/khaosz/trainer/strategy.py index 93a2fea..c14f492 100644 --- a/khaosz/trainer/strategy.py +++ b/khaosz/trainer/strategy.py @@ -159,12 +159,7 @@ class StrategyFactory: def load(model, train_type, **kwargs): train_strategy: Dict[str, Callable[[], BaseStrategy]] = { "seq": lambda: SeqStrategy(model), - "sft": lambda: SftStrategy( - model, - kwargs.get("bos_token_id"), - kwargs.get("eos_token_id"), - kwargs.get("multi_turn") - ), + "sft": lambda: SftStrategy(model), "dpo": lambda: DpoStrategy( model, kwargs.get("pad_token_id"), diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 6f8a46f..9a0005f 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,6 +1,4 @@ -import torch -import itertools -from typing import Optional, List +from typing import Optional, List, cast from torch.utils.data import DataLoader from khaosz.core import ModelParameter, Checkpoint @@ -23,11 +21,7 @@ class Trainer: schedule_config: ScheduleConfig, callbacks: Optional[List[TrainerCallback]] = None ): - self.checkpoint = Checkpoint( - model=parameter.model, - tokenizer=parameter.tokenizer, - config=parameter.config, - ) + self.parameter = parameter self.train_config = train_config self.schedule_config = schedule_config self.callbacks = callbacks or self._get_default_callbacks() @@ -39,81 +33,108 @@ class Trainer: GradientClippingCallback(), SchedulerCallback(self.schedule_config), ] - - def _create_dataloader(self, start_index: int = 0) -> DataLoader: + + def _set_train_kwargs(self, kwargs: dict): + used_epochs = 0 + used_iters = 0 seed = self.train_config.random_seed - generator = torch.Generator().manual_seed(seed) - sampler = RandomSampler( - self.train_config.dataset, - generator=generator, - seed=seed - ) + sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed) + optim = self.train_config.optimizer + checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None)) + + if checkpoint is None: + checkpoint = Checkpoint( + model=self.parameter.model, + tokenizer=self.parameter.tokenizer, + config=self.parameter.config, + sampler_state=None, + optim_state=None, + loss_list=[] + ) + + sampler_state = checkpoint.sampler_state + optim_state = checkpoint.optim_state + + if sampler_state: + sampler.load_state_dict(sampler_state) + used_epochs = sampler_state.get('epoch', 0) + used_iters = sampler_state.get('iter', 0) + + if optim_state: + optim.load_state_dict(optim_state) + + checkpoint.optim_state = optim.state_dict() + checkpoint.sampler_state = sampler.state_dict() + dataloader = DataLoader( self.train_config.dataset, batch_size=self.train_config.batch_size, sampler=sampler ) - if start_index > 0: - dataloader = itertools.islice(dataloader, start_index, None) + kwargs["dataloader"] = dataloader + kwargs["optimizer"] = self.train_config.optimizer + kwargs["epoch"] = used_epochs + kwargs["current_iter"] = used_iters + kwargs["sampler"] = sampler + kwargs["checkpoint"] = checkpoint - return dataloader - def _call_callbacks(self, method_name: str, **kwargs): for callback in self.callbacks: method = getattr(callback, method_name, None) if method: method(self, **kwargs) - + def train( self, - train_checkpoint: Optional[Checkpoint] = None + checkpoint: Optional[Checkpoint] = None ) -> Checkpoint: - - if train_checkpoint: - self.checkpoint = train_checkpoint - self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state) - else: - self.checkpoint.optim_state = self.train_config.optimizer.state_dict() - - current_iter = len(self.checkpoint.loss_list) - total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size - total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter - current_epochs = total_reamining_steps // total_steps_per_epoch - current_steps = total_reamining_steps % total_steps_per_epoch - - # train - self._call_callbacks('on_train_begin', checkpoint=self.checkpoint) - self.checkpoint.model.train() + # train + train_kwargs = { + 'checkpoint': checkpoint, + 'dataloader': None, + 'optimizer': None, + 'sampler': None, + 'epoch': 0, + 'current_iter': 0, + 'loss': 0.0, + } + + self._set_train_kwargs(train_kwargs) + self._call_callbacks('on_train_begin', **train_kwargs) + + dataloader = train_kwargs['dataloader'] + checkpoint = train_kwargs['checkpoint'] + start_epoch = train_kwargs['epoch'] try: - for epoch in range(current_epochs): + self.parameter.model.train() + for epoch in range(start_epoch, self.train_config.n_epoch): # epoch - self._call_callbacks('on_epoch_begin', epoch=epoch) - dataloader = self._create_dataloader(start_index=current_steps) + train_kwargs["epoch"] = epoch + self._call_callbacks('on_epoch_begin', **train_kwargs) for batch in dataloader: # batch - self._call_callbacks('on_batch_begin', batch=batch) + self._call_callbacks('on_batch_begin', **train_kwargs) loss = self.train_config.strategy(batch) - self.checkpoint.loss_list.append(loss.item()) loss.backward() - self._call_callbacks('on_batch_end', batch=batch, loss=loss.item(), current_iter=current_iter) + train_kwargs["loss"] = loss.item() + self._call_callbacks('on_batch_end', **train_kwargs) - if current_iter % self.train_config.accumulation_steps == 0: + if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0: # step - self._call_callbacks('on_step_begin', current_iter=current_iter) + self._call_callbacks('on_step_begin', **train_kwargs) self.train_config.optimizer.step() self.train_config.optimizer.zero_grad() - self._call_callbacks('on_step_end', current_iter=current_iter) + self._call_callbacks('on_step_end', **train_kwargs) - current_iter += 1 + train_kwargs["current_iter"] += 1 - self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list) + self._call_callbacks('on_epoch_end', **train_kwargs) except Exception as e: raise e - finally: - self._call_callbacks('on_train_end', checkpoint=self.checkpoint) - return self.checkpoint \ No newline at end of file + self._call_callbacks('on_train_end', **train_kwargs) + return checkpoint \ No newline at end of file diff --git a/khaosz/trainer/trainer_callback.py b/khaosz/trainer/trainer_callback.py index ce3927a..e1a9ec8 100644 --- a/khaosz/trainer/trainer_callback.py +++ b/khaosz/trainer/trainer_callback.py @@ -1,9 +1,12 @@ import os +import torch.optim as optim + from tqdm import tqdm -from khaosz.core.parameter import Checkpoint from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import LambdaLR from typing import Optional, cast, TYPE_CHECKING +from khaosz.core.parameter import Checkpoint +from khaosz.trainer.data_util import RandomSampler from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory if TYPE_CHECKING: @@ -76,7 +79,7 @@ class ProgressBarCallback(TrainerCallback): def on_epoch_begin(self, trainer: 'Trainer', **kwargs): epoch = kwargs.get('epoch') - dataloader = trainer._create_dataloader() + dataloader = kwargs.get('dataloader') self.progress_bar = tqdm( dataloader, desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}", @@ -106,28 +109,36 @@ class CheckpointCallback(TrainerCallback): self.last_ckpt_iter = 0 @staticmethod - def _save_checkpoint(trainer: 'Trainer'): - current_iter = len(trainer.checkpoint.loss_list) - save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}") - trainer.checkpoint.optim_state = trainer.train_config.optimizer.state_dict() - trainer.checkpoint.save(save_path) - - def on_train_begin(self, trainer: 'Trainer', **kwargs): - _ = trainer + def _save_checkpoint(trainer: 'Trainer', **kwargs): + current_iter = kwargs.get('current_iter') + random_sampler = cast(RandomSampler, kwargs.get('sampler')) + optimizer = cast(optim.Optimizer, kwargs.get('optimizer')) checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) - self.last_ckpt_iter = len(checkpoint.loss_list) + + save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}") + checkpoint.sampler_state = random_sampler.state_dict() + checkpoint.optim_state = optimizer.state_dict() + + checkpoint.sampler_state['epoch'] = kwargs.get('epoch', 0) + checkpoint.sampler_state['current_iter'] = kwargs.get('current_iter', 0) + + checkpoint.save(save_path) def on_batch_end(self, trainer: 'Trainer', **kwargs): current_iter = kwargs.get('current_iter') + checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) + loss = kwargs.get('loss') + checkpoint.loss_list.append(loss) + if current_iter - self.last_ckpt_iter >= self.checkpoint_interval: - CheckpointCallback._save_checkpoint(trainer) + CheckpointCallback._save_checkpoint(trainer, **kwargs) self.last_ckpt_iter = current_iter def on_train_end(self, trainer: 'Trainer', **kwargs): - checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) - current_iter = len(checkpoint.loss_list) + current_iter = kwargs.get('current_iter') if current_iter != self.last_ckpt_iter: - CheckpointCallback._save_checkpoint(trainer) + CheckpointCallback._save_checkpoint(trainer, **kwargs) + self.last_ckpt_iter = current_iter class GradientClippingCallback(TrainerCallback): @@ -137,7 +148,7 @@ class GradientClippingCallback(TrainerCallback): def on_step_begin(self, trainer: 'Trainer', **kwargs): _ = kwargs clip_grad_norm_( - trainer.checkpoint.model.parameters(), + trainer.parameter.model.parameters(), trainer.train_config.max_grad_norm ) @@ -152,8 +163,7 @@ class SchedulerCallback(TrainerCallback): self.current_iter = 0 def on_train_begin(self, trainer: 'Trainer', **kwargs): - checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) - self.current_iter = len(checkpoint.loss_list) + self.current_iter = kwargs.get('current_iter') for group in trainer.train_config.optimizer.param_groups: if "initial_lr" not in group: