fix(train.py): 修复参数传递错误
This commit is contained in:
parent
17f1a12f27
commit
dd6a9e4ede
5
train.py
5
train.py
|
|
@ -52,12 +52,13 @@ def train(
|
|||
"dpo_beta": dpo_beta,
|
||||
"bos_token_id": parameter.tokenizer.bos_id,
|
||||
"eos_token_id": parameter.tokenizer.eos_id,
|
||||
"pad_token_id": parameter.tokenizer.pad_id,
|
||||
"user_token_id":parameter.tokenizer.user_id,
|
||||
}
|
||||
|
||||
strategy = StrategyFactory.load(
|
||||
model,
|
||||
train_type
|
||||
train_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
|
@ -66,7 +67,7 @@ def train(
|
|||
load_path=cache_files,
|
||||
max_len=parameter.config.m_len,
|
||||
device=device,
|
||||
dataset_kwargs=kwargs
|
||||
**kwargs
|
||||
)
|
||||
|
||||
param_groups = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue