fix(tools/train): 修复参数传递错误
This commit is contained in:
parent
d52685facd
commit
82e65ccc21
|
|
@ -56,7 +56,8 @@ def train(
|
|||
checkpoint_interval: int,
|
||||
checkpoint_dir: str,
|
||||
dpo_beta: float,
|
||||
adamw_betas: tuple,
|
||||
adamw_beta1: float,
|
||||
adamw_beta2: float,
|
||||
adamw_weight_decay: float,
|
||||
max_grad_norm: float,
|
||||
embdeding_lr_rate: int,
|
||||
|
|
@ -112,7 +113,7 @@ def train(
|
|||
|
||||
optim = AdamW(
|
||||
param_groups,
|
||||
betas=adamw_betas,
|
||||
betas=(adamw_beta1, adamw_beta2),
|
||||
weight_decay=adamw_weight_decay
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue