fix(khaosz/trainer): 修正训练步数计算逻辑

This commit is contained in:
ViperEkura 2025-09-29 19:05:26 +08:00
parent c104a400e7
commit e467420475
3 changed files with 38 additions and 28 deletions

View File

@ -154,6 +154,11 @@ class SchedulerCallback(TrainerCallback):
def on_train_begin(self, trainer: 'Trainer', **kwargs): def on_train_begin(self, trainer: 'Trainer', **kwargs):
checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
self.current_iter = len(checkpoint.loss_list) self.current_iter = len(checkpoint.loss_list)
for group in trainer.train_config.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
self.schedule_config.validate() self.schedule_config.validate()
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
self.schedule_config self.schedule_config

View File

@ -1,4 +1,5 @@
import torch import torch
import itertools
from typing import Optional, List from typing import Optional, List
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
@ -38,16 +39,21 @@ class Trainer:
SchedulerCallback(self.schedule_config), SchedulerCallback(self.schedule_config),
] ]
def _create_dataloader(self) -> DataLoader: def _create_dataloader(self, start_index: int = 0) -> DataLoader:
seed = self.train_config.random_seed seed = self.train_config.random_seed
generator = torch.Generator().manual_seed(seed) generator = torch.Generator().manual_seed(seed)
sampler = RandomSampler(self.train_config.dataset, generator=generator) sampler = RandomSampler(self.train_config.dataset, generator=generator)
return DataLoader( dataloader = DataLoader(
self.train_config.dataset, self.train_config.dataset,
batch_size=self.train_config.batch_size, batch_size=self.train_config.batch_size,
sampler=sampler sampler=sampler
) )
if start_index > 0:
dataloader = itertools.islice(dataloader, start_index, None)
return dataloader
def _call_callbacks(self, method_name: str, **kwargs): def _call_callbacks(self, method_name: str, **kwargs):
for callback in self.callbacks: for callback in self.callbacks:
method = getattr(callback, method_name, None) method = getattr(callback, method_name, None)
@ -58,34 +64,29 @@ class Trainer:
self, self,
train_checkpoint: Optional[Checkpoint] = None train_checkpoint: Optional[Checkpoint] = None
) -> Checkpoint: ) -> Checkpoint:
assert self.schedule_config.schedule_type in ["cosine", "sgdr"]
if train_checkpoint: if train_checkpoint:
self.checkpoint = train_checkpoint self.checkpoint = train_checkpoint
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state) self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
else:
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
current_iter = len(self.checkpoint.loss_list) current_iter = len(self.checkpoint.loss_list)
total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size
total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter
for group in self.train_config.optimizer.param_groups: current_epochs = total_reamining_steps // total_steps_per_epoch
if "initial_lr" not in group: current_steps = total_reamining_steps % total_steps_per_epoch
group["initial_lr"] = group["lr"]
reamining_steps = self.train_config.n_epoch - current_iter
total_steps = len(self.train_config.dataset) // self.train_config.batch_size
remaining_epochs = (reamining_steps + total_steps - 1) // total_steps
# train # train
self._call_callbacks('on_train_begin', checkpoint=self.checkpoint) self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
self.checkpoint.model.train()
try: try:
for epoch in range(remaining_epochs): for epoch in range(current_epochs):
self.checkpoint.model.train()
# epoch # epoch
self._call_callbacks('on_epoch_begin', epoch=epoch) self._call_callbacks('on_epoch_begin', epoch=epoch)
dataloader = self._create_dataloader(start_index=current_steps)
dataloader = self._create_dataloader()
for batch in dataloader: for batch in dataloader:
# batch # batch
self._call_callbacks('on_batch_begin', batch=batch) self._call_callbacks('on_batch_begin', batch=batch)
@ -110,5 +111,4 @@ class Trainer:
finally: finally:
self._call_callbacks('on_train_end', checkpoint=self.checkpoint) self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
return self.checkpoint
return self.checkpoint

View File

@ -204,9 +204,9 @@ def test_checkpoint_train(test_env):
dataset=dataset, dataset=dataset,
optimizer=optimizer, optimizer=optimizer,
checkpoint_dir=test_env["test_dir"], checkpoint_dir=test_env["test_dir"],
n_epoch=2, n_epoch=1,
batch_size=2, batch_size=2,
checkpoint_interval=5, checkpoint_interval=1,
accumulation_steps=1, accumulation_steps=1,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42 random_seed=42
@ -218,13 +218,18 @@ def test_checkpoint_train(test_env):
pad_token_id=test_env["tokenizer"].pad_id pad_token_id=test_env["tokenizer"].pad_id
) )
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(
warmup_steps=100, warmup_steps=1,
total_steps=1000 total_steps=5
) )
trainer = Trainer(param, train_config, schedule_config) trainer = Trainer(param, train_config, schedule_config)
checkpoint = None
try: try:
trainer.train() checkpoint = trainer.train()
except Exception: except Exception:
checkpoint = trainer.checkpoint pass
trainer.train(train_checkpoint=checkpoint)
checkpoint = trainer.train(train_checkpoint=checkpoint)
assert len(checkpoint.loss_list) == 5 - 1