refactor(test): 更新训练配置参数名称并优化测试逻辑

This commit is contained in:
ViperEkura 2025-09-28 22:14:39 +08:00
parent 1c9063fd3d
commit 0ebf53008e
1 changed files with 40 additions and 40 deletions

View File

@ -78,14 +78,13 @@ def test_dataset_loader(test_env):
def test_training_config(test_env): def test_training_config(test_env):
optimizer = torch.optim.AdamW(test_env["model"].parameters()) optimizer = torch.optim.AdamW(test_env["model"].parameters())
train_config = TrainConfig( train_config = TrainConfig(
train_type="seq",
dataset=test_env["dataset"], dataset=test_env["dataset"],
optimizer=optimizer, optimizer=optimizer,
ckpt_dir=test_env["test_dir"], checkpoint_dir=test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=2, batch_size=2,
n_iter_ckpt=5, checkpoint_interval=5,
n_iter_step=1, accumulation_steps=1,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42 random_seed=42
) )
@ -94,51 +93,56 @@ def test_training_config(test_env):
def test_cosine_schedule(test_env): def test_cosine_schedule(test_env):
assert test_env is not None assert test_env is not None
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(
warning_step=100, warmup_steps=100,
total_iters=1000 total_steps=1000
) )
kwargs = schedule_config.get_kwargs() kwargs = schedule_config.get_kwargs()
assert kwargs["warning_step"] == 100 assert kwargs["warmup_steps"] == 100
assert kwargs["lr_decay_iters"] == 900 assert kwargs["lr_decay_steps"] == 900
def test_sgdr_schedule(test_env): def test_sgdr_schedule(test_env):
assert test_env is not None assert test_env is not None
schedule_config = SgdrScheduleConfig( schedule_config = SgdrScheduleConfig(
warning_step=100, warmup_steps=100,
cycle_length=200, cycle_length=200,
T_mult=2 t_mult=2
) )
kwargs = schedule_config.get_kwargs() kwargs = schedule_config.get_kwargs()
assert kwargs["warning_step"] == 100 assert kwargs["warmup_steps"] == 100
assert kwargs["cycle_length"] == 200 assert kwargs["cycle_length"] == 200
assert kwargs["T_mult"] == 2 assert kwargs["t_mult"] == 2
def test_trainer_train(test_env): def test_trainer_train(test_env):
optimizer = torch.optim.AdamW(test_env["model"].parameters()) optimizer = torch.optim.AdamW(test_env["model"].parameters())
train_config = TrainConfig( train_config = TrainConfig(
train_type="seq",
dataset=test_env["dataset"], dataset=test_env["dataset"],
optimizer=optimizer, optimizer=optimizer,
ckpt_dir=test_env["test_dir"], checkpoint_dir=test_env["test_dir"],
n_epoch=1, n_epoch=1,
batch_size=2, batch_size=2,
n_iter_ckpt=5, checkpoint_interval=5,
n_iter_step=1, accumulation_steps=1,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42 random_seed=42
) )
schedule_config = CosineScheduleConfig( schedule_config = CosineScheduleConfig(
warning_step=100, warmup_steps=100,
total_iters=1000 total_steps=1000
)
train_config.strategy = StrategyFactory.load(
test_env["model"],
"seq",
pad_token_id=test_env["tokenizer"].pad_id
) )
model_parameter = ModelParameter( model_parameter = ModelParameter(
test_env["model"], test_env["model"],
test_env["tokenizer"], test_env["tokenizer"],
test_env["transformer_config"] test_env["transformer_config"]
) )
trainer = Trainer(model_parameter) trainer = Trainer(model_parameter, train_config, schedule_config)
trainer.train(train_config, schedule_config) trainer.train()
def test_checkpoint(test_env): def test_checkpoint(test_env):
temp_dir = test_env["test_dir"] temp_dir = test_env["test_dir"]
@ -195,36 +199,32 @@ def test_checkpoint_train(test_env):
dataset = InterruptDataset(length=10, interrupt_idx=3) dataset = InterruptDataset(length=10, interrupt_idx=3)
param = ModelParameter(model, tokenizer, config) param = ModelParameter(model, tokenizer, config)
trainer = Trainer(param)
optimizer = torch.optim.AdamW(test_env["model"].parameters()) optimizer = torch.optim.AdamW(test_env["model"].parameters())
train_config = TrainConfig( train_config = TrainConfig(
train_type="seq",
dataset=dataset, dataset=dataset,
optimizer=optimizer, optimizer=optimizer,
ckpt_dir=test_env["test_dir"], checkpoint_dir=test_env["test_dir"],
n_epoch=2, n_epoch=2,
batch_size=2, batch_size=2,
n_iter_ckpt=5, checkpoint_interval=5,
n_iter_step=1, accumulation_steps=1,
max_grad_norm=1.0, max_grad_norm=1.0,
random_seed=42 random_seed=42
) )
schedule_config = CosineScheduleConfig(
warning_step=100, train_config.strategy = StrategyFactory.load(
total_iters=1000 test_env["model"],
"seq",
pad_token_id=test_env["tokenizer"].pad_id
) )
schedule_config = CosineScheduleConfig(
warmup_steps=100,
total_steps=1000
)
trainer = Trainer(param, train_config, schedule_config)
try: try:
trainer.train( trainer.train()
train_config=train_config,
schedule_config=schedule_config,
)
except Exception: except Exception:
checkpoint = trainer.checkpoint checkpoint = trainer.checkpoint
trainer.train( trainer.train(train_checkpoint=checkpoint)
train_config=train_config,
schedule_config=schedule_config,
train_checkpoint=checkpoint
)