diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py new file mode 100644 index 0000000..f2826b4 --- /dev/null +++ b/khaosz/trainer/train_context.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass +from typing import Optional, Self, TYPE_CHECKING +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from khaosz.core.parameter import Checkpoint +from khaosz.trainer.data_util import RandomSampler + +if TYPE_CHECKING: + from khaosz.trainer.trainer import Trainer + + +@dataclass +class TrainContext: + dataloader: DataLoader + optimizer: Optimizer + sampler: RandomSampler + epoch: int + current_iter: int + loss: float + checkpoint: Checkpoint + + +class TrainContextBuilder: + def __init__(self, trainer: 'Trainer'): + self.trainer = trainer + self._context = TrainContext( + dataloader=None, + optimizer=None, + sampler=None, + epoch=0, + current_iter=0, + loss=0.0, + checkpoint=None + ) + + def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: + if checkpoint is None: + checkpoint = Checkpoint( + model=self.trainer.parameter.model, + tokenizer=self.trainer.parameter.tokenizer, + config=self.trainer.parameter.config, + sampler_state=None, + optim_state=None, + loss_list=[] + ) + self._context.checkpoint = checkpoint + return self + + def with_sampler(self) -> Self: + seed = self.trainer.train_config.random_seed + sampler = RandomSampler( + data_source=self.trainer.train_config.dataset, + seed=seed + ) + + if self._context.checkpoint and self._context.checkpoint.sampler_state: + sampler.load_state_dict(self._context.checkpoint.sampler_state) + + self._context.sampler = sampler + self._context.epoch = sampler.epoch + self._context.current_iter = sampler.current_iter + + if self._context.checkpoint: + self._context.checkpoint.sampler_state = sampler.state_dict() + + return self + + def with_optimizer(self) -> Self: + optimizer = self.trainer.train_config.optimizer + + if self._context.checkpoint and self._context.checkpoint.optim_state: + optimizer.load_state_dict(self._context.checkpoint.optim_state) + + self._context.optimizer = optimizer + + if self._context.checkpoint: + self._context.checkpoint.optim_state = optimizer.state_dict() + + return self + + def with_dataloader(self) -> Self: + dataloader = DataLoader( + self.trainer.train_config.dataset, + batch_size=self.trainer.train_config.batch_size, + sampler=self._context.sampler + ) + self._context.dataloader = dataloader + return self + + def build(self) -> TrainContext: + return self._context \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 4e182d0..da9a88c 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,9 +1,7 @@ import logging -from typing import Optional, List, cast -from torch.utils.data import DataLoader +from typing import Optional, List from khaosz.core import ModelParameter, Checkpoint -from khaosz.trainer.data_util import RandomSampler from khaosz.trainer.strategy import TrainConfig, ScheduleConfig from khaosz.trainer.trainer_callback import ( TrainerCallback, @@ -12,6 +10,7 @@ from khaosz.trainer.trainer_callback import ( GradientClippingCallback, SchedulerCallback ) +from khaosz.trainer.train_context import TrainContext, TrainContextBuilder logger = logging.getLogger(__name__) @@ -27,7 +26,7 @@ class Trainer: self.train_config = train_config self.schedule_config = schedule_config self.callbacks = callbacks or self._get_default_callbacks() - + def _get_default_callbacks(self) -> List[TrainerCallback]: return [ ProgressBarCallback(), @@ -35,106 +34,62 @@ class Trainer: GradientClippingCallback(), SchedulerCallback(self.schedule_config), ] - - def _set_train_kwargs(self, kwargs: dict): - seed = self.train_config.random_seed - sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed) - optim = self.train_config.optimizer - checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None)) - if checkpoint is None: - checkpoint = Checkpoint( - model=self.parameter.model, - tokenizer=self.parameter.tokenizer, - config=self.parameter.config, - sampler_state=None, - optim_state=None, - loss_list=[] - ) + def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: + return (TrainContextBuilder(self) + .with_checkpoint(checkpoint) + .with_sampler() + .with_optimizer() + .with_dataloader() + .build()) + + def _call_callbacks(self, method_name: str, context: TrainContext): + kwargs = { + 'dataloader': context.dataloader, + 'optimizer': context.optimizer, + 'sampler': context.sampler, + 'epoch': context.epoch, + 'current_iter': context.current_iter, + 'loss': context.loss, + 'checkpoint': context.checkpoint + } - sampler_state = checkpoint.sampler_state - optim_state = checkpoint.optim_state - - if sampler_state: - sampler.load_state_dict(sampler_state) - - if optim_state: - optim.load_state_dict(optim_state) - - checkpoint.optim_state = optim.state_dict() - checkpoint.sampler_state = sampler.state_dict() - - dataloader = DataLoader( - self.train_config.dataset, - batch_size=self.train_config.batch_size, - sampler=sampler - ) - - kwargs["dataloader"] = dataloader - kwargs["optimizer"] = self.train_config.optimizer - kwargs["epoch"] = sampler.epoch - kwargs["current_iter"] = sampler.current_iter - kwargs["sampler"] = sampler - kwargs["checkpoint"] = checkpoint - - def _call_callbacks(self, method_name: str, **kwargs): for callback in self.callbacks: method = getattr(callback, method_name, None) if method: method(self, **kwargs) - def train( - self, - checkpoint: Optional[Checkpoint] = None - ) -> Checkpoint: + def train(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint: + context = self._build_train_context(checkpoint) - # train - train_kwargs = { - 'checkpoint': checkpoint, - 'dataloader': None, - 'optimizer': None, - 'sampler': None, - 'epoch': 0, - 'current_iter': 0, - 'loss': 0.0, - } - - self._set_train_kwargs(train_kwargs) - self._call_callbacks('on_train_begin', **train_kwargs) - - dataloader = train_kwargs['dataloader'] - checkpoint = train_kwargs['checkpoint'] - start_epoch = train_kwargs['epoch'] + self._call_callbacks('on_train_begin', context) try: self.parameter.model.train() - for epoch in range(start_epoch, self.train_config.n_epoch): - # epoch - train_kwargs["epoch"] = epoch - self._call_callbacks('on_epoch_begin', **train_kwargs) - for batch in dataloader: - - if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0: - # step - self._call_callbacks('on_step_begin', **train_kwargs) + for epoch in range(context.epoch, self.train_config.n_epoch): + context.epoch = epoch + self._call_callbacks('on_epoch_begin', context) + + for batch in context.dataloader: + if context.current_iter % self.train_config.accumulation_steps == 0: + self._call_callbacks('on_step_begin', context) self.train_config.optimizer.step() self.train_config.optimizer.zero_grad() - self._call_callbacks('on_step_end', **train_kwargs) - - # batch - self._call_callbacks('on_batch_begin', **train_kwargs) + self._call_callbacks('on_step_end', context) + + self._call_callbacks('on_batch_begin', context) loss = self.train_config.strategy(batch) - train_kwargs["loss"] = loss.item() - train_kwargs["current_iter"] += 1 + context.loss = loss.item() + context.current_iter += 1 loss.backward() - self._call_callbacks('on_batch_end', **train_kwargs) + self._call_callbacks('on_batch_end', context) - self._call_callbacks('on_epoch_end', **train_kwargs) + self._call_callbacks('on_epoch_end', context) except Exception as e: logger.error(f"Training failed: {str(e)}", exc_info=True) raise finally: - self._call_callbacks('on_train_end', **train_kwargs) - return checkpoint \ No newline at end of file + self._call_callbacks('on_train_end', context) + return context.checkpoint \ No newline at end of file