fix(trainer): 修复检查点加载逻辑

This commit is contained in:
ViperEkura 2025-10-18 21:45:23 +08:00
parent b67bc9865d
commit 622982364b
5 changed files with 35 additions and 21 deletions

View File

@ -119,6 +119,14 @@ class Checkpoint(BaseModelIO):
default_factory=list, default_factory=list,
metadata={"help": "List of training losses."} metadata={"help": "List of training losses."}
) )
epoch: int = field(
default=0,
metadata={"help": "Current epoch."}
)
batch_iter: int = field(
default=0,
metadata={"help": "Current iteration."}
)
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]: def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
paths = self._get_file_paths(directory) paths = self._get_file_paths(directory)
@ -173,11 +181,11 @@ class Checkpoint(BaseModelIO):
if not self.loss_list: if not self.loss_list:
return return
current_iter = len(self.loss_list) batch_iter = len(self.loss_list)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.plot(self.loss_list) plt.plot(self.loss_list)
plt.title(f"Training Loss - Iteration {current_iter}") plt.title(f"Training Loss - Iteration {batch_iter}")
plt.xlabel("Batch") plt.xlabel("Batch")
plt.ylabel("Loss") plt.ylabel("Loss")
plt.grid(True) plt.grid(True)
@ -234,5 +242,3 @@ class ParameterLoader:
loss_list=loss_list or [], loss_list=loss_list or [],
optimizer_state=optimizer optimizer_state=optimizer
) )

View File

@ -274,13 +274,14 @@ class ResumeableRandomSampler(Sampler[int]):
generator = torch.Generator() generator = torch.Generator()
generator.manual_seed(seed) generator.manual_seed(seed)
# consume previous epochs
for _ in range(start_epoch):
torch.randperm(self.num_samples, generator=generator)
self.generator = generator self.generator = generator
self._indices = None self._indices = None
def _get_indices(self): def _get_indices(self):
for _ in range(self.epoch):
_ = torch.randperm(self.num_samples, generator=self.generator)
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist() current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
self._indices = current_epoch_indices[self.iter % self.num_samples:] self._indices = current_epoch_indices[self.iter % self.num_samples:]

View File

@ -96,20 +96,22 @@ class CheckpointCallback(TrainCallback):
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}") save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.batch_iter}")
context.checkpoint.optimizer_state = context.optimizer.state_dict() context.checkpoint.optimizer_state = context.optimizer.state_dict()
context.checkpoint.scheduler_state = context.scheduler.state_dict() context.checkpoint.scheduler_state = context.scheduler.state_dict()
context.checkpoint.epoch = context.epoch
context.checkpoint.batch_iter = context.batch_iter
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.current_iter self.last_ckpt_iter = context.batch_iter
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
context.checkpoint.loss_list.append(context.loss) context.checkpoint.loss_list.append(context.loss)
if context.current_iter - self.last_ckpt_iter >= self.checkpoint_interval: if context.batch_iter - self.last_ckpt_iter >= self.checkpoint_interval:
self._save_checkpoint(trainer, context) self._save_checkpoint(trainer, context)
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'): def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
if context.current_iter != self.last_ckpt_iter: if context.batch_iter != self.last_ckpt_iter:
self._save_checkpoint(trainer, context) self._save_checkpoint(trainer, context)
@ -177,7 +179,7 @@ class StepMonitorCallback(TrainCallback):
log_data = { log_data = {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
"epoch": context.epoch, "epoch": context.epoch,
"iter": context.current_iter, "iter": context.batch_iter,
"metrics": self.metrics, "metrics": self.metrics,
} }
@ -207,7 +209,7 @@ class StepMonitorCallback(TrainCallback):
""" Logs training information to console and file. """ """ Logs training information to console and file. """
log_data = self._handle_info(trainer, context) log_data = self._handle_info(trainer, context)
try: try:
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.current_iter}.json" log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.batch_iter}.json"
with open(log_file, 'a') as f: with open(log_file, 'a') as f:
json.dump(log_data, f, indent=4) json.dump(log_data, f, indent=4)
except Exception: except Exception:
@ -218,4 +220,3 @@ class StepMonitorCallback(TrainCallback):
self._handle_log(trainer, context) self._handle_log(trainer, context)
self.step_num += 1 self.step_num += 1

View File

@ -17,7 +17,7 @@ class TrainContext:
scheduler: BaseScheduler = field(default=None) scheduler: BaseScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
current_iter: int = field(default=0) batch_iter: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
def asdict(self) -> dict: def asdict(self) -> dict:
@ -37,6 +37,10 @@ class TrainContextBuilder:
tokenizer=self.trainer.parameter.tokenizer, tokenizer=self.trainer.parameter.tokenizer,
config=self.trainer.parameter.config, config=self.trainer.parameter.config,
) )
else:
self._context.epoch = checkpoint.epoch
self._context.batch_iter = checkpoint.batch_iter
self._context.checkpoint = checkpoint self._context.checkpoint = checkpoint
return self return self
@ -70,10 +74,12 @@ class TrainContextBuilder:
return self return self
def with_dataloader(self) -> Self: def with_dataloader(self) -> Self:
# fix: change batch level batch_iter to sample level offset
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
resumeable_sampler = ResumeableRandomSampler( resumeable_sampler = ResumeableRandomSampler(
data_source=self.trainer.train_config.dataset, data_source=self.trainer.train_config.dataset,
start_epoch=self._context.epoch, start_epoch=self._context.epoch,
start_iter=self._context.current_iter, start_iter=sampler_offset,
seed=self.trainer.train_config.random_seed seed=self.trainer.train_config.random_seed
) )

View File

@ -65,7 +65,7 @@ class Trainer:
self._call_callbacks('on_epoch_begin', context) self._call_callbacks('on_epoch_begin', context)
for batch in context.dataloader: for batch in context.dataloader:
if context.current_iter % self.train_config.accumulation_steps == 0: if context.batch_iter % self.train_config.accumulation_steps == 0:
# 2. step # 2. step
self._call_callbacks('on_step_begin', context) self._call_callbacks('on_step_begin', context)
self.train_config.optimizer.step() self.train_config.optimizer.step()
@ -76,7 +76,7 @@ class Trainer:
self._call_callbacks('on_batch_begin', context) self._call_callbacks('on_batch_begin', context)
loss = self.train_config.strategy(batch) loss = self.train_config.strategy(batch)
context.loss = loss.item() context.loss = loss.item()
context.current_iter += 1 context.batch_iter += 1
# to make the loss normalized by accumulation steps # to make the loss normalized by accumulation steps
normalized_loss = loss / self.train_config.accumulation_steps normalized_loss = loss / self.train_config.accumulation_steps