refactor(train.py): 重命名策略参数变量名

This commit is contained in:
ViperEkura 2025-09-29 17:26:01 +08:00
parent 6d5176a11c
commit c104a400e7
1 changed files with 6 additions and 6 deletions

View File

@ -47,18 +47,18 @@ def train(
cache_files = get_files(data_root_path) cache_files = get_files(data_root_path)
strategy_kwargs = { kwargs = {
"multi_turn": multi_turn, "multi_turn": multi_turn,
"dpo_beta": dpo_beta,
"bos_token_id": parameter.tokenizer.bos_id, "bos_token_id": parameter.tokenizer.bos_id,
"eos_token_id": parameter.tokenizer.eos_id, "eos_token_id": parameter.tokenizer.eos_id,
"user_token_id":parameter.tokenizer.user_id, "user_token_id":parameter.tokenizer.user_id,
"dpo_beta": dpo_beta
} }
strategy = StrategyFactory.load( strategy = StrategyFactory.load(
model, model,
train_type train_type
**strategy_kwargs **kwargs
) )
dataset = DatasetLoader.load( dataset = DatasetLoader.load(
@ -66,7 +66,7 @@ def train(
load_path=cache_files, load_path=cache_files,
max_len=parameter.config.m_len, max_len=parameter.config.m_len,
device=device, device=device,
dataset_kwargs=strategy_kwargs dataset_kwargs=kwargs
) )
param_groups = [ param_groups = [
@ -119,7 +119,6 @@ if __name__ == "__main__":
parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.")
parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.")
parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.") parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.")
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
@ -128,7 +127,8 @@ if __name__ == "__main__":
# other configs # other configs
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.") parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
args = parser.parse_args() args = parser.parse_args()
train( train(