diff --git a/tools/train.py b/tools/train.py index d4bb15d..8704092 100644 --- a/tools/train.py +++ b/tools/train.py @@ -56,7 +56,8 @@ def train( checkpoint_interval: int, checkpoint_dir: str, dpo_beta: float, - adamw_betas: tuple, + adamw_beta1: float, + adamw_beta2: float, adamw_weight_decay: float, max_grad_norm: float, embdeding_lr_rate: int, @@ -112,7 +113,7 @@ def train( optim = AdamW( param_groups, - betas=adamw_betas, + betas=(adamw_beta1, adamw_beta2), weight_decay=adamw_weight_decay )