feat(trainer): 添加训练起始轮次和批次配置支持

This commit is contained in:
ViperEkura 2025-10-19 21:47:10 +08:00
parent 613edd7a14
commit 98efca7b9d
4 changed files with 27 additions and 5 deletions

View File

@ -34,6 +34,14 @@ class TrainConfig:
default=4, default=4,
metadata={"help": "Batch size for training."} metadata={"help": "Batch size for training."}
) )
start_epoch: int = field(
default=0,
metadata={"help": "Start epoch for training."}
)
start_batch: int = field(
default=0,
metadata={"help": "Start batch iteration for training."}
)
checkpoint_interval: int = field( checkpoint_interval: int = field(
default=5000, default=5000,
metadata={"help": "Number of iterations between checkpoints."} metadata={"help": "Number of iterations between checkpoints."}

View File

@ -161,7 +161,8 @@ class StepMonitorCallback(TrainCallback):
Args: Args:
log_dir: Directory to save log files. If None, logs won't be saved to file. log_dir: Directory to save log files. If None, logs won't be saved to file.
log_interval: Log every N steps log_interval: Log every N steps
metrics: List of metrics to log. Supported: ['loss', 'lr', 'grad_norm', 'grad_std', grad_max', 'grad_min', 'grad_mean', 'grad_nan_num'] metrics: List of metrics to log. Supported: ['loss', 'lr', 'grad_norm', 'grad_std',
grad_max', 'grad_min', 'grad_mean', 'grad_nan_num']
custom_handlers: List of custom log handler functions custom_handlers: List of custom log handler functions
json_log: Whether to save logs in JSON format json_log: Whether to save logs in JSON format
""" """

View File

@ -28,9 +28,10 @@ class TrainContext:
class TrainContextBuilder: class TrainContextBuilder:
def __init__(self, trainer: 'Trainer'): def __init__(self, trainer: 'Trainer'):
self.trainer = trainer self.trainer = trainer
self._context = TrainContext() self._context: TrainContext = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._context = TrainContext()
if checkpoint is None: if checkpoint is None:
checkpoint = Checkpoint( checkpoint = Checkpoint(
model=self.trainer.parameter.model, model=self.trainer.parameter.model,
@ -38,13 +39,17 @@ class TrainContextBuilder:
config=self.trainer.parameter.config, config=self.trainer.parameter.config,
) )
else: else:
self._context.epoch = checkpoint.epoch # resume from the assigned checkpoint or assigned iteration
self._context.batch_iter = checkpoint.batch_iter self._context.epoch = max(checkpoint.epoch, self.trainer.train_config.start_epoch)
self._context.batch_iter = max(checkpoint.batch_iter, self.trainer.train_config.start_batch)
self._context.checkpoint = checkpoint self._context.checkpoint = checkpoint
return self return self
def with_optimizer(self) -> Self: def with_optimizer(self) -> Self:
if self._context is None:
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
optimizer = self.trainer.train_config.optimizer optimizer = self.trainer.train_config.optimizer
if self._context.checkpoint and self._context.checkpoint.optimizer_state: if self._context.checkpoint and self._context.checkpoint.optimizer_state:
@ -58,7 +63,9 @@ class TrainContextBuilder:
return self return self
def with_scheduler(self) -> Self: def with_scheduler(self) -> Self:
# the build order has any problem ? if not hasattr(self._context, 'optimizer') or self._context.optimizer is None:
raise RuntimeError("Must call with_optimizer() before with_scheduler()")
optimizer = self.trainer.train_config.optimizer optimizer = self.trainer.train_config.optimizer
schedule_config = self.trainer.schedule_config schedule_config = self.trainer.schedule_config
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config) scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)

View File

@ -24,6 +24,8 @@ def train(
max_lr: int, max_lr: int,
n_epoch: int, n_epoch: int,
batch_size: int, batch_size: int,
start_epoch: int,
start_batch: int,
accumulation_steps: int, accumulation_steps: int,
warmup_steps: int, warmup_steps: int,
checkpoint_interval: int, checkpoint_interval: int,
@ -94,6 +96,8 @@ def train(
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
n_epoch=n_epoch, n_epoch=n_epoch,
batch_size=batch_size, batch_size=batch_size,
start_epoch=start_epoch,
start_batch=start_batch,
checkpoint_interval=checkpoint_interval, checkpoint_interval=checkpoint_interval,
accumulation_steps=accumulation_steps, accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
@ -135,6 +139,8 @@ if __name__ == "__main__":
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
# other configs # other configs
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.") parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.") parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")