feat(train): 添加多轮对话训练支持
This commit is contained in:
parent
1169cfad82
commit
30ac07418c
19
train.py
19
train.py
|
|
@ -33,6 +33,7 @@ def train(
|
|||
max_grad_norm: float,
|
||||
embdeding_lr_rate: int,
|
||||
random_seed: int,
|
||||
multi_turn: bool,
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo"]
|
||||
assert os.path.exists(param_path)
|
||||
|
|
@ -44,11 +45,20 @@ def train(
|
|||
model = model.to(device=device, dtype=torch.bfloat16)
|
||||
|
||||
cache_files = get_files(data_root_path)
|
||||
|
||||
dataset_kwargs = {
|
||||
"multi_turn": multi_turn,
|
||||
"bos_token_id": parameter.tokenizer.bos_id,
|
||||
"eos_token_id": parameter.tokenizer.eos_id,
|
||||
"user_token_id":parameter.tokenizer.encode("<|user|>")[0]
|
||||
}
|
||||
|
||||
dataset = DatasetLoader.load(
|
||||
train_type=train_type,
|
||||
load_path=cache_files,
|
||||
max_len=parameter.config.m_len,
|
||||
device=device
|
||||
device=device,
|
||||
dataset_kwargs=dataset_kwargs
|
||||
)
|
||||
|
||||
param_groups = [
|
||||
|
|
@ -89,6 +99,7 @@ def train(
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||
# train args
|
||||
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
||||
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
||||
|
|
@ -106,6 +117,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
|
||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||
|
||||
# other configs
|
||||
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train(
|
||||
|
|
@ -124,5 +138,6 @@ if __name__ == "__main__":
|
|||
n_iter_ckpt=args.n_iter_ckpt,
|
||||
ckpt_dir=args.ckpt_dir,
|
||||
train_type=args.train_type,
|
||||
random_seed=args.random_seed
|
||||
random_seed=args.random_seed,
|
||||
multi_turn=args.multi_turn
|
||||
)
|
||||
Loading…
Reference in New Issue