fix(train.py): 修复参数传递错误

This commit is contained in:
ViperEkura 2025-09-30 17:30:00 +08:00
parent 17f1a12f27
commit dd6a9e4ede
1 changed files with 3 additions and 2 deletions

View File

@ -52,12 +52,13 @@ def train(
"dpo_beta": dpo_beta,
"bos_token_id": parameter.tokenizer.bos_id,
"eos_token_id": parameter.tokenizer.eos_id,
"pad_token_id": parameter.tokenizer.pad_id,
"user_token_id":parameter.tokenizer.user_id,
}
strategy = StrategyFactory.load(
model,
train_type
train_type,
**kwargs
)
@ -66,7 +67,7 @@ def train(
load_path=cache_files,
max_len=parameter.config.m_len,
device=device,
dataset_kwargs=kwargs
**kwargs
)
param_groups = [