diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index aba4a54..a941583 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -26,17 +26,22 @@ class TrainContext: iteration: int = field(default=0) loss: float = field(default=0.0) - wolrd_size: int = field(default=1) + world_size: int = field(default=1) rank: int = field(default=0) class TrainContextBuilder: def __init__(self, config: TrainConfig): self.config = config - self._context: TrainContext = None + self._context = TrainContext( + model=config.model, + optimizer=config.optimizer, + scheduler=config.scheduler, + world_size=get_world_size(), + rank=get_rank(), + ) def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: - self._context = TrainContext() if checkpoint is None: checkpoint = Checkpoint( optimizer_state=self.config.optimizer.state_dict(), @@ -46,37 +51,12 @@ class TrainContextBuilder: # resume from the assigned checkpoint or assigned iteration self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) self._context.iteration = max(checkpoint.iteration, self.config.start_batch) + self._context.optimizer.load_state_dict(checkpoint.optimizer_state) + self._context.scheduler.load_state_dict(checkpoint.scheduler_state) self._context.checkpoint = checkpoint return self - def with_optimizer(self) -> Self: - optimizer = self.config.optimizer - if self._context is None: - raise RuntimeError("Must call with_checkpoint() before with_optimizer()") - - if self._context.checkpoint and self._context.checkpoint.optimizer_state: - optimizer.load_state_dict(self._context.checkpoint.optimizer_state) - - self._context.optimizer = optimizer - - if self._context.checkpoint: - self._context.checkpoint.optimizer_state = optimizer.state_dict() - - return self - - def with_scheduler(self) -> Self: - scheduler = self.config.scheduler - if self._context.checkpoint and self._context.checkpoint.scheduler_state: - scheduler.load_state_dict(self._context.checkpoint.scheduler_state) - - self._context.scheduler = scheduler - - if self._context.checkpoint: - self._context.checkpoint.scheduler_state = scheduler.state_dict() - - return self - def with_dataloader(self) -> Self: # fix: change batch level iteration to sample level offset config = self.config @@ -110,10 +90,4 @@ class TrainContextBuilder: return self def build(self) -> TrainContext: - self._context.model = self.config.model - - if self.config.nprocs > 1: - self._context.wolrd_size = get_world_size() - self._context.rank = get_rank() - return self._context \ No newline at end of file diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index c6a61c9..8456ad8 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -34,8 +34,6 @@ class Trainer: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: return (TrainContextBuilder(self.train_config) .with_checkpoint(checkpoint) - .with_optimizer() - .with_scheduler() .with_dataloader() .with_strategy() .build())