fix(tools/train): 修复参数传递错误

This commit is contained in:
ViperEkura 2025-12-05 13:53:50 +08:00
parent d52685facd
commit 82e65ccc21
1 changed files with 3 additions and 2 deletions

View File

@ -56,7 +56,8 @@ def train(
checkpoint_interval: int, checkpoint_interval: int,
checkpoint_dir: str, checkpoint_dir: str,
dpo_beta: float, dpo_beta: float,
adamw_betas: tuple, adamw_beta1: float,
adamw_beta2: float,
adamw_weight_decay: float, adamw_weight_decay: float,
max_grad_norm: float, max_grad_norm: float,
embdeding_lr_rate: int, embdeding_lr_rate: int,
@ -112,7 +113,7 @@ def train(
optim = AdamW( optim = AdamW(
param_groups, param_groups,
betas=adamw_betas, betas=(adamw_beta1, adamw_beta2),
weight_decay=adamw_weight_decay weight_decay=adamw_weight_decay
) )