refactor(trainer): 统一参数命名以提升可读性

This commit is contained in:
ViperEkura 2025-09-28 22:14:24 +08:00
parent fa43ed2943
commit 1c9063fd3d
4 changed files with 37 additions and 37 deletions

View File

@ -54,8 +54,8 @@ python train.py \
--n_epoch=5 \
--batch_size=8 \
--max_lr=2e-4 \
--n_iter_ckpt=10000 \
--ckpt_dir=checkpoints
--checkpoint_interval=10000 \
--checkpoint_dir=checkpoints
```
**Parameters Explanation:**
@ -64,11 +64,11 @@ python train.py \
- `--param_path`: Path to the model training parameters
- `--n_epoch`: Total number of training epochs
- `--batch_size`: Batch size
- `--n_iter_step`: Number of batches per training step
- `--warning_step`: Number of warmup steps
- `--accumulation_steps`: Number of batches per training step
- `--warmup_steps`: Number of warmup steps
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
- `--n_iter_ckpt`: Checkpoint saving interval
- `--ckpt_dir`: Directory to save checkpoints
- `--checkpoint_interval`: Checkpoint saving interval
- `--checkpoint_dir`: Directory to save checkpoints
- `--resume_dir`: Resume training from the specified path
Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation.
@ -214,8 +214,8 @@ python train.py \
--n_epoch=5 \
--batch_size=8 \
--max_lr=2e-4 \
--n_iter_ckpt=10000 \
--ckpt_dir=checkpoints
--checkpoint_interval=10000 \
--checkpoint_dir=checkpoints
```
**参数说明:**
@ -224,11 +224,11 @@ python train.py \
- `--param_path`: 模型训练参数路径
- `--n_epoch`: 总训练轮数
- `--batch_size`: 批量大小
- `--n_iter_step`: 每个训练步骤的 batch 数量
- `--warning_step`: 预热步数warmup steps
- `--accumulation_steps`: 每个训练步骤的 batch 数量
- `--warmup_steps`: 预热步数warmup steps
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
- `--n_iter_ckpt`: 检查点保存间隔
- `--ckpt_dir`: 检查点保存目录
- `--checkpoint_interval`: 检查点保存间隔
- `--checkpoint_dir`: 检查点保存目录
- `--resume_dir`: 从指定路径恢复训练
训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。

View File

@ -190,7 +190,7 @@ class TrainConfig:
default=None,
metadata={"help": "Optimizer for training."}
)
ckpt_dir: str = field(
checkpoint_dir: str = field(
default="./checkpoint",
metadata={"help": "Checkpoint directory."}
)
@ -202,11 +202,11 @@ class TrainConfig:
default=4,
metadata={"help": "Batch size for training."}
)
n_iter_ckpt: int = field(
checkpoint_interval: int = field(
default=5000,
metadata={"help": "Number of iterations between checkpoints."}
)
n_iter_step: int = field(
accumulation_steps: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
)
@ -256,7 +256,7 @@ class ScheduleConfig(ABC):
@dataclass
class CosineScheduleConfig(ScheduleConfig):
total_steps: int = field( # 更准确的命名
total_steps: int = field(
default=None,
metadata={"help": "Total training steps for cosine schedule."}
)

View File

@ -32,7 +32,7 @@ class Trainer:
train_config: TrainConfig
):
current_iter = len(loss_list)
save_path = os.path.join(train_config.ckpt_dir, f"iter_{current_iter}")
save_path = os.path.join(train_config.checkpoint_dir, f"iter_{current_iter}")
self.checkpoint.loss_list = loss_list
self.checkpoint.optim_state = train_config.optimizer.state_dict()
self.checkpoint.save(save_path)
@ -93,7 +93,7 @@ class Trainer:
#backward
loss.backward()
#step
if current_iter % train_config.n_iter_step == 0:
if current_iter % train_config.accumulation_steps == 0:
clip_grad_norm_(
self.checkpoint.model.parameters(),
train_config.max_grad_norm
@ -108,7 +108,7 @@ class Trainer:
"lr": f"{train_config.optimizer.param_groups[0]['lr']:.2e}"
})
#save checkpotint
if current_iter - last_ckpt_iter >= train_config.n_iter_ckpt:
if current_iter - last_ckpt_iter >= train_config.checkpoint_interval:
self.save_checkpoint(loss_list, train_config)
last_ckpt_iter = current_iter

View File

@ -21,13 +21,13 @@ def train(
train_type: str,
param_path: str,
data_root_path: str,
max_lr: int,
n_epoch: int,
batch_size: int,
n_iter_step: int,
warning_step: int,
max_lr: int,
n_iter_ckpt: int,
ckpt_dir: str,
accumulation_steps: int,
warmup_steps: int,
checkpoint_interval: int,
checkpoint_dir: str,
dpo_beta: float,
adamw_betas: tuple,
adamw_weight_decay: float,
@ -84,18 +84,18 @@ def train(
strategy=strategy,
dataset=dataset,
optimizer=optim,
ckpt_dir=ckpt_dir,
checkpoint_dir=checkpoint_dir,
n_epoch=n_epoch,
batch_size=batch_size,
n_iter_ckpt=n_iter_ckpt,
n_iter_step=n_iter_step,
checkpoint_interval=checkpoint_interval,
accumulation_steps=accumulation_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
)
schedule_config = CosineScheduleConfig(
warning_step=warning_step,
total_iters=len(dataset) * n_epoch // batch_size,
warmup_steps=warmup_steps,
total_steps=len(dataset) * n_epoch // batch_size,
)
trainer = Trainer(
@ -114,11 +114,11 @@ if __name__ == "__main__":
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
parser.add_argument("--n_epoch", type=int, default=1, help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--n_iter_step", type=int, default=1, help="Number of iterations between each optimizer step.")
parser.add_argument("--warning_step", type=int, default=1000, help="Number of iters between warnings.")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of iterations between each optimizer step.")
parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of iters between warnings.")
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
parser.add_argument("--n_iter_ckpt", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--ckpt_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
@ -136,16 +136,16 @@ if __name__ == "__main__":
data_root_path=args.data_root_path,
n_epoch=args.n_epoch,
batch_size=args.batch_size,
n_iter_step=args.n_iter_step,
warning_step=args.warning_step,
accumulation_steps=args.accumulation_steps,
warmup_steps=args.warmup_steps,
max_lr=args.max_lr,
dpo_beta=args.dpo_beta,
adamw_betas=args.adamw_betas,
adamw_weight_decay=args.adamw_weight_decay,
max_grad_norm=args.max_grad_norm,
embdeding_lr_rate=args.embdeding_lr_rate,
n_iter_ckpt=args.n_iter_ckpt,
ckpt_dir=args.ckpt_dir,
checkpoint_interval=args.checkpoint_interval,
checkpoint_dir=args.checkpoint_dir,
train_type=args.train_type,
random_seed=args.random_seed,
multi_turn=args.multi_turn