From 30ac07418c515ebafcfa11c98156b5b850391287 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 28 Sep 2025 15:38:53 +0800 Subject: [PATCH] =?UTF-8?q?feat(train):=20=E6=B7=BB=E5=8A=A0=E5=A4=9A?= =?UTF-8?q?=E8=BD=AE=E5=AF=B9=E8=AF=9D=E8=AE=AD=E7=BB=83=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 707ad15..73695bf 100644 --- a/train.py +++ b/train.py @@ -33,6 +33,7 @@ def train( max_grad_norm: float, embdeding_lr_rate: int, random_seed: int, + multi_turn: bool, ): assert train_type in ["seq", "sft", "dpo"] assert os.path.exists(param_path) @@ -44,11 +45,20 @@ def train( model = model.to(device=device, dtype=torch.bfloat16) cache_files = get_files(data_root_path) + + dataset_kwargs = { + "multi_turn": multi_turn, + "bos_token_id": parameter.tokenizer.bos_id, + "eos_token_id": parameter.tokenizer.eos_id, + "user_token_id":parameter.tokenizer.encode("<|user|>")[0] + } + dataset = DatasetLoader.load( train_type=train_type, load_path=cache_files, max_len=parameter.config.m_len, - device=device + device=device, + dataset_kwargs=dataset_kwargs ) param_groups = [ @@ -89,6 +99,7 @@ def train( if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train the Transformer model.") + # train args parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.") parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.") parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.") @@ -105,6 +116,9 @@ if __name__ == "__main__": parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.") parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.") parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") + + # other configs + parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.") args = parser.parse_args() @@ -124,5 +138,6 @@ if __name__ == "__main__": n_iter_ckpt=args.n_iter_ckpt, ckpt_dir=args.ckpt_dir, train_type=args.train_type, - random_seed=args.random_seed + random_seed=args.random_seed, + multi_turn=args.multi_turn ) \ No newline at end of file