refactor(train.py): 重命名策略参数变量名
This commit is contained in:
parent
6d5176a11c
commit
c104a400e7
12
train.py
12
train.py
|
|
@ -47,18 +47,18 @@ def train(
|
||||||
|
|
||||||
cache_files = get_files(data_root_path)
|
cache_files = get_files(data_root_path)
|
||||||
|
|
||||||
strategy_kwargs = {
|
kwargs = {
|
||||||
"multi_turn": multi_turn,
|
"multi_turn": multi_turn,
|
||||||
|
"dpo_beta": dpo_beta,
|
||||||
"bos_token_id": parameter.tokenizer.bos_id,
|
"bos_token_id": parameter.tokenizer.bos_id,
|
||||||
"eos_token_id": parameter.tokenizer.eos_id,
|
"eos_token_id": parameter.tokenizer.eos_id,
|
||||||
"user_token_id":parameter.tokenizer.user_id,
|
"user_token_id":parameter.tokenizer.user_id,
|
||||||
"dpo_beta": dpo_beta
|
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy = StrategyFactory.load(
|
strategy = StrategyFactory.load(
|
||||||
model,
|
model,
|
||||||
train_type
|
train_type
|
||||||
**strategy_kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
|
|
@ -66,7 +66,7 @@ def train(
|
||||||
load_path=cache_files,
|
load_path=cache_files,
|
||||||
max_len=parameter.config.m_len,
|
max_len=parameter.config.m_len,
|
||||||
device=device,
|
device=device,
|
||||||
dataset_kwargs=strategy_kwargs
|
dataset_kwargs=kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
param_groups = [
|
param_groups = [
|
||||||
|
|
@ -119,7 +119,6 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
|
||||||
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
|
||||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
|
||||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
|
||||||
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
|
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
|
||||||
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
||||||
|
|
@ -128,7 +127,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# other configs
|
# other configs
|
||||||
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
||||||
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
train(
|
train(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue