feat(train): 添加 max_len 和 step_size 参数支持

This commit is contained in:
ViperEkura 2025-10-29 12:32:17 +08:00
parent d94fc5a87a
commit 46b2a0f86f
1 changed files with 10 additions and 7 deletions

View File

@ -36,7 +36,8 @@ def train(
max_grad_norm: float,
embdeding_lr_rate: int,
random_seed: int,
multi_turn: bool,
max_len: int,
step_size: int,
resume_from_checkpoint: bool
):
assert train_type in ["seq", "sft", "dpo"]
@ -48,15 +49,15 @@ def train(
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
checkpoint = parameter
model = parameter.model
if max_len is None:
max_len = parameter.config.m_len
model = parameter.model
device = torch.device("cuda")
model = model.to(device=device, dtype=torch.bfloat16)
cache_files = get_files(data_root_path)
kwargs = {
"multi_turn": multi_turn,
"dpo_beta": dpo_beta,
"bos_token_id": parameter.tokenizer.bos_id,
"eos_token_id": parameter.tokenizer.eos_id,
@ -73,7 +74,8 @@ def train(
dataset = DatasetLoader.load(
train_type=train_type,
load_path=cache_files,
max_len=parameter.config.m_len,
max_len=max_len,
step_size=step_size
**kwargs
)
@ -138,10 +140,11 @@ if __name__ == "__main__":
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
# other configs
parser.add_argument("--max_len", type=int, default=None, help="the max length of the input sequence.")
parser.add_argument("--step_size", type=int, default=None, help="the step size of the input sequence.")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
args = parser.parse_args()