From 46b2a0f86f6b1dee349efd3724aa952514e05601 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 29 Oct 2025 12:32:17 +0800 Subject: [PATCH] =?UTF-8?q?feat(train):=20=E6=B7=BB=E5=8A=A0=20max=5Flen?= =?UTF-8?q?=20=E5=92=8C=20step=5Fsize=20=E5=8F=82=E6=95=B0=E6=94=AF?= =?UTF-8?q?=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index cd6e2ff..83c1bb8 100644 --- a/train.py +++ b/train.py @@ -36,7 +36,8 @@ def train( max_grad_norm: float, embdeding_lr_rate: int, random_seed: int, - multi_turn: bool, + max_len: int, + step_size: int, resume_from_checkpoint: bool ): assert train_type in ["seq", "sft", "dpo"] @@ -47,16 +48,16 @@ def train( if isinstance(parameter, Checkpoint) and resume_from_checkpoint: checkpoint = parameter - + + if max_len is None: + max_len = parameter.config.m_len + model = parameter.model - device = torch.device("cuda") model = model.to(device=device, dtype=torch.bfloat16) - cache_files = get_files(data_root_path) kwargs = { - "multi_turn": multi_turn, "dpo_beta": dpo_beta, "bos_token_id": parameter.tokenizer.bos_id, "eos_token_id": parameter.tokenizer.eos_id, @@ -73,7 +74,8 @@ def train( dataset = DatasetLoader.load( train_type=train_type, load_path=cache_files, - max_len=parameter.config.m_len, + max_len=max_len, + step_size=step_size **kwargs ) @@ -138,10 +140,11 @@ if __name__ == "__main__": parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.") # other configs + parser.add_argument("--max_len", type=int, default=None, help="the max length of the input sequence.") + parser.add_argument("--step_size", type=int, default=None, help="the step size of the input sequence.") parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.") parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.") parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.") - parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") args = parser.parse_args()