refactor(trainer): 统一参数命名以提升可读性

This commit is contained in:
ViperEkura 2025-09-28 22:14:24 +08:00
parent fa43ed2943
commit 1c9063fd3d
4 changed files with 37 additions and 37 deletions

View File

@ -54,8 +54,8 @@ python train.py \
--n_epoch=5 \ --n_epoch=5 \
--batch_size=8 \ --batch_size=8 \
--max_lr=2e-4 \ --max_lr=2e-4 \
--n_iter_ckpt=10000 \ --checkpoint_interval=10000 \
--ckpt_dir=checkpoints --checkpoint_dir=checkpoints
``` ```
**Parameters Explanation:** **Parameters Explanation:**
@ -64,11 +64,11 @@ python train.py \
- `--param_path`: Path to the model training parameters - `--param_path`: Path to the model training parameters
- `--n_epoch`: Total number of training epochs - `--n_epoch`: Total number of training epochs
- `--batch_size`: Batch size - `--batch_size`: Batch size
- `--n_iter_step`: Number of batches per training step - `--accumulation_steps`: Number of batches per training step
- `--warning_step`: Number of warmup steps - `--warmup_steps`: Number of warmup steps
- `--max_lr`: Maximum learning rate (using warmup + cosine decay) - `--max_lr`: Maximum learning rate (using warmup + cosine decay)
- `--n_iter_ckpt`: Checkpoint saving interval - `--checkpoint_interval`: Checkpoint saving interval
- `--ckpt_dir`: Directory to save checkpoints - `--checkpoint_dir`: Directory to save checkpoints
- `--resume_dir`: Resume training from the specified path - `--resume_dir`: Resume training from the specified path
Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation. Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation.
@ -214,8 +214,8 @@ python train.py \
--n_epoch=5 \ --n_epoch=5 \
--batch_size=8 \ --batch_size=8 \
--max_lr=2e-4 \ --max_lr=2e-4 \
--n_iter_ckpt=10000 \ --checkpoint_interval=10000 \
--ckpt_dir=checkpoints --checkpoint_dir=checkpoints
``` ```
**参数说明:** **参数说明:**
@ -224,11 +224,11 @@ python train.py \
- `--param_path`: 模型训练参数路径 - `--param_path`: 模型训练参数路径
- `--n_epoch`: 总训练轮数 - `--n_epoch`: 总训练轮数
- `--batch_size`: 批量大小 - `--batch_size`: 批量大小
- `--n_iter_step`: 每个训练步骤的 batch 数量 - `--accumulation_steps`: 每个训练步骤的 batch 数量
- `--warning_step`: 预热步数warmup steps - `--warmup_steps`: 预热步数warmup steps
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减) - `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
- `--n_iter_ckpt`: 检查点保存间隔 - `--checkpoint_interval`: 检查点保存间隔
- `--ckpt_dir`: 检查点保存目录 - `--checkpoint_dir`: 检查点保存目录
- `--resume_dir`: 从指定路径恢复训练 - `--resume_dir`: 从指定路径恢复训练
训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。 训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。

View File

@ -190,7 +190,7 @@ class TrainConfig:
default=None, default=None,
metadata={"help": "Optimizer for training."} metadata={"help": "Optimizer for training."}
) )
ckpt_dir: str = field( checkpoint_dir: str = field(
default="./checkpoint", default="./checkpoint",
metadata={"help": "Checkpoint directory."} metadata={"help": "Checkpoint directory."}
) )
@ -202,11 +202,11 @@ class TrainConfig:
default=4, default=4,
metadata={"help": "Batch size for training."} metadata={"help": "Batch size for training."}
) )
n_iter_ckpt: int = field( checkpoint_interval: int = field(
default=5000, default=5000,
metadata={"help": "Number of iterations between checkpoints."} metadata={"help": "Number of iterations between checkpoints."}
) )
n_iter_step: int = field( accumulation_steps: int = field(
default=1, default=1,
metadata={"help": "Number of iterations between steps."} metadata={"help": "Number of iterations between steps."}
) )
@ -256,7 +256,7 @@ class ScheduleConfig(ABC):
@dataclass @dataclass
class CosineScheduleConfig(ScheduleConfig): class CosineScheduleConfig(ScheduleConfig):
total_steps: int = field( # 更准确的命名 total_steps: int = field(
default=None, default=None,
metadata={"help": "Total training steps for cosine schedule."} metadata={"help": "Total training steps for cosine schedule."}
) )

View File

@ -32,7 +32,7 @@ class Trainer:
train_config: TrainConfig train_config: TrainConfig
): ):
current_iter = len(loss_list) current_iter = len(loss_list)
save_path = os.path.join(train_config.ckpt_dir, f"iter_{current_iter}") save_path = os.path.join(train_config.checkpoint_dir, f"iter_{current_iter}")
self.checkpoint.loss_list = loss_list self.checkpoint.loss_list = loss_list
self.checkpoint.optim_state = train_config.optimizer.state_dict() self.checkpoint.optim_state = train_config.optimizer.state_dict()
self.checkpoint.save(save_path) self.checkpoint.save(save_path)
@ -93,7 +93,7 @@ class Trainer:
#backward #backward
loss.backward() loss.backward()
#step #step
if current_iter % train_config.n_iter_step == 0: if current_iter % train_config.accumulation_steps == 0:
clip_grad_norm_( clip_grad_norm_(
self.checkpoint.model.parameters(), self.checkpoint.model.parameters(),
train_config.max_grad_norm train_config.max_grad_norm
@ -108,7 +108,7 @@ class Trainer:
"lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}" "lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}"
}) })
#save checkpotint #save checkpotint
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt: if current_iter - last_ckpt_iter >= train_config.checkpoint_interval:
self.save_checkpoint(loss_list, train_config) self.save_checkpoint(loss_list, train_config)
last_ckpt_iter = current_iter last_ckpt_iter = current_iter

View File

@ -21,13 +21,13 @@ def train(
train_type: str, train_type: str,
param_path: str, param_path: str,
data_root_path: str, data_root_path: str,
max_lr: int,
n_epoch: int, n_epoch: int,
batch_size: int, batch_size: int,
n_iter_step: int, accumulation_steps: int,
warning_step: int, warmup_steps: int,
max_lr: int, checkpoint_interval: int,
n_iter_ckpt: int, checkpoint_dir: str,
ckpt_dir: str,
dpo_beta: float, dpo_beta: float,
adamw_betas: tuple, adamw_betas: tuple,
adamw_weight_decay: float, adamw_weight_decay: float,
@ -84,18 +84,18 @@ def train(
strategy=strategy, strategy=strategy,
dataset=dataset, dataset=dataset,
optimizer=optim, optimizer=optim,
ckpt_dir=ckpt_dir, checkpoint_dir=checkpoint_dir,
n_epoch=n_epoch, n_epoch=n_epoch,
batch_size=batch_size, batch_size=batch_size,
n_iter_ckpt=n_iter_ckpt, checkpoint_interval=checkpoint_interval,
n_iter_step=n_iter_step, accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
random_seed=random_seed, random_seed=random_seed,
) )
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(
warning_step=warning_step, warmup_steps=warmup_steps,
total_iters=len(dataset) * n_epoch // batch_size, total_steps=len(dataset) * n_epoch // batch_size,
) )
trainer = Trainer( trainer = Trainer(
@ -114,11 +114,11 @@ if __name__ == "__main__":
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.") parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--n_iter_step", type=int, default=1, help="Number of iterations between each optimizer step.") parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
parser.add_argument("--warning_step", type=int, default=1000, help="Number of iters between warnings.") parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
parser.add_argument("--n_iter_ckpt", type=int, default=5000, help="Number of iters between checkpoints.") parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--ckpt_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.") parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
@ -136,16 +136,16 @@ if __name__ == "__main__":
data_root_path=args.data_root_path, data_root_path=args.data_root_path,
n_epoch=args.n_epoch, n_epoch=args.n_epoch,
batch_size=args.batch_size, batch_size=args.batch_size,
n_iter_step=args.n_iter_step, accumulation_steps=args.accumulation_steps,
warning_step=args.warning_step, warmup_steps=args.warmup_steps,
max_lr=args.max_lr, max_lr=args.max_lr,
dpo_beta=args.dpo_beta, dpo_beta=args.dpo_beta,
adamw_betas=args.adamw_betas, adamw_betas=args.adamw_betas,
adamw_weight_decay=args.adamw_weight_decay, adamw_weight_decay=args.adamw_weight_decay,
max_grad_norm=args.max_grad_norm, max_grad_norm=args.max_grad_norm,
embdeding_lr_rate=args.embdeding_lr_rate, embdeding_lr_rate=args.embdeding_lr_rate,
n_iter_ckpt=args.n_iter_ckpt, checkpoint_interval=args.checkpoint_interval,
ckpt_dir=args.ckpt_dir, checkpoint_dir=args.checkpoint_dir,
train_type=args.train_type, train_type=args.train_type,
random_seed=args.random_seed, random_seed=args.random_seed,
multi_turn=args.multi_turn multi_turn=args.multi_turn