refactor(test): 更新训练配置参数名称并优化测试逻辑
This commit is contained in:
parent
1c9063fd3d
commit
0ebf53008e
|
|
@ -78,14 +78,13 @@ def test_dataset_loader(test_env):
|
||||||
def test_training_config(test_env):
|
def test_training_config(test_env):
|
||||||
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
train_type="seq",
|
|
||||||
dataset=test_env["dataset"],
|
dataset=test_env["dataset"],
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
ckpt_dir=test_env["test_dir"],
|
checkpoint_dir=test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
n_iter_ckpt=5,
|
checkpoint_interval=5,
|
||||||
n_iter_step=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42
|
random_seed=42
|
||||||
)
|
)
|
||||||
|
|
@ -94,51 +93,56 @@ def test_training_config(test_env):
|
||||||
def test_cosine_schedule(test_env):
|
def test_cosine_schedule(test_env):
|
||||||
assert test_env is not None
|
assert test_env is not None
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
warning_step=100,
|
warmup_steps=100,
|
||||||
total_iters=1000
|
total_steps=1000
|
||||||
)
|
)
|
||||||
kwargs = schedule_config.get_kwargs()
|
kwargs = schedule_config.get_kwargs()
|
||||||
assert kwargs["warning_step"] == 100
|
assert kwargs["warmup_steps"] == 100
|
||||||
assert kwargs["lr_decay_iters"] == 900
|
assert kwargs["lr_decay_steps"] == 900
|
||||||
|
|
||||||
|
|
||||||
def test_sgdr_schedule(test_env):
|
def test_sgdr_schedule(test_env):
|
||||||
assert test_env is not None
|
assert test_env is not None
|
||||||
schedule_config = SgdrScheduleConfig(
|
schedule_config = SgdrScheduleConfig(
|
||||||
warning_step=100,
|
warmup_steps=100,
|
||||||
cycle_length=200,
|
cycle_length=200,
|
||||||
T_mult=2
|
t_mult=2
|
||||||
)
|
)
|
||||||
kwargs = schedule_config.get_kwargs()
|
kwargs = schedule_config.get_kwargs()
|
||||||
assert kwargs["warning_step"] == 100
|
assert kwargs["warmup_steps"] == 100
|
||||||
assert kwargs["cycle_length"] == 200
|
assert kwargs["cycle_length"] == 200
|
||||||
assert kwargs["T_mult"] == 2
|
assert kwargs["t_mult"] == 2
|
||||||
|
|
||||||
def test_trainer_train(test_env):
|
def test_trainer_train(test_env):
|
||||||
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
train_type="seq",
|
|
||||||
dataset=test_env["dataset"],
|
dataset=test_env["dataset"],
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
ckpt_dir=test_env["test_dir"],
|
checkpoint_dir=test_env["test_dir"],
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
n_iter_ckpt=5,
|
checkpoint_interval=5,
|
||||||
n_iter_step=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42
|
random_seed=42
|
||||||
)
|
)
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
warning_step=100,
|
warmup_steps=100,
|
||||||
total_iters=1000
|
total_steps=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
train_config.strategy = StrategyFactory.load(
|
||||||
|
test_env["model"],
|
||||||
|
"seq",
|
||||||
|
pad_token_id=test_env["tokenizer"].pad_id
|
||||||
)
|
)
|
||||||
model_parameter = ModelParameter(
|
model_parameter = ModelParameter(
|
||||||
test_env["model"],
|
test_env["model"],
|
||||||
test_env["tokenizer"],
|
test_env["tokenizer"],
|
||||||
test_env["transformer_config"]
|
test_env["transformer_config"]
|
||||||
)
|
)
|
||||||
trainer = Trainer(model_parameter)
|
trainer = Trainer(model_parameter, train_config, schedule_config)
|
||||||
trainer.train(train_config, schedule_config)
|
trainer.train()
|
||||||
|
|
||||||
def test_checkpoint(test_env):
|
def test_checkpoint(test_env):
|
||||||
temp_dir = test_env["test_dir"]
|
temp_dir = test_env["test_dir"]
|
||||||
|
|
@ -195,36 +199,32 @@ def test_checkpoint_train(test_env):
|
||||||
|
|
||||||
dataset = InterruptDataset(length=10, interrupt_idx=3)
|
dataset = InterruptDataset(length=10, interrupt_idx=3)
|
||||||
param = ModelParameter(model, tokenizer, config)
|
param = ModelParameter(model, tokenizer, config)
|
||||||
trainer = Trainer(param)
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
optimizer = torch.optim.AdamW(test_env["model"].parameters())
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
train_type="seq",
|
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
ckpt_dir=test_env["test_dir"],
|
checkpoint_dir=test_env["test_dir"],
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
n_iter_ckpt=5,
|
checkpoint_interval=5,
|
||||||
n_iter_step=1,
|
accumulation_steps=1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
random_seed=42
|
random_seed=42
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_config.strategy = StrategyFactory.load(
|
||||||
|
test_env["model"],
|
||||||
|
"seq",
|
||||||
|
pad_token_id=test_env["tokenizer"].pad_id
|
||||||
|
)
|
||||||
schedule_config = CosineScheduleConfig(
|
schedule_config = CosineScheduleConfig(
|
||||||
warning_step=100,
|
warmup_steps=100,
|
||||||
total_iters=1000
|
total_steps=1000
|
||||||
)
|
)
|
||||||
|
trainer = Trainer(param, train_config, schedule_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trainer.train(
|
trainer.train()
|
||||||
train_config=train_config,
|
|
||||||
schedule_config=schedule_config,
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
checkpoint = trainer.checkpoint
|
checkpoint = trainer.checkpoint
|
||||||
trainer.train(
|
trainer.train(train_checkpoint=checkpoint)
|
||||||
train_config=train_config,
|
|
||||||
schedule_config=schedule_config,
|
|
||||||
train_checkpoint=checkpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue