feat(train): 添加多轮对话训练支持
This commit is contained in:
parent
1169cfad82
commit
30ac07418c
19
train.py
19
train.py
|
|
@ -33,6 +33,7 @@ def train(
|
||||||
max_grad_norm: float,
|
max_grad_norm: float,
|
||||||
embdeding_lr_rate: int,
|
embdeding_lr_rate: int,
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
|
multi_turn: bool,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
@ -44,11 +45,20 @@ def train(
|
||||||
model = model.to(device=device, dtype=torch.bfloat16)
|
model = model.to(device=device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
cache_files = get_files(data_root_path)
|
cache_files = get_files(data_root_path)
|
||||||
|
|
||||||
|
dataset_kwargs = {
|
||||||
|
"multi_turn": multi_turn,
|
||||||
|
"bos_token_id": parameter.tokenizer.bos_id,
|
||||||
|
"eos_token_id": parameter.tokenizer.eos_id,
|
||||||
|
"user_token_id":parameter.tokenizer.encode("<|user|>")[0]
|
||||||
|
}
|
||||||
|
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=cache_files,
|
load_path=cache_files,
|
||||||
max_len=parameter.config.m_len,
|
max_len=parameter.config.m_len,
|
||||||
device=device
|
device=device,
|
||||||
|
dataset_kwargs=dataset_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
param_groups = [
|
param_groups = [
|
||||||
|
|
@ -89,6 +99,7 @@ def train(
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||||
|
# train args
|
||||||
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
parser.add_argument("--train_type",choices=["seq", "sft", "dpo"], help="Train type.")
|
||||||
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
parser.add_argument("--data_root_path", type=str, required=True, help="Path to the root directory of the dataset.")
|
||||||
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
parser.add_argument("--param_path", type=str, required=True, help="Path to the model parameters or resume checkpoint.")
|
||||||
|
|
@ -105,6 +116,9 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
parser.add_argument("--adamw_weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer.")
|
||||||
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
|
parser.add_argument("--embdeding_lr_rate", type=float, default=1.0, help="The rate between the embedding layers lr rate and the max lr rate.")
|
||||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||||
|
|
||||||
|
# other configs
|
||||||
|
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -124,5 +138,6 @@ if __name__ == "__main__":
|
||||||
n_iter_ckpt=args.n_iter_ckpt,
|
n_iter_ckpt=args.n_iter_ckpt,
|
||||||
ckpt_dir=args.ckpt_dir,
|
ckpt_dir=args.ckpt_dir,
|
||||||
train_type=args.train_type,
|
train_type=args.train_type,
|
||||||
random_seed=args.random_seed
|
random_seed=args.random_seed,
|
||||||
|
multi_turn=args.multi_turn
|
||||||
)
|
)
|
||||||
Loading…
Reference in New Issue