feat(trainer): 添加训练起始轮次和批次配置支持
This commit is contained in:
parent
613edd7a14
commit
98efca7b9d
|
|
@ -34,6 +34,14 @@ class TrainConfig:
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "Batch size for training."}
|
metadata={"help": "Batch size for training."}
|
||||||
)
|
)
|
||||||
|
start_epoch: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Start epoch for training."}
|
||||||
|
)
|
||||||
|
start_batch: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Start batch iteration for training."}
|
||||||
|
)
|
||||||
checkpoint_interval: int = field(
|
checkpoint_interval: int = field(
|
||||||
default=5000,
|
default=5000,
|
||||||
metadata={"help": "Number of iterations between checkpoints."}
|
metadata={"help": "Number of iterations between checkpoints."}
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,8 @@ class StepMonitorCallback(TrainCallback):
|
||||||
Args:
|
Args:
|
||||||
log_dir: Directory to save log files. If None, logs won't be saved to file.
|
log_dir: Directory to save log files. If None, logs won't be saved to file.
|
||||||
log_interval: Log every N steps
|
log_interval: Log every N steps
|
||||||
metrics: List of metrics to log. Supported: ['loss', 'lr', 'grad_norm', 'grad_std', grad_max', 'grad_min', 'grad_mean', 'grad_nan_num']
|
metrics: List of metrics to log. Supported: ['loss', 'lr', 'grad_norm', 'grad_std',
|
||||||
|
grad_max', 'grad_min', 'grad_mean', 'grad_nan_num']
|
||||||
custom_handlers: List of custom log handler functions
|
custom_handlers: List of custom log handler functions
|
||||||
json_log: Whether to save logs in JSON format
|
json_log: Whether to save logs in JSON format
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,10 @@ class TrainContext:
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
def __init__(self, trainer: 'Trainer'):
|
def __init__(self, trainer: 'Trainer'):
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
self._context = TrainContext()
|
self._context: TrainContext = None
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
|
self._context = TrainContext()
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
model=self.trainer.parameter.model,
|
model=self.trainer.parameter.model,
|
||||||
|
|
@ -38,13 +39,17 @@ class TrainContextBuilder:
|
||||||
config=self.trainer.parameter.config,
|
config=self.trainer.parameter.config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._context.epoch = checkpoint.epoch
|
# resume from the assigned checkpoint or assigned iteration
|
||||||
self._context.batch_iter = checkpoint.batch_iter
|
self._context.epoch = max(checkpoint.epoch, self.trainer.train_config.start_epoch)
|
||||||
|
self._context.batch_iter = max(checkpoint.batch_iter, self.trainer.train_config.start_batch)
|
||||||
|
|
||||||
self._context.checkpoint = checkpoint
|
self._context.checkpoint = checkpoint
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_optimizer(self) -> Self:
|
def with_optimizer(self) -> Self:
|
||||||
|
if self._context is None:
|
||||||
|
raise RuntimeError("Must call with_checkpoint() before with_optimizer()")
|
||||||
|
|
||||||
optimizer = self.trainer.train_config.optimizer
|
optimizer = self.trainer.train_config.optimizer
|
||||||
|
|
||||||
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
|
if self._context.checkpoint and self._context.checkpoint.optimizer_state:
|
||||||
|
|
@ -58,7 +63,9 @@ class TrainContextBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_scheduler(self) -> Self:
|
def with_scheduler(self) -> Self:
|
||||||
# the build order has any problem ?
|
if not hasattr(self._context, 'optimizer') or self._context.optimizer is None:
|
||||||
|
raise RuntimeError("Must call with_optimizer() before with_scheduler()")
|
||||||
|
|
||||||
optimizer = self.trainer.train_config.optimizer
|
optimizer = self.trainer.train_config.optimizer
|
||||||
schedule_config = self.trainer.schedule_config
|
schedule_config = self.trainer.schedule_config
|
||||||
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)
|
scheduler = SchedulerFactory.load_scheduler(optimizer, schedule_config)
|
||||||
|
|
|
||||||
6
train.py
6
train.py
|
|
@ -24,6 +24,8 @@ def train(
|
||||||
max_lr: int,
|
max_lr: int,
|
||||||
n_epoch: int,
|
n_epoch: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
start_epoch: int,
|
||||||
|
start_batch: int,
|
||||||
accumulation_steps: int,
|
accumulation_steps: int,
|
||||||
warmup_steps: int,
|
warmup_steps: int,
|
||||||
checkpoint_interval: int,
|
checkpoint_interval: int,
|
||||||
|
|
@ -94,6 +96,8 @@ def train(
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
start_epoch=start_epoch,
|
||||||
|
start_batch=start_batch,
|
||||||
checkpoint_interval=checkpoint_interval,
|
checkpoint_interval=checkpoint_interval,
|
||||||
accumulation_steps=accumulation_steps,
|
accumulation_steps=accumulation_steps,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
|
|
@ -135,6 +139,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||||
|
|
||||||
# other configs
|
# other configs
|
||||||
|
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||||
|
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||||
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
||||||
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue