fix(khaosz/trainer): 修正训练步数计算逻辑

This commit is contained in:
ViperEkura 2025-09-29 19:05:26 +08:00
parent c104a400e7
commit e467420475
3 changed files with 38 additions and 28 deletions

View File

@ -154,6 +154,11 @@ class SchedulerCallback(TrainerCallback):
def on_train_begin(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
self.current_iter = len(checkpoint.loss_list)
for group in trainer.train_config.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
self.schedule_config.validate()
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
self.schedule_config

View File

@ -1,4 +1,5 @@
import torch
import itertools
from typing import Optional, List
from torch.utils.data import DataLoader, RandomSampler
@ -38,15 +39,20 @@ class Trainer:
SchedulerCallback(self.schedule_config),
]
def _create_dataloader(self) -> DataLoader:
def _create_dataloader(self, start_index: int = 0) -> DataLoader:
seed = self.train_config.random_seed
generator = torch.Generator().manual_seed(seed)
sampler = RandomSampler(self.train_config.dataset, generator=generator)
return DataLoader(
dataloader = DataLoader(
self.train_config.dataset,
batch_size=self.train_config.batch_size,
sampler=sampler
)
if start_index > 0:
dataloader = itertools.islice(dataloader, start_index, None)
return dataloader
def _call_callbacks(self, method_name: str, **kwargs):
for callback in self.callbacks:
@ -58,34 +64,29 @@ class Trainer:
self,
train_checkpoint: Optional[Checkpoint] = None
) -> Checkpoint:
assert self.schedule_config.schedule_type in ["cosine", "sgdr"]
if train_checkpoint:
self.checkpoint = train_checkpoint
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
else:
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
current_iter = len(self.checkpoint.loss_list)
for group in self.train_config.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size
total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter
reamining_steps = self.train_config.n_epoch - current_iter
total_steps = len(self.train_config.dataset) // self.train_config.batch_size
remaining_epochs = (reamining_steps + total_steps - 1) // total_steps
current_epochs = total_reamining_steps // total_steps_per_epoch
current_steps = total_reamining_steps % total_steps_per_epoch
# train
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
self.checkpoint.model.train()
try:
for epoch in range(remaining_epochs):
self.checkpoint.model.train()
for epoch in range(current_epochs):
# epoch
self._call_callbacks('on_epoch_begin', epoch=epoch)
dataloader = self._create_dataloader()
dataloader = self._create_dataloader(start_index=current_steps)
for batch in dataloader:
# batch
self._call_callbacks('on_batch_begin', batch=batch)
@ -110,5 +111,4 @@ class Trainer:
finally:
self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
return self.checkpoint
return self.checkpoint

View File

@ -204,9 +204,9 @@ def test_checkpoint_train(test_env):
dataset=dataset,
optimizer=optimizer,
checkpoint_dir=test_env["test_dir"],
n_epoch=2,
n_epoch=1,
batch_size=2,
checkpoint_interval=5,
checkpoint_interval=1,
accumulation_steps=1,
max_grad_norm=1.0,
random_seed=42
@ -218,13 +218,18 @@ def test_checkpoint_train(test_env):
pad_token_id=test_env["tokenizer"].pad_id
)
schedule_config = CosineScheduleConfig(
warmup_steps=100,
total_steps=1000
warmup_steps=1,
total_steps=5
)
trainer = Trainer(param, train_config, schedule_config)
checkpoint = None
try:
trainer.train()
checkpoint = trainer.train()
except Exception:
checkpoint = trainer.checkpoint
trainer.train(train_checkpoint=checkpoint)
pass
checkpoint = trainer.train(train_checkpoint=checkpoint)
assert len(checkpoint.loss_list) == 5 - 1