From 622982364b69514352827be8b1f362d08059133b Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 18 Oct 2025 21:45:23 +0800 Subject: [PATCH] =?UTF-8?q?fix(trainer):=20=E4=BF=AE=E5=A4=8D=E6=A3=80?= =?UTF-8?q?=E6=9F=A5=E7=82=B9=E5=8A=A0=E8=BD=BD=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- khaosz/config/param_config.py | 16 +++++++++++----- khaosz/data/data_util.py | 9 +++++---- khaosz/trainer/train_callback.py | 17 +++++++++-------- khaosz/trainer/train_context.py | 10 ++++++++-- khaosz/trainer/trainer.py | 4 ++-- 5 files changed, 35 insertions(+), 21 deletions(-) diff --git a/khaosz/config/param_config.py b/khaosz/config/param_config.py index b7b3f05..de8fd13 100644 --- a/khaosz/config/param_config.py +++ b/khaosz/config/param_config.py @@ -119,6 +119,14 @@ class Checkpoint(BaseModelIO): default_factory=list, metadata={"help": "List of training losses."} ) + epoch: int = field( + default=0, + metadata={"help": "Current epoch."} + ) + batch_iter: int = field( + default=0, + metadata={"help": "Current iteration."} + ) def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]: paths = self._get_file_paths(directory) @@ -173,11 +181,11 @@ class Checkpoint(BaseModelIO): if not self.loss_list: return - current_iter = len(self.loss_list) + batch_iter = len(self.loss_list) plt.figure(figsize=(10, 6)) plt.plot(self.loss_list) - plt.title(f"Training Loss - Iteration {current_iter}") + plt.title(f"Training Loss - Iteration {batch_iter}") plt.xlabel("Batch") plt.ylabel("Loss") plt.grid(True) @@ -233,6 +241,4 @@ class ParameterLoader: config=config, loss_list=loss_list or [], optimizer_state=optimizer - ) - - + ) \ No newline at end of file diff --git a/khaosz/data/data_util.py b/khaosz/data/data_util.py index 19d811c..1a1c317 100644 --- a/khaosz/data/data_util.py +++ b/khaosz/data/data_util.py @@ -273,14 +273,15 @@ class ResumeableRandomSampler(Sampler[int]): generator = torch.Generator() generator.manual_seed(seed) + + # consume previous epochs + for _ in range(start_epoch): + torch.randperm(self.num_samples, generator=generator) self.generator = generator - self._indices = None + self._indices = None def _get_indices(self): - for _ in range(self.epoch): - _ = torch.randperm(self.num_samples, generator=self.generator) - current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist() self._indices = current_epoch_indices[self.iter % self.num_samples:] diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 73dc53d..5053bc0 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -96,20 +96,22 @@ class CheckpointCallback(TrainCallback): self.last_ckpt_iter = 0 def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): - save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}") + save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.batch_iter}") context.checkpoint.optimizer_state = context.optimizer.state_dict() context.checkpoint.scheduler_state = context.scheduler.state_dict() + context.checkpoint.epoch = context.epoch + context.checkpoint.batch_iter = context.batch_iter context.checkpoint.save(save_path) - self.last_ckpt_iter = context.current_iter + self.last_ckpt_iter = context.batch_iter def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): context.checkpoint.loss_list.append(context.loss) - if context.current_iter - self.last_ckpt_iter >= self.checkpoint_interval: + if context.batch_iter - self.last_ckpt_iter >= self.checkpoint_interval: self._save_checkpoint(trainer, context) def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): - if context.current_iter != self.last_ckpt_iter: + if context.batch_iter != self.last_ckpt_iter: self._save_checkpoint(trainer, context) @@ -177,7 +179,7 @@ class StepMonitorCallback(TrainCallback): log_data = { "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "epoch": context.epoch, - "iter": context.current_iter, + "iter": context.batch_iter, "metrics": self.metrics, } @@ -207,7 +209,7 @@ class StepMonitorCallback(TrainCallback): """ Logs training information to console and file. """ log_data = self._handle_info(trainer, context) try: - log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.current_iter}.json" + log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.batch_iter}.json" with open(log_file, 'a') as f: json.dump(log_data, f, indent=4) except Exception: @@ -217,5 +219,4 @@ class StepMonitorCallback(TrainCallback): if self.step_num % self.log_interval == 0: self._handle_log(trainer, context) - self.step_num += 1 - \ No newline at end of file + self.step_num += 1 \ No newline at end of file diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index b1ed4e0..4248ce1 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -17,7 +17,7 @@ class TrainContext: scheduler: BaseScheduler = field(default=None) checkpoint: Checkpoint = field(default=None) epoch: int = field(default=0) - current_iter: int = field(default=0) + batch_iter: int = field(default=0) loss: float = field(default=0.0) def asdict(self) -> dict: @@ -37,6 +37,10 @@ class TrainContextBuilder: tokenizer=self.trainer.parameter.tokenizer, config=self.trainer.parameter.config, ) + else: + self._context.epoch = checkpoint.epoch + self._context.batch_iter = checkpoint.batch_iter + self._context.checkpoint = checkpoint return self @@ -70,10 +74,12 @@ class TrainContextBuilder: return self def with_dataloader(self) -> Self: + # fix: change batch level batch_iter to sample level offset + sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size resumeable_sampler = ResumeableRandomSampler( data_source=self.trainer.train_config.dataset, start_epoch=self._context.epoch, - start_iter=self._context.current_iter, + start_iter=sampler_offset, seed=self.trainer.train_config.random_seed ) diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 0241ded..1565584 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -65,7 +65,7 @@ class Trainer: self._call_callbacks('on_epoch_begin', context) for batch in context.dataloader: - if context.current_iter % self.train_config.accumulation_steps == 0: + if context.batch_iter % self.train_config.accumulation_steps == 0: # 2. step self._call_callbacks('on_step_begin', context) self.train_config.optimizer.step() @@ -76,7 +76,7 @@ class Trainer: self._call_callbacks('on_batch_begin', context) loss = self.train_config.strategy(batch) context.loss = loss.item() - context.current_iter += 1 + context.batch_iter += 1 # to make the loss normalized by accumulation steps normalized_loss = loss / self.train_config.accumulation_steps