test(trainer): 调整测试参数以提高训练和断言的稳定性
This commit is contained in:
parent
83c08cfbb9
commit
0a03a15679
|
|
@ -22,13 +22,13 @@ def test_env():
|
||||||
test_dir = tempfile.mkdtemp()
|
test_dir = tempfile.mkdtemp()
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
||||||
n_dim_choices = [16, 32, 64]
|
n_dim_choices = [8, 16, 32]
|
||||||
n_head_choices = [4, 8, 16]
|
n_head_choices = [2, 4]
|
||||||
|
|
||||||
n_dim = int(np.random.choice(n_dim_choices))
|
n_dim = int(np.random.choice(n_dim_choices))
|
||||||
n_head = int(np.random.choice(n_head_choices))
|
n_head = int(np.random.choice(n_head_choices))
|
||||||
n_kvhead = n_head // 4
|
n_kvhead = n_head // 2
|
||||||
d_ffn = n_dim * 4
|
d_ffn = n_dim * 2
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
|
|
@ -80,7 +80,6 @@ def test_env():
|
||||||
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||||
loss_mask = build_loss_mask(input_ids, 0, 1)
|
loss_mask = build_loss_mask(input_ids, 0, 1)
|
||||||
attn_mask = build_attention_mask(input_ids, 2, True)
|
attn_mask = build_attention_mask(input_ids, 2, True)
|
||||||
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
|
@ -220,12 +219,12 @@ def test_multi_turn_training(test_env):
|
||||||
dataset=test_env["multi_turn_dataset"],
|
dataset=test_env["multi_turn_dataset"],
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
checkpoint_dir=test_env["test_dir"],
|
checkpoint_dir=test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=2,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=3,
|
checkpoint_interval=3,
|
||||||
accumulation_steps=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=np.random.randint(1000)
|
random_seed=int(np.random.randint(1000))
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
|
|
@ -289,7 +288,7 @@ def test_gradient_accumulation(test_env):
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
trainer = Trainer(model_parameter, train_config, schedule_config)
|
||||||
checkpoint = trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
assert train_config.accumulation_steps == accumulation_steps
|
assert train_config.accumulation_steps == accumulation_steps
|
||||||
|
|
||||||
|
|
@ -432,7 +431,7 @@ def test_early_stopping_simulation(test_env):
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
checkpoint_dir=test_env["test_dir"],
|
checkpoint_dir=test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=2,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
checkpoint_interval=1,
|
checkpoint_interval=1,
|
||||||
accumulation_steps=1,
|
accumulation_steps=1,
|
||||||
|
|
@ -459,7 +458,7 @@ def test_early_stopping_simulation(test_env):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
checkpoint = trainer.train(checkpoint)
|
checkpoint = trainer.train(checkpoint)
|
||||||
assert len(checkpoint.loss_list) == 5 + 1
|
assert len(checkpoint.loss_list) == 10 + 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue