feat(train): 添加多轮对话训练支持

This commit is contained in:
ViperEkura 2025-09-28 15:38:53 +08:00
parent 1169cfad82
commit 30ac07418c
1 changed files with 17 additions and 2 deletions

View File

@ -33,6 +33,7 @@ def train(
max_grad_norm: float,
embdeding_lr_rate: int,
random_seed: int,
multi_turn: bool,
):
assert train_type in ["seq", "sft", "dpo"]
assert os.path.exists(param_path)
@ -44,11 +45,20 @@ def train(
model = model.to(device=device, dtype=torch.bfloat16)
cache_files = get_files(data_root_path)
dataset_kwargs = {
"multi_turn": multi_turn,
"bos_token_id": parameter.tokenizer.bos_id,
"eos_token_id": parameter.tokenizer.eos_id,
"user_token_id":parameter.tokenizer.encode("<|user|>")[0]
}
dataset = DatasetLoader.load(
train_type=train_type,
load_path=cache_files,
max_len=parameter.config.m_len,
device=device
device=device,
dataset_kwargs=dataset_kwargs
)
param_groups = [
@ -89,6 +99,7 @@ def train(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train the Transformer model.")
# train args
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
@ -106,6 +117,9 @@ if __name__ == "__main__":
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
# other configs
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
args = parser.parse_args()
train(
@ -124,5 +138,6 @@ if __name__ == "__main__":
n_iter_ckpt=args.n_iter_ckpt,
ckpt_dir=args.ckpt_dir,
train_type=args.train_type,
random_seed=args.random_seed
random_seed=args.random_seed,
multi_turn=args.multi_turn
)