test(trainer): 调整测试参数以提高训练和断言的稳定性
This commit is contained in:
parent
83c08cfbb9
commit
0a03a15679
|
|
@ -22,13 +22,13 @@ def test_env():
|
|||
test_dir = tempfile.mkdtemp()
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
|
||||
n_dim_choices = [16, 32, 64]
|
||||
n_head_choices = [4, 8, 16]
|
||||
n_dim_choices = [8, 16, 32]
|
||||
n_head_choices = [2, 4]
|
||||
|
||||
n_dim = int(np.random.choice(n_dim_choices))
|
||||
n_head = int(np.random.choice(n_head_choices))
|
||||
n_kvhead = n_head // 4
|
||||
d_ffn = n_dim * 4
|
||||
n_kvhead = n_head // 2
|
||||
d_ffn = n_dim * 2
|
||||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
|
|
@ -81,7 +81,6 @@ def test_env():
|
|||
loss_mask = build_loss_mask(input_ids, 0, 1)
|
||||
attn_mask = build_attention_mask(input_ids, 2, True)
|
||||
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"target_ids": target_ids,
|
||||
|
|
@ -220,12 +219,12 @@ def test_multi_turn_training(test_env):
|
|||
dataset=test_env["multi_turn_dataset"],
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
checkpoint_interval=3,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=np.random.randint(1000)
|
||||
random_seed=int(np.random.randint(1000))
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
|
|
@ -289,7 +288,7 @@ def test_gradient_accumulation(test_env):
|
|||
)
|
||||
|
||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
||||
checkpoint = trainer.train()
|
||||
trainer.train()
|
||||
|
||||
assert train_config.accumulation_steps == accumulation_steps
|
||||
|
||||
|
|
@ -432,7 +431,7 @@ def test_early_stopping_simulation(test_env):
|
|||
dataset=dataset,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
checkpoint_interval=1,
|
||||
accumulation_steps=1,
|
||||
|
|
@ -459,7 +458,7 @@ def test_early_stopping_simulation(test_env):
|
|||
pass
|
||||
|
||||
checkpoint = trainer.train(checkpoint)
|
||||
assert len(checkpoint.loss_list) == 5 + 1
|
||||
assert len(checkpoint.loss_list) == 10 + 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue