diff --git a/khaosz/config/train_config.py b/khaosz/config/train_config.py index 6f8162f..28214da 100644 --- a/khaosz/config/train_config.py +++ b/khaosz/config/train_config.py @@ -34,6 +34,14 @@ class TrainConfig: default=4, metadata={"help": "Batch size for training."} ) + start_epoch: int = field( + default=0, + metadata={"help": "Start epoch for training."} + ) + start_batch: int = field( + default=0, + metadata={"help": "Start batch iteration for training."} + ) checkpoint_interval: int = field( default=5000, metadata={"help": "Number of iterations between checkpoints."} diff --git a/khaosz/trainer/train_callback.py b/khaosz/trainer/train_callback.py index 5053bc0..1673221 100644 --- a/khaosz/trainer/train_callback.py +++ b/khaosz/trainer/train_callback.py @@ -161,7 +161,8 @@ class StepMonitorCallback(TrainCallback): Args: log_dir: Directory to save log files. If None, logs won't be saved to file. log_interval: Log every N steps - metrics: List of metrics to log. Supported: ['loss', 'lr', 'grad_norm', 'grad_std', grad_max', 'grad_min', 'grad_mean', 'grad_nan_num'] + metrics: List of metrics to log. Supported: ['loss', 'lr', 'grad_norm', 'grad_std', + grad_max', 'grad_min', 'grad_mean', 'grad_nan_num'] custom_handlers: List of custom log handler functions json_log: Whether to save logs in JSON format """ diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 4248ce1..f0009b7 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -28,9 +28,10 @@ class TrainContext: class TrainContextBuilder: def __init__(self, trainer: 'Trainer'): self.trainer = trainer - self._context = TrainContext() + self._context: TrainContext = None def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: + self._context = TrainContext() if checkpoint is None: checkpoint = Checkpoint( model=self.trainer.parameter.model, @@ -38,13 +39,17 @@ class TrainContextBuilder: config=self.trainer.parameter.config, ) else: - self._context.epoch = checkpoint.epoch - self._context.batch_iter = checkpoint.batch_iter + # resume from the assigned checkpoint or assigned iteration + self._context.epoch = max(checkpoint.epoch, self.trainer.train_config.start_epoch) + self._context.batch_iter = max(checkpoint.batch_iter, self.trainer.train_config.start_batch) self._context.checkpoint = checkpoint return self def with_optimizer(self) -> Self: + if self._context is None: + raise RuntimeError("Must call with_checkpoint() before with_optimizer()") + optimizer = self.trainer.train_config.optimizer if self._context.checkpoint and self._context.checkpoint.optimizer_state: @@ -58,7 +63,9 @@ class TrainContextBuilder: return self def with_scheduler(self) -> Self: - # the build order has any problem ? + if not hasattr(self._context, 'optimizer') or self._context.optimizer is None: + raise RuntimeError("Must call with_optimizer() before with_scheduler()") + optimizer = self.trainer.train_config.optimizer schedule_config = self.trainer.schedule_config scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config) diff --git a/train.py b/train.py index 93e3c88..c230f6d 100644 --- a/train.py +++ b/train.py @@ -24,6 +24,8 @@ def train( max_lr: int, n_epoch: int, batch_size: int, + start_epoch: int, + start_batch: int, accumulation_steps: int, warmup_steps: int, checkpoint_interval: int, @@ -94,6 +96,8 @@ def train( checkpoint_dir=checkpoint_dir, n_epoch=n_epoch, batch_size=batch_size, + start_epoch=start_epoch, + start_batch=start_batch, checkpoint_interval=checkpoint_interval, accumulation_steps=accumulation_steps, max_grad_norm=max_grad_norm, @@ -135,6 +139,8 @@ if __name__ == "__main__": parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") # other configs + parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") + parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.") parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")