diff --git a/train.py b/train.py index bf3dc16..cc233e1 100644 --- a/train.py +++ b/train.py @@ -47,18 +47,18 @@ def train( cache_files = get_files(data_root_path) - strategy_kwargs = { + kwargs = { "multi_turn": multi_turn, + "dpo_beta": dpo_beta, "bos_token_id": parameter.tokenizer.bos_id, "eos_token_id": parameter.tokenizer.eos_id, "user_token_id":parameter.tokenizer.user_id, - "dpo_beta": dpo_beta } strategy = StrategyFactory.load( model, train_type - **strategy_kwargs + **kwargs ) dataset = DatasetLoader.load( @@ -66,7 +66,7 @@ def train( load_path=cache_files, max_len=parameter.config.m_len, device=device, - dataset_kwargs=strategy_kwargs + dataset_kwargs=kwargs ) param_groups = [ @@ -119,7 +119,6 @@ if __name__ == "__main__": parser.add_argument("--max_lr", type=float, default=3e-4, help="Max learning rate for training.") parser.add_argument("--checkpoint_interval", type=int, default=5000, help="Number of iters between checkpoints.") parser.add_argument("--checkpoint_dir", type=str, default="checkpoint", help="Directory to save checkpoints.") - parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping.") parser.add_argument("--adamw_betas", type=tuple, default=(0.9, 0.95), help="Beta values for AdamW optimizer.") parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") @@ -128,7 +127,8 @@ if __name__ == "__main__": # other configs parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.") - + parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") + args = parser.parse_args() train(