feat(trainer): 添加训练起始轮次和批次配置支持
This commit is contained in:
parent
613edd7a14
commit
98efca7b9d
|
|
@ -34,6 +34,14 @@ class TrainConfig:
|
|||
default=4,
|
||||
metadata={"help": "Batch size for training."}
|
||||
)
|
||||
start_epoch: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Start epoch for training."}
|
||||
)
|
||||
start_batch: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Start batch iteration for training."}
|
||||
)
|
||||
checkpoint_interval: int = field(
|
||||
default=5000,
|
||||
metadata={"help": "Number of iterations between checkpoints."}
|
||||
|
|
|
|||
|
|
@ -161,7 +161,8 @@ class StepMonitorCallback(TrainCallback):
|
|||
Args:
|
||||
log_dir: Directory to save log files. If None, logs won't be saved to file.
|
||||
log_interval: Log every N steps
|
||||
metrics: List of metrics to log. Supported: ['loss', 'lr', 'grad_norm', 'grad_std', grad_max', 'grad_min', 'grad_mean', 'grad_nan_num']
|
||||
metrics: List of metrics to log. Supported: ['loss', 'lr', 'grad_norm', 'grad_std',
|
||||
grad_max', 'grad_min', 'grad_mean', 'grad_nan_num']
|
||||
custom_handlers: List of custom log handler functions
|
||||
json_log: Whether to save logs in JSON format
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -28,9 +28,10 @@ class TrainContext:
|
|||
class TrainContextBuilder:
|
||||
def __init__(self, trainer: 'Trainer'):
|
||||
self.trainer = trainer
|
||||
self._context = TrainContext()
|
||||
self._context: TrainContext = None
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._context = TrainContext()
|
||||
if checkpoint is None:
|
||||
checkpoint = Checkpoint(
|
||||
model=self.trainer.parameter.model,
|
||||
|
|
@ -38,13 +39,17 @@ class TrainContextBuilder:
|
|||
config=self.trainer.parameter.config,
|
||||
)
|
||||
else:
|
||||
self._context.epoch = checkpoint.epoch
|
||||
self._context.batch_iter = checkpoint.batch_iter
|
||||
# resume from the assigned checkpoint or assigned iteration
|
||||
self._context.epoch = max(checkpoint.epoch, self.trainer.train_config.start_epoch)
|
||||
self._context.batch_iter = max(checkpoint.batch_iter, self.trainer.train_config.start_batch)
|
||||
|
||||
self._context.checkpoint = checkpoint
|
||||
return self
|
||||
|
||||
def with_optimizer(self) -> Self:
|
||||
if self._context is None:
|
||||
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
|
||||
|
||||
optimizer = self.trainer.train_config.optimizer
|
||||
|
||||
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
|
||||
|
|
@ -58,7 +63,9 @@ class TrainContextBuilder:
|
|||
return self
|
||||
|
||||
def with_scheduler(self) -> Self:
|
||||
# the build order has any problem ?
|
||||
if not hasattr(self._context, 'optimizer') or self._context.optimizer is None:
|
||||
raise RuntimeError("Must call with_optimizer() before with_scheduler()")
|
||||
|
||||
optimizer = self.trainer.train_config.optimizer
|
||||
schedule_config = self.trainer.schedule_config
|
||||
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)
|
||||
|
|
|
|||
6
train.py
6
train.py
|
|
@ -24,6 +24,8 @@ def train(
|
|||
max_lr: int,
|
||||
n_epoch: int,
|
||||
batch_size: int,
|
||||
start_epoch: int,
|
||||
start_batch: int,
|
||||
accumulation_steps: int,
|
||||
warmup_steps: int,
|
||||
checkpoint_interval: int,
|
||||
|
|
@ -94,6 +96,8 @@ def train(
|
|||
checkpoint_dir=checkpoint_dir,
|
||||
n_epoch=n_epoch,
|
||||
batch_size=batch_size,
|
||||
start_epoch=start_epoch,
|
||||
start_batch=start_batch,
|
||||
checkpoint_interval=checkpoint_interval,
|
||||
accumulation_steps=accumulation_steps,
|
||||
max_grad_norm=max_grad_norm,
|
||||
|
|
@ -135,6 +139,8 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||
|
||||
# other configs
|
||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
||||
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||
|
|
|
|||
Loading…
Reference in New Issue