feat(trainer): 引入训练器回调机制并重构训练流程

This commit is contained in:
ViperEkura 2025-09-29 11:31:31 +08:00
parent 92999fa9f6
commit 816bc78894
1 changed files with 161 additions and 58 deletions

View File

@ -1,7 +1,8 @@
import os import os
import torch import torch
from abc import abstractmethod
from typing import Optional from typing import Optional, List, override
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
@ -11,12 +12,106 @@ from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig from khaosz.trainer.strategy import SchedulerFactory, TrainConfig, ScheduleConfig
class TrainerCallback:
@abstractmethod
def on_train_begin(self, trainer: 'Trainer', **kwargs):
pass
@abstractmethod
def on_train_end(self, trainer: 'Trainer', **kwargs):
pass
@abstractmethod
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
pass
@abstractmethod
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
pass
@abstractmethod
def on_batch_begin(self, trainer: 'Trainer', **kwargs):
pass
@abstractmethod
def on_batch_end(self, trainer: 'Trainer', **kwargs):
pass
@abstractmethod
def on_step_begin(self, trainer: 'Trainer', **kwargs):
pass
@abstractmethod
def on_step_end(self, trainer: 'Trainer', **kwargs):
pass
class ProgressBarCallback(TrainerCallback):
def __init__(self):
self.progress_bar: tqdm = None
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
epoch = kwargs.get('epoch')
dataloader = trainer._create_dataloader()
self.progress_bar = tqdm(
dataloader,
desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}",
dynamic_ncols=True
)
def on_batch_end(self, trainer: 'Trainer', **kwargs):
loss = kwargs.get('loss')
self.progress_bar.set_postfix({
"loss": f"{loss:.4f}",
"lr": f"{trainer.train_config.optimizer.param_groups[0]['lr']:.2e}"
})
self.progress_bar.update(1)
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
if self.progress_bar:
self.progress_bar.close()
class CheckpointCallback(TrainerCallback):
def __init__(self, checkpoint_interval: int):
self.checkpoint_interval = checkpoint_interval
self.last_ckpt_iter = 0
def on_train_begin(self, trainer: 'Trainer', **kwargs):
checkpoint = kwargs.get('checkpoint')
self.last_ckpt_iter = len(checkpoint.loss_list)
def on_batch_end(self, trainer: 'Trainer', **kwargs):
current_iter = kwargs.get('current_iter')
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
trainer._save_checkpoint()
self.last_ckpt_iter = current_iter
def on_train_end(self, trainer: 'Trainer', **kwargs):
checkpoint = kwargs.get('checkpoint')
current_iter = len(checkpoint.loss_list)
if current_iter != self.last_ckpt_iter:
trainer._save_checkpoint()
class GradientClippingCallback(TrainerCallback):
def on_step_begin(self, trainer: 'Trainer', **kwargs):
clip_grad_norm_(
trainer.checkpoint.model.parameters(),
trainer.train_config.max_grad_norm
)
class Trainer: class Trainer:
def __init__( def __init__(
self, self,
parameter: ModelParameter, parameter: ModelParameter,
train_config: TrainConfig, train_config: TrainConfig,
schedule_config: ScheduleConfig schedule_config: ScheduleConfig,
callbacks: Optional[List[TrainerCallback]] = None
): ):
self.checkpoint = Checkpoint( self.checkpoint = Checkpoint(
model=parameter.model, model=parameter.model,
@ -26,16 +121,37 @@ class Trainer:
self.train_config = train_config self.train_config = train_config
self.schedule_config = schedule_config self.schedule_config = schedule_config
def save_checkpoint( self.callbacks = callbacks or self._get_default_callbacks()
self,
loss_list: list, def _get_default_callbacks(self) -> List[TrainerCallback]:
): return [
current_iter = len(loss_list) ProgressBarCallback(),
CheckpointCallback(self.train_config.checkpoint_interval),
GradientClippingCallback(),
]
def _create_dataloader(self) -> DataLoader:
seed = self.train_config.random_seed
generator = torch.Generator().manual_seed(seed)
sampler = RandomSampler(self.train_config.dataset, generator=generator)
return DataLoader(
self.train_config.dataset,
batch_size=self.train_config.batch_size,
sampler=sampler
)
def _save_checkpoint(self):
current_iter = len(self.checkpoint.loss_list)
save_path = os.path.join(self.train_config.checkpoint_dir, f"iter_{current_iter}") save_path = os.path.join(self.train_config.checkpoint_dir, f"iter_{current_iter}")
self.checkpoint.loss_list = loss_list
self.checkpoint.optim_state = self.train_config.optimizer.state_dict() self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
self.checkpoint.save(save_path) self.checkpoint.save(save_path)
def _call_callbacks(self, method_name: str, **kwargs):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
if method:
method(self, **kwargs)
def train( def train(
self, self,
train_checkpoint: Optional[Checkpoint] = None train_checkpoint: Optional[Checkpoint] = None
@ -47,16 +163,13 @@ class Trainer:
self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state) self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state)
self.checkpoint.optim_state = self.train_config.optimizer.state_dict() self.checkpoint.optim_state = self.train_config.optimizer.state_dict()
loss_list = self.checkpoint.loss_list
current_iter = len(self.checkpoint.loss_list) current_iter = len(self.checkpoint.loss_list)
last_ckpt_iter = current_iter
for group in self.train_config.optimizer.param_groups: for group in self.train_config.optimizer.param_groups:
if "initial_lr" not in group: if "initial_lr" not in group:
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
**self.schedule_config.get_kwargs() **self.schedule_config.get_kwargs()
) )
@ -66,52 +179,42 @@ class Trainer:
last_epoch=current_iter - 1 if train_checkpoint else -1 last_epoch=current_iter - 1 if train_checkpoint else -1
) )
seed = self.train_config.random_seed reamining_steps = self.train_config.n_epoch - current_iter
generator = torch.Generator().manual_seed(seed) total_steps = len(self.train_config.dataset) // self.train_config.batch_size
sampler = RandomSampler(self.train_config.dataset, generator=generator) remaining_epochs = (reamining_steps + total_steps - 1) // total_steps
remaining_epochs = self.train_config.n_epoch - current_iter // ( # train
len(self.train_config.dataset) // self.train_config.batch_size) self._call_callbacks('on_train_begin', checkpoint=self.checkpoint)
for epoch in range(remaining_epochs): try:
self.checkpoint.model.train() for epoch in range(remaining_epochs):
dataloader = DataLoader( self.checkpoint.model.train()
self.train_config.dataset,
batch_size=self.train_config.batch_size,
sampler=sampler
)
progress_bar = tqdm(
dataloader,
desc=f"Epoch {epoch+1}/{self.train_config.n_epoch}",
dynamic_ncols=True
)
for batch in progress_bar:
#forward
loss = self.train_config.strategy(batch)
loss_list.append(loss.item())
#backward
loss.backward()
#step
if current_iter % self.train_config.accumulation_steps == 0:
clip_grad_norm_(
self.checkpoint.model.parameters(),
self.train_config.max_grad_norm
)
self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad()
current_iter += 1 # epoch
scheduler.step() self._call_callbacks('on_epoch_begin', epoch=epoch)
progress_bar.set_postfix({
"loss": f"{loss.item():.4f}",
"lr": f"{self.train_config.optimizer.param_groups[0]['lr']:.2e}"
})
#save checkpotint
if current_iter - last_ckpt_iter >= self.train_config.checkpoint_interval:
self.save_checkpoint(loss_list)
last_ckpt_iter = current_iter
if current_iter != last_ckpt_iter: dataloader = self._create_dataloader()
self.save_checkpoint(loss_list)
last_ckpt_iter = current_iter for batch in dataloader:
# batch
self._call_callbacks('on_batch_begin', batch=batch)
loss = self.train_config.strategy(batch)
self.checkpoint.loss_list.append(loss.item())
loss.backward()
self._call_callbacks('on_batch_end', batch=batch, loss=loss.item(), current_iter=current_iter)
if current_iter % self.train_config.accumulation_steps == 0:
# step
self._call_callbacks('on_step_begin', current_iter=current_iter)
self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad()
self._call_callbacks('on_step_end', current_iter=current_iter)
current_iter += 1
scheduler.step()
self._call_callbacks('on_epoch_end', epoch=epoch, loss_list=self.checkpoint.loss_list)
finally:
self._call_callbacks('on_train_end', checkpoint=self.checkpoint)
return self.checkpoint return self.checkpoint