fix(trainer): 修复检查点加载逻辑
This commit is contained in:
parent
b67bc9865d
commit
622982364b
|
|
@ -119,6 +119,14 @@ class Checkpoint(BaseModelIO):
|
|||
default_factory=list,
|
||||
metadata={"help": "List of training losses."}
|
||||
)
|
||||
epoch: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Current epoch."}
|
||||
)
|
||||
batch_iter: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Current iteration."}
|
||||
)
|
||||
|
||||
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||
paths = self._get_file_paths(directory)
|
||||
|
|
@ -173,11 +181,11 @@ class Checkpoint(BaseModelIO):
|
|||
if not self.loss_list:
|
||||
return
|
||||
|
||||
current_iter = len(self.loss_list)
|
||||
batch_iter = len(self.loss_list)
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(self.loss_list)
|
||||
plt.title(f"Training Loss - Iteration {current_iter}")
|
||||
plt.title(f"Training Loss - Iteration {batch_iter}")
|
||||
plt.xlabel("Batch")
|
||||
plt.ylabel("Loss")
|
||||
plt.grid(True)
|
||||
|
|
@ -233,6 +241,4 @@ class ParameterLoader:
|
|||
config=config,
|
||||
loss_list=loss_list or [],
|
||||
optimizer_state=optimizer
|
||||
)
|
||||
|
||||
|
||||
)
|
||||
|
|
@ -273,14 +273,15 @@ class ResumeableRandomSampler(Sampler[int]):
|
|||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# consume previous epochs
|
||||
for _ in range(start_epoch):
|
||||
torch.randperm(self.num_samples, generator=generator)
|
||||
|
||||
self.generator = generator
|
||||
self._indices = None
|
||||
self._indices = None
|
||||
|
||||
def _get_indices(self):
|
||||
for _ in range(self.epoch):
|
||||
_ = torch.randperm(self.num_samples, generator=self.generator)
|
||||
|
||||
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
||||
self._indices = current_epoch_indices[self.iter % self.num_samples:]
|
||||
|
||||
|
|
|
|||
|
|
@ -96,20 +96,22 @@ class CheckpointCallback(TrainCallback):
|
|||
self.last_ckpt_iter = 0
|
||||
|
||||
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
|
||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.batch_iter}")
|
||||
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
||||
context.checkpoint.scheduler_state = context.scheduler.state_dict()
|
||||
context.checkpoint.epoch = context.epoch
|
||||
context.checkpoint.batch_iter = context.batch_iter
|
||||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.current_iter
|
||||
self.last_ckpt_iter = context.batch_iter
|
||||
|
||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
context.checkpoint.loss_list.append(context.loss)
|
||||
|
||||
if context.current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
||||
if context.batch_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
||||
self._save_checkpoint(trainer, context)
|
||||
|
||||
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
if context.current_iter != self.last_ckpt_iter:
|
||||
if context.batch_iter != self.last_ckpt_iter:
|
||||
self._save_checkpoint(trainer, context)
|
||||
|
||||
|
||||
|
|
@ -177,7 +179,7 @@ class StepMonitorCallback(TrainCallback):
|
|||
log_data = {
|
||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"epoch": context.epoch,
|
||||
"iter": context.current_iter,
|
||||
"iter": context.batch_iter,
|
||||
"metrics": self.metrics,
|
||||
}
|
||||
|
||||
|
|
@ -207,7 +209,7 @@ class StepMonitorCallback(TrainCallback):
|
|||
""" Logs training information to console and file. """
|
||||
log_data = self._handle_info(trainer, context)
|
||||
try:
|
||||
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.current_iter}.json"
|
||||
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.batch_iter}.json"
|
||||
with open(log_file, 'a') as f:
|
||||
json.dump(log_data, f, indent=4)
|
||||
except Exception:
|
||||
|
|
@ -217,5 +219,4 @@ class StepMonitorCallback(TrainCallback):
|
|||
if self.step_num % self.log_interval == 0:
|
||||
self._handle_log(trainer, context)
|
||||
|
||||
self.step_num += 1
|
||||
|
||||
self.step_num += 1
|
||||
|
|
@ -17,7 +17,7 @@ class TrainContext:
|
|||
scheduler: BaseScheduler = field(default=None)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
epoch: int = field(default=0)
|
||||
current_iter: int = field(default=0)
|
||||
batch_iter: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
|
||||
def asdict(self) -> dict:
|
||||
|
|
@ -37,6 +37,10 @@ class TrainContextBuilder:
|
|||
tokenizer=self.trainer.parameter.tokenizer,
|
||||
config=self.trainer.parameter.config,
|
||||
)
|
||||
else:
|
||||
self._context.epoch = checkpoint.epoch
|
||||
self._context.batch_iter = checkpoint.batch_iter
|
||||
|
||||
self._context.checkpoint = checkpoint
|
||||
return self
|
||||
|
||||
|
|
@ -70,10 +74,12 @@ class TrainContextBuilder:
|
|||
return self
|
||||
|
||||
def with_dataloader(self) -> Self:
|
||||
# fix: change batch level batch_iter to sample level offset
|
||||
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
|
||||
resumeable_sampler = ResumeableRandomSampler(
|
||||
data_source=self.trainer.train_config.dataset,
|
||||
start_epoch=self._context.epoch,
|
||||
start_iter=self._context.current_iter,
|
||||
start_iter=sampler_offset,
|
||||
seed=self.trainer.train_config.random_seed
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class Trainer:
|
|||
self._call_callbacks('on_epoch_begin', context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
if context.current_iter % self.train_config.accumulation_steps == 0:
|
||||
if context.batch_iter % self.train_config.accumulation_steps == 0:
|
||||
# 2. step
|
||||
self._call_callbacks('on_step_begin', context)
|
||||
self.train_config.optimizer.step()
|
||||
|
|
@ -76,7 +76,7 @@ class Trainer:
|
|||
self._call_callbacks('on_batch_begin', context)
|
||||
loss = self.train_config.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
context.current_iter += 1
|
||||
context.batch_iter += 1
|
||||
|
||||
# to make the loss normalized by accumulation steps
|
||||
normalized_loss = loss / self.train_config.accumulation_steps
|
||||
|
|
|
|||
Loading…
Reference in New Issue