refactor(test): 更新训练配置参数名称并优化测试逻辑
This commit is contained in:
parent
1c9063fd3d
commit
0ebf53008e
|
|
@ -78,14 +78,13 @@ def test_dataset_loader(test_env):
|
|||
def test_training_config(test_env):
|
||||
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
||||
train_config = TrainConfig(
|
||||
train_type="seq",
|
||||
dataset=test_env["dataset"],
|
||||
optimizer=optimizer,
|
||||
ckpt_dir=test_env["test_dir"],
|
||||
checkpoint_dir=test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
n_iter_ckpt=5,
|
||||
n_iter_step=1,
|
||||
checkpoint_interval=5,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42
|
||||
)
|
||||
|
|
@ -94,51 +93,56 @@ def test_training_config(test_env):
|
|||
def test_cosine_schedule(test_env):
|
||||
assert test_env is not None
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warning_step=100,
|
||||
total_iters=1000
|
||||
warmup_steps=100,
|
||||
total_steps=1000
|
||||
)
|
||||
kwargs = schedule_config.get_kwargs()
|
||||
assert kwargs["warning_step"] == 100
|
||||
assert kwargs["lr_decay_iters"] == 900
|
||||
assert kwargs["warmup_steps"] == 100
|
||||
assert kwargs["lr_decay_steps"] == 900
|
||||
|
||||
|
||||
def test_sgdr_schedule(test_env):
|
||||
assert test_env is not None
|
||||
schedule_config = SgdrScheduleConfig(
|
||||
warning_step=100,
|
||||
warmup_steps=100,
|
||||
cycle_length=200,
|
||||
T_mult=2
|
||||
t_mult=2
|
||||
)
|
||||
kwargs = schedule_config.get_kwargs()
|
||||
assert kwargs["warning_step"] == 100
|
||||
assert kwargs["warmup_steps"] == 100
|
||||
assert kwargs["cycle_length"] == 200
|
||||
assert kwargs["T_mult"] == 2
|
||||
assert kwargs["t_mult"] == 2
|
||||
|
||||
def test_trainer_train(test_env):
|
||||
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
||||
train_config = TrainConfig(
|
||||
train_type="seq",
|
||||
dataset=test_env["dataset"],
|
||||
optimizer=optimizer,
|
||||
ckpt_dir=test_env["test_dir"],
|
||||
checkpoint_dir=test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_size=2,
|
||||
n_iter_ckpt=5,
|
||||
n_iter_step=1,
|
||||
checkpoint_interval=5,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42
|
||||
)
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warning_step=100,
|
||||
total_iters=1000
|
||||
warmup_steps=100,
|
||||
total_steps=1000
|
||||
)
|
||||
|
||||
train_config.strategy = StrategyFactory.load(
|
||||
test_env["model"],
|
||||
"seq",
|
||||
pad_token_id=test_env["tokenizer"].pad_id
|
||||
)
|
||||
model_parameter = ModelParameter(
|
||||
test_env["model"],
|
||||
test_env["tokenizer"],
|
||||
test_env["transformer_config"]
|
||||
)
|
||||
trainer = Trainer(model_parameter)
|
||||
trainer.train(train_config, schedule_config)
|
||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
||||
trainer.train()
|
||||
|
||||
def test_checkpoint(test_env):
|
||||
temp_dir = test_env["test_dir"]
|
||||
|
|
@ -195,36 +199,32 @@ def test_checkpoint_train(test_env):
|
|||
|
||||
dataset = InterruptDataset(length=10, interrupt_idx=3)
|
||||
param = ModelParameter(model, tokenizer, config)
|
||||
trainer = Trainer(param)
|
||||
|
||||
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
||||
train_config = TrainConfig(
|
||||
train_type="seq",
|
||||
dataset=dataset,
|
||||
optimizer=optimizer,
|
||||
ckpt_dir=test_env["test_dir"],
|
||||
checkpoint_dir=test_env["test_dir"],
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
n_iter_ckpt=5,
|
||||
n_iter_step=1,
|
||||
checkpoint_interval=5,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42
|
||||
)
|
||||
|
||||
train_config.strategy = StrategyFactory.load(
|
||||
test_env["model"],
|
||||
"seq",
|
||||
pad_token_id=test_env["tokenizer"].pad_id
|
||||
)
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warning_step=100,
|
||||
total_iters=1000
|
||||
warmup_steps=100,
|
||||
total_steps=1000
|
||||
)
|
||||
trainer = Trainer(param, train_config, schedule_config)
|
||||
|
||||
try:
|
||||
trainer.train(
|
||||
train_config=train_config,
|
||||
schedule_config=schedule_config,
|
||||
)
|
||||
trainer.train()
|
||||
except Exception:
|
||||
checkpoint = trainer.checkpoint
|
||||
trainer.train(
|
||||
train_config=train_config,
|
||||
schedule_config=schedule_config,
|
||||
train_checkpoint=checkpoint
|
||||
)
|
||||
|
||||
|
||||
trainer.train(train_checkpoint=checkpoint)
|
||||
Loading…
Reference in New Issue