feat(khaosz/trainer): 引入 TrainContext 和 TrainContextBuilder 优化训练上下文管理

This commit is contained in:
ViperEkura 2025-10-03 22:42:11 +08:00
parent 6e1a497c04
commit 240ee00221
2 changed files with 131 additions and 85 deletions

View File

@ -0,0 +1,91 @@
from dataclasses import dataclass
from typing import Optional, Self, TYPE_CHECKING
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from khaosz.core.parameter import Checkpoint
from khaosz.trainer.data_util import RandomSampler
if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer
@dataclass
class TrainContext:
dataloader: DataLoader
optimizer: Optimizer
sampler: RandomSampler
epoch: int
current_iter: int
loss: float
checkpoint: Checkpoint
class TrainContextBuilder:
def __init__(self, trainer: 'Trainer'):
self.trainer = trainer
self._context = TrainContext(
dataloader=None,
optimizer=None,
sampler=None,
epoch=0,
current_iter=0,
loss=0.0,
checkpoint=None
)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
checkpoint = Checkpoint(
model=self.trainer.parameter.model,
tokenizer=self.trainer.parameter.tokenizer,
config=self.trainer.parameter.config,
sampler_state=None,
optim_state=None,
loss_list=[]
)
self._context.checkpoint = checkpoint
return self
def with_sampler(self) -> Self:
seed = self.trainer.train_config.random_seed
sampler = RandomSampler(
data_source=self.trainer.train_config.dataset,
seed=seed
)
if self._context.checkpoint and self._context.checkpoint.sampler_state:
sampler.load_state_dict(self._context.checkpoint.sampler_state)
self._context.sampler = sampler
self._context.epoch = sampler.epoch
self._context.current_iter = sampler.current_iter
if self._context.checkpoint:
self._context.checkpoint.sampler_state = sampler.state_dict()
return self
def with_optimizer(self) -> Self:
optimizer = self.trainer.train_config.optimizer
if self._context.checkpoint and self._context.checkpoint.optim_state:
optimizer.load_state_dict(self._context.checkpoint.optim_state)
self._context.optimizer = optimizer
if self._context.checkpoint:
self._context.checkpoint.optim_state = optimizer.state_dict()
return self
def with_dataloader(self) -> Self:
dataloader = DataLoader(
self.trainer.train_config.dataset,
batch_size=self.trainer.train_config.batch_size,
sampler=self._context.sampler
)
self._context.dataloader = dataloader
return self
def build(self) -> TrainContext:
return self._context

View File

@ -1,9 +1,7 @@
import logging import logging
from typing import Optional, List, cast from typing import Optional, List
from torch.utils.data import DataLoader
from khaosz.core import ModelParameter, Checkpoint from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.data_util import RandomSampler
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
from khaosz.trainer.trainer_callback import ( from khaosz.trainer.trainer_callback import (
TrainerCallback, TrainerCallback,
@ -12,6 +10,7 @@ from khaosz.trainer.trainer_callback import (
GradientClippingCallback, GradientClippingCallback,
SchedulerCallback SchedulerCallback
) )
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,7 +26,7 @@ class Trainer:
self.train_config = train_config self.train_config = train_config
self.schedule_config = schedule_config self.schedule_config = schedule_config
self.callbacks = callbacks or self._get_default_callbacks() self.callbacks = callbacks or self._get_default_callbacks()
def _get_default_callbacks(self) -> List[TrainerCallback]: def _get_default_callbacks(self) -> List[TrainerCallback]:
return [ return [
ProgressBarCallback(), ProgressBarCallback(),
@ -35,106 +34,62 @@ class Trainer:
GradientClippingCallback(), GradientClippingCallback(),
SchedulerCallback(self.schedule_config), SchedulerCallback(self.schedule_config),
] ]
def _set_train_kwargs(self, kwargs: dict):
seed = self.train_config.random_seed
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
optim = self.train_config.optimizer
checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None))
if checkpoint is None: def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
checkpoint = Checkpoint( return (TrainContextBuilder(self)
model=self.parameter.model, .with_checkpoint(checkpoint)
tokenizer=self.parameter.tokenizer, .with_sampler()
config=self.parameter.config, .with_optimizer()
sampler_state=None, .with_dataloader()
optim_state=None, .build())
loss_list=[]
) def _call_callbacks(self, method_name: str, context: TrainContext):
kwargs = {
'dataloader': context.dataloader,
'optimizer': context.optimizer,
'sampler': context.sampler,
'epoch': context.epoch,
'current_iter': context.current_iter,
'loss': context.loss,
'checkpoint': context.checkpoint
}
sampler_state = checkpoint.sampler_state
optim_state = checkpoint.optim_state
if sampler_state:
sampler.load_state_dict(sampler_state)
if optim_state:
optim.load_state_dict(optim_state)
checkpoint.optim_state = optim.state_dict()
checkpoint.sampler_state = sampler.state_dict()
dataloader = DataLoader(
self.train_config.dataset,
batch_size=self.train_config.batch_size,
sampler=sampler
)
kwargs["dataloader"] = dataloader
kwargs["optimizer"] = self.train_config.optimizer
kwargs["epoch"] = sampler.epoch
kwargs["current_iter"] = sampler.current_iter
kwargs["sampler"] = sampler
kwargs["checkpoint"] = checkpoint
def _call_callbacks(self, method_name: str, **kwargs):
for callback in self.callbacks: for callback in self.callbacks:
method = getattr(callback, method_name, None) method = getattr(callback, method_name, None)
if method: if method:
method(self, **kwargs) method(self, **kwargs)
def train( def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
self, context = self._build_train_context(checkpoint)
checkpoint: Optional[Checkpoint] = None
) -> Checkpoint:
# train self._call_callbacks('on_train_begin', context)
train_kwargs = {
'checkpoint': checkpoint,
'dataloader': None,
'optimizer': None,
'sampler': None,
'epoch': 0,
'current_iter': 0,
'loss': 0.0,
}
self._set_train_kwargs(train_kwargs)
self._call_callbacks('on_train_begin', **train_kwargs)
dataloader = train_kwargs['dataloader']
checkpoint = train_kwargs['checkpoint']
start_epoch = train_kwargs['epoch']
try: try:
self.parameter.model.train() self.parameter.model.train()
for epoch in range(start_epoch, self.train_config.n_epoch): for epoch in range(context.epoch, self.train_config.n_epoch):
# epoch context.epoch = epoch
train_kwargs["epoch"] = epoch self._call_callbacks('on_epoch_begin', context)
self._call_callbacks('on_epoch_begin', **train_kwargs)
for batch in dataloader: for batch in context.dataloader:
if context.current_iter % self.train_config.accumulation_steps == 0:
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0: self._call_callbacks('on_step_begin', context)
# step
self._call_callbacks('on_step_begin', **train_kwargs)
self.train_config.optimizer.step() self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad() self.train_config.optimizer.zero_grad()
self._call_callbacks('on_step_end', **train_kwargs) self._call_callbacks('on_step_end', context)
# batch self._call_callbacks('on_batch_begin', context)
self._call_callbacks('on_batch_begin', **train_kwargs)
loss = self.train_config.strategy(batch) loss = self.train_config.strategy(batch)
train_kwargs["loss"] = loss.item() context.loss = loss.item()
train_kwargs["current_iter"] += 1 context.current_iter += 1
loss.backward() loss.backward()
self._call_callbacks('on_batch_end', **train_kwargs) self._call_callbacks('on_batch_end', context)
self._call_callbacks('on_epoch_end', **train_kwargs) self._call_callbacks('on_epoch_end', context)
except Exception as e: except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True) logger.error(f"Training failed: {str(e)}", exc_info=True)
raise raise
finally: finally:
self._call_callbacks('on_train_end', **train_kwargs) self._call_callbacks('on_train_end', context)
return checkpoint return context.checkpoint