fix(tools/train): 修复参数传递错误
This commit is contained in:
parent
d52685facd
commit
82e65ccc21
|
|
@ -56,7 +56,8 @@ def train(
|
||||||
checkpoint_interval: int,
|
checkpoint_interval: int,
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
adamw_betas: tuple,
|
adamw_beta1: float,
|
||||||
|
adamw_beta2: float,
|
||||||
adamw_weight_decay: float,
|
adamw_weight_decay: float,
|
||||||
max_grad_norm: float,
|
max_grad_norm: float,
|
||||||
embdeding_lr_rate: int,
|
embdeding_lr_rate: int,
|
||||||
|
|
@ -112,7 +113,7 @@ def train(
|
||||||
|
|
||||||
optim = AdamW(
|
optim = AdamW(
|
||||||
param_groups,
|
param_groups,
|
||||||
betas=adamw_betas,
|
betas=(adamw_beta1, adamw_beta2),
|
||||||
weight_decay=adamw_weight_decay
|
weight_decay=adamw_weight_decay
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue