diff --git a/train.py b/train.py index cc233e1..feb7f8b 100644 --- a/train.py +++ b/train.py @@ -52,12 +52,13 @@ def train( "dpo_beta": dpo_beta, "bos_token_id": parameter.tokenizer.bos_id, "eos_token_id": parameter.tokenizer.eos_id, + "pad_token_id": parameter.tokenizer.pad_id, "user_token_id":parameter.tokenizer.user_id, } strategy = StrategyFactory.load( model, - train_type + train_type, **kwargs ) @@ -66,7 +67,7 @@ def train( load_path=cache_files, max_len=parameter.config.m_len, device=device, - dataset_kwargs=kwargs + **kwargs ) param_groups = [