feat(train): 添加 max_len 和 step_size 参数支持
This commit is contained in:
parent
d94fc5a87a
commit
46b2a0f86f
17
train.py
17
train.py
|
|
@ -36,7 +36,8 @@ def train(
|
||||||
max_grad_norm: float,
|
max_grad_norm: float,
|
||||||
embdeding_lr_rate: int,
|
embdeding_lr_rate: int,
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
multi_turn: bool,
|
max_len: int,
|
||||||
|
step_size: int,
|
||||||
resume_from_checkpoint: bool
|
resume_from_checkpoint: bool
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
|
|
@ -47,16 +48,16 @@ def train(
|
||||||
|
|
||||||
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
|
if isinstance(parameter, Checkpoint) and resume_from_checkpoint:
|
||||||
checkpoint = parameter
|
checkpoint = parameter
|
||||||
|
|
||||||
|
if max_len is None:
|
||||||
|
max_len = parameter.config.m_len
|
||||||
|
|
||||||
model = parameter.model
|
model = parameter.model
|
||||||
|
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
model = model.to(device=device, dtype=torch.bfloat16)
|
model = model.to(device=device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
cache_files = get_files(data_root_path)
|
cache_files = get_files(data_root_path)
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"multi_turn": multi_turn,
|
|
||||||
"dpo_beta": dpo_beta,
|
"dpo_beta": dpo_beta,
|
||||||
"bos_token_id": parameter.tokenizer.bos_id,
|
"bos_token_id": parameter.tokenizer.bos_id,
|
||||||
"eos_token_id": parameter.tokenizer.eos_id,
|
"eos_token_id": parameter.tokenizer.eos_id,
|
||||||
|
|
@ -73,7 +74,8 @@ def train(
|
||||||
dataset = DatasetLoader.load(
|
dataset = DatasetLoader.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=cache_files,
|
load_path=cache_files,
|
||||||
max_len=parameter.config.m_len,
|
max_len=max_len,
|
||||||
|
step_size=step_size
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -138,10 +140,11 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
parser.add_argument("--random_seed", type=int, default=3407, help="Random seed for reproducibility.")
|
||||||
|
|
||||||
# other configs
|
# other configs
|
||||||
|
parser.add_argument("--max_len", type=int, default=None, help="the max length of the input sequence.")
|
||||||
|
parser.add_argument("--step_size", type=int, default=None, help="the step size of the input sequence.")
|
||||||
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch for training.")
|
||||||
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
parser.add_argument("--start_batch", type=int, default=0, help="Start batch for training.")
|
||||||
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
parser.add_argument("--resume_from_checkpoint", type=bool, default=False, help="train from checkpoint or not.")
|
||||||
parser.add_argument("--multi_turn", type=bool, default=False, help="Whether to use multi-turn convsersation training.")
|
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue