fix(train.py): 修复参数传递错误
This commit is contained in:
parent
17f1a12f27
commit
dd6a9e4ede
5
train.py
5
train.py
|
|
@ -52,12 +52,13 @@ def train(
|
||||||
"dpo_beta": dpo_beta,
|
"dpo_beta": dpo_beta,
|
||||||
"bos_token_id": parameter.tokenizer.bos_id,
|
"bos_token_id": parameter.tokenizer.bos_id,
|
||||||
"eos_token_id": parameter.tokenizer.eos_id,
|
"eos_token_id": parameter.tokenizer.eos_id,
|
||||||
|
"pad_token_id": parameter.tokenizer.pad_id,
|
||||||
"user_token_id":parameter.tokenizer.user_id,
|
"user_token_id":parameter.tokenizer.user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy = StrategyFactory.load(
|
strategy = StrategyFactory.load(
|
||||||
model,
|
model,
|
||||||
train_type
|
train_type,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -66,7 +67,7 @@ def train(
|
||||||
load_path=cache_files,
|
load_path=cache_files,
|
||||||
max_len=parameter.config.m_len,
|
max_len=parameter.config.m_len,
|
||||||
device=device,
|
device=device,
|
||||||
dataset_kwargs=kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
param_groups = [
|
param_groups = [
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue