fix(trainer): 修复检查点加载逻辑
This commit is contained in:
parent
b67bc9865d
commit
622982364b
|
|
@ -119,6 +119,14 @@ class Checkpoint(BaseModelIO):
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
metadata={"help": "List of training losses."}
|
metadata={"help": "List of training losses."}
|
||||||
)
|
)
|
||||||
|
epoch: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Current epoch."}
|
||||||
|
)
|
||||||
|
batch_iter: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Current iteration."}
|
||||||
|
)
|
||||||
|
|
||||||
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
|
||||||
paths = self._get_file_paths(directory)
|
paths = self._get_file_paths(directory)
|
||||||
|
|
@ -173,11 +181,11 @@ class Checkpoint(BaseModelIO):
|
||||||
if not self.loss_list:
|
if not self.loss_list:
|
||||||
return
|
return
|
||||||
|
|
||||||
current_iter = len(self.loss_list)
|
batch_iter = len(self.loss_list)
|
||||||
|
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(self.loss_list)
|
plt.plot(self.loss_list)
|
||||||
plt.title(f"Training Loss - Iteration {current_iter}")
|
plt.title(f"Training Loss - Iteration {batch_iter}")
|
||||||
plt.xlabel("Batch")
|
plt.xlabel("Batch")
|
||||||
plt.ylabel("Loss")
|
plt.ylabel("Loss")
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
|
|
@ -234,5 +242,3 @@ class ParameterLoader:
|
||||||
loss_list=loss_list or [],
|
loss_list=loss_list or [],
|
||||||
optimizer_state=optimizer
|
optimizer_state=optimizer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -274,13 +274,14 @@ class ResumeableRandomSampler(Sampler[int]):
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
|
|
||||||
|
# consume previous epochs
|
||||||
|
for _ in range(start_epoch):
|
||||||
|
torch.randperm(self.num_samples, generator=generator)
|
||||||
|
|
||||||
self.generator = generator
|
self.generator = generator
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
def _get_indices(self):
|
def _get_indices(self):
|
||||||
for _ in range(self.epoch):
|
|
||||||
_ = torch.randperm(self.num_samples, generator=self.generator)
|
|
||||||
|
|
||||||
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
||||||
self._indices = current_epoch_indices[self.iter % self.num_samples:]
|
self._indices = current_epoch_indices[self.iter % self.num_samples:]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -96,20 +96,22 @@ class CheckpointCallback(TrainCallback):
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
|
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.batch_iter}")
|
||||||
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
||||||
context.checkpoint.scheduler_state = context.scheduler.state_dict()
|
context.checkpoint.scheduler_state = context.scheduler.state_dict()
|
||||||
|
context.checkpoint.epoch = context.epoch
|
||||||
|
context.checkpoint.batch_iter = context.batch_iter
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.current_iter
|
self.last_ckpt_iter = context.batch_iter
|
||||||
|
|
||||||
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_batch_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
context.checkpoint.loss_list.append(context.loss)
|
context.checkpoint.loss_list.append(context.loss)
|
||||||
|
|
||||||
if context.current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
if context.batch_iter - self.last_ckpt_iter >= self.checkpoint_interval:
|
||||||
self._save_checkpoint(trainer, context)
|
self._save_checkpoint(trainer, context)
|
||||||
|
|
||||||
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
def on_train_end(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
if context.current_iter != self.last_ckpt_iter:
|
if context.batch_iter != self.last_ckpt_iter:
|
||||||
self._save_checkpoint(trainer, context)
|
self._save_checkpoint(trainer, context)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -177,7 +179,7 @@ class StepMonitorCallback(TrainCallback):
|
||||||
log_data = {
|
log_data = {
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
"epoch": context.epoch,
|
"epoch": context.epoch,
|
||||||
"iter": context.current_iter,
|
"iter": context.batch_iter,
|
||||||
"metrics": self.metrics,
|
"metrics": self.metrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -207,7 +209,7 @@ class StepMonitorCallback(TrainCallback):
|
||||||
""" Logs training information to console and file. """
|
""" Logs training information to console and file. """
|
||||||
log_data = self._handle_info(trainer, context)
|
log_data = self._handle_info(trainer, context)
|
||||||
try:
|
try:
|
||||||
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.current_iter}.json"
|
log_file = self.log_dir / f"log_epoch_{context.epoch}_iter_{context.batch_iter}.json"
|
||||||
with open(log_file, 'a') as f:
|
with open(log_file, 'a') as f:
|
||||||
json.dump(log_data, f, indent=4)
|
json.dump(log_data, f, indent=4)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -218,4 +220,3 @@ class StepMonitorCallback(TrainCallback):
|
||||||
self._handle_log(trainer, context)
|
self._handle_log(trainer, context)
|
||||||
|
|
||||||
self.step_num += 1
|
self.step_num += 1
|
||||||
|
|
||||||
|
|
@ -17,7 +17,7 @@ class TrainContext:
|
||||||
scheduler: BaseScheduler = field(default=None)
|
scheduler: BaseScheduler = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
current_iter: int = field(default=0)
|
batch_iter: int = field(default=0)
|
||||||
loss: float = field(default=0.0)
|
loss: float = field(default=0.0)
|
||||||
|
|
||||||
def asdict(self) -> dict:
|
def asdict(self) -> dict:
|
||||||
|
|
@ -37,6 +37,10 @@ class TrainContextBuilder:
|
||||||
tokenizer=self.trainer.parameter.tokenizer,
|
tokenizer=self.trainer.parameter.tokenizer,
|
||||||
config=self.trainer.parameter.config,
|
config=self.trainer.parameter.config,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self._context.epoch = checkpoint.epoch
|
||||||
|
self._context.batch_iter = checkpoint.batch_iter
|
||||||
|
|
||||||
self._context.checkpoint = checkpoint
|
self._context.checkpoint = checkpoint
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -70,10 +74,12 @@ class TrainContextBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_dataloader(self) -> Self:
|
def with_dataloader(self) -> Self:
|
||||||
|
# fix: change batch level batch_iter to sample level offset
|
||||||
|
sampler_offset = self._context.batch_iter * self.trainer.train_config.batch_size
|
||||||
resumeable_sampler = ResumeableRandomSampler(
|
resumeable_sampler = ResumeableRandomSampler(
|
||||||
data_source=self.trainer.train_config.dataset,
|
data_source=self.trainer.train_config.dataset,
|
||||||
start_epoch=self._context.epoch,
|
start_epoch=self._context.epoch,
|
||||||
start_iter=self._context.current_iter,
|
start_iter=sampler_offset,
|
||||||
seed=self.trainer.train_config.random_seed
|
seed=self.trainer.train_config.random_seed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ class Trainer:
|
||||||
self._call_callbacks('on_epoch_begin', context)
|
self._call_callbacks('on_epoch_begin', context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
if context.current_iter % self.train_config.accumulation_steps == 0:
|
if context.batch_iter % self.train_config.accumulation_steps == 0:
|
||||||
# 2. step
|
# 2. step
|
||||||
self._call_callbacks('on_step_begin', context)
|
self._call_callbacks('on_step_begin', context)
|
||||||
self.train_config.optimizer.step()
|
self.train_config.optimizer.step()
|
||||||
|
|
@ -76,7 +76,7 @@ class Trainer:
|
||||||
self._call_callbacks('on_batch_begin', context)
|
self._call_callbacks('on_batch_begin', context)
|
||||||
loss = self.train_config.strategy(batch)
|
loss = self.train_config.strategy(batch)
|
||||||
context.loss = loss.item()
|
context.loss = loss.item()
|
||||||
context.current_iter += 1
|
context.batch_iter += 1
|
||||||
|
|
||||||
# to make the loss normalized by accumulation steps
|
# to make the loss normalized by accumulation steps
|
||||||
normalized_loss = loss / self.train_config.accumulation_steps
|
normalized_loss = loss / self.train_config.accumulation_steps
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue