feat(khaosz/trainer): 引入 TrainContext 和 TrainContextBuilder 优化训练上下文管理

This commit is contained in:
ViperEkura 2025-10-03 22:42:11 +08:00
parent 6e1a497c04
commit 240ee00221
2 changed files with 131 additions and 85 deletions

View File

@ -0,0 +1,91 @@
from dataclasses import dataclass
from typing import Optional, Self, TYPE_CHECKING
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from khaosz.core.parameter import Checkpoint
from khaosz.trainer.data_util import RandomSampler
if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer
@dataclass
class TrainContext:
dataloader: DataLoader
optimizer: Optimizer
sampler: RandomSampler
epoch: int
current_iter: int
loss: float
checkpoint: Checkpoint
class TrainContextBuilder:
def __init__(self, trainer: 'Trainer'):
self.trainer = trainer
self._context = TrainContext(
dataloader=None,
optimizer=None,
sampler=None,
epoch=0,
current_iter=0,
loss=0.0,
checkpoint=None
)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
checkpoint = Checkpoint(
model=self.trainer.parameter.model,
tokenizer=self.trainer.parameter.tokenizer,
config=self.trainer.parameter.config,
sampler_state=None,
optim_state=None,
loss_list=[]
)
self._context.checkpoint = checkpoint
return self
def with_sampler(self) -> Self:
seed = self.trainer.train_config.random_seed
sampler = RandomSampler(
data_source=self.trainer.train_config.dataset,
seed=seed
)
if self._context.checkpoint and self._context.checkpoint.sampler_state:
sampler.load_state_dict(self._context.checkpoint.sampler_state)
self._context.sampler = sampler
self._context.epoch = sampler.epoch
self._context.current_iter = sampler.current_iter
if self._context.checkpoint:
self._context.checkpoint.sampler_state = sampler.state_dict()
return self
def with_optimizer(self) -> Self:
optimizer = self.trainer.train_config.optimizer
if self._context.checkpoint and self._context.checkpoint.optim_state:
optimizer.load_state_dict(self._context.checkpoint.optim_state)
self._context.optimizer = optimizer
if self._context.checkpoint:
self._context.checkpoint.optim_state = optimizer.state_dict()
return self
def with_dataloader(self) -> Self:
dataloader = DataLoader(
self.trainer.train_config.dataset,
batch_size=self.trainer.train_config.batch_size,
sampler=self._context.sampler
)
self._context.dataloader = dataloader
return self
def build(self) -> TrainContext:
return self._context

View File

@ -1,9 +1,7 @@
import logging
from typing import Optional, List, cast
from torch.utils.data import DataLoader
from typing import Optional, List
from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.data_util import RandomSampler
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
from khaosz.trainer.trainer_callback import (
TrainerCallback,
@ -12,6 +10,7 @@ from khaosz.trainer.trainer_callback import (
GradientClippingCallback,
SchedulerCallback
)
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
logger = logging.getLogger(__name__)
@ -27,7 +26,7 @@ class Trainer:
self.train_config = train_config
self.schedule_config = schedule_config
self.callbacks = callbacks or self._get_default_callbacks()
def _get_default_callbacks(self) -> List[TrainerCallback]:
return [
ProgressBarCallback(),
@ -35,106 +34,62 @@ class Trainer:
GradientClippingCallback(),
SchedulerCallback(self.schedule_config),
]
def _set_train_kwargs(self, kwargs: dict):
seed = self.train_config.random_seed
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
optim = self.train_config.optimizer
checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None))
if checkpoint is None:
checkpoint = Checkpoint(
model=self.parameter.model,
tokenizer=self.parameter.tokenizer,
config=self.parameter.config,
sampler_state=None,
optim_state=None,
loss_list=[]
)
def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self)
.with_checkpoint(checkpoint)
.with_sampler()
.with_optimizer()
.with_dataloader()
.build())
def _call_callbacks(self, method_name: str, context: TrainContext):
kwargs = {
'dataloader': context.dataloader,
'optimizer': context.optimizer,
'sampler': context.sampler,
'epoch': context.epoch,
'current_iter': context.current_iter,
'loss': context.loss,
'checkpoint': context.checkpoint
}
sampler_state = checkpoint.sampler_state
optim_state = checkpoint.optim_state
if sampler_state:
sampler.load_state_dict(sampler_state)
if optim_state:
optim.load_state_dict(optim_state)
checkpoint.optim_state = optim.state_dict()
checkpoint.sampler_state = sampler.state_dict()
dataloader = DataLoader(
self.train_config.dataset,
batch_size=self.train_config.batch_size,
sampler=sampler
)
kwargs["dataloader"] = dataloader
kwargs["optimizer"] = self.train_config.optimizer
kwargs["epoch"] = sampler.epoch
kwargs["current_iter"] = sampler.current_iter
kwargs["sampler"] = sampler
kwargs["checkpoint"] = checkpoint
def _call_callbacks(self, method_name: str, **kwargs):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
if method:
method(self, **kwargs)
def train(
self,
checkpoint: Optional[Checkpoint] = None
) -> Checkpoint:
def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_train_context(checkpoint)
# train
train_kwargs = {
'checkpoint': checkpoint,
'dataloader': None,
'optimizer': None,
'sampler': None,
'epoch': 0,
'current_iter': 0,
'loss': 0.0,
}
self._set_train_kwargs(train_kwargs)
self._call_callbacks('on_train_begin', **train_kwargs)
dataloader = train_kwargs['dataloader']
checkpoint = train_kwargs['checkpoint']
start_epoch = train_kwargs['epoch']
self._call_callbacks('on_train_begin', context)
try:
self.parameter.model.train()
for epoch in range(start_epoch, self.train_config.n_epoch):
# epoch
train_kwargs["epoch"] = epoch
self._call_callbacks('on_epoch_begin', **train_kwargs)
for batch in dataloader:
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
# step
self._call_callbacks('on_step_begin', **train_kwargs)
for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch
self._call_callbacks('on_epoch_begin', context)
for batch in context.dataloader:
if context.current_iter % self.train_config.accumulation_steps == 0:
self._call_callbacks('on_step_begin', context)
self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad()
self._call_callbacks('on_step_end', **train_kwargs)
# batch
self._call_callbacks('on_batch_begin', **train_kwargs)
self._call_callbacks('on_step_end', context)
self._call_callbacks('on_batch_begin', context)
loss = self.train_config.strategy(batch)
train_kwargs["loss"] = loss.item()
train_kwargs["current_iter"] += 1
context.loss = loss.item()
context.current_iter += 1
loss.backward()
self._call_callbacks('on_batch_end', **train_kwargs)
self._call_callbacks('on_batch_end', context)
self._call_callbacks('on_epoch_end', **train_kwargs)
self._call_callbacks('on_epoch_end', context)
except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True)
raise
finally:
self._call_callbacks('on_train_end', **train_kwargs)
return checkpoint
self._call_callbacks('on_train_end', context)
return context.checkpoint