test(trainer): 增加训练中断与检查点恢复测试

This commit is contained in:
ViperEkura 2025-09-28 14:38:23 +08:00
parent 25ec56a1f5
commit 0b96b11a6e
1 changed files with 33 additions and 8 deletions

View File

@ -170,12 +170,30 @@ def test_checkpoint(test_env):
def test_checkpoint_train(test_env): def test_checkpoint_train(test_env):
temp_dir = test_env["test_dir"]
config = test_env["transformer_config"] config = test_env["transformer_config"]
model = test_env["model"] model = test_env["model"]
tokenizer = test_env["tokenizer"] tokenizer = test_env["tokenizer"]
dataset = test_env["dataset"]
class InterruptDataset(Dataset):
def __init__(self, length, interrupt_idx=0):
self.length = length
self.interrupt_idx = interrupt_idx
def __len__(self):
return self.length
def __getitem__(self, idx):
if idx == self.interrupt_idx:
self.interrupt_idx = -1
raise Exception("Interrupt")
return {
"input_ids": torch.randint(0, 1000, (64,)),
"target_ids": torch.randint(0, 1000, (64,))
}
dataset = InterruptDataset(length=10, interrupt_idx=3)
param = ModelParameter(model, tokenizer, config) param = ModelParameter(model, tokenizer, config)
trainer = Trainer(param) trainer = Trainer(param)
@ -185,7 +203,7 @@ def test_checkpoint_train(test_env):
dataset=dataset, dataset=dataset,
optimizer=optimizer, optimizer=optimizer,
ckpt_dir=test_env["test_dir"], ckpt_dir=test_env["test_dir"],
n_epoch=1, n_epoch=2,
batch_size=2, batch_size=2,
n_iter_ckpt=5, n_iter_ckpt=5,
n_iter_step=1, n_iter_step=1,
@ -196,10 +214,17 @@ def test_checkpoint_train(test_env):
warning_step=100, warning_step=100,
total_iters=1000 total_iters=1000
) )
try:
trainer.train( trainer.train(
train_config=train_config, train_config=train_config,
schedule_config=schedule_config, schedule_config=schedule_config,
) )
except Exception:
checkpoint = trainer.checkpoint
trainer.train(
train_config=train_config,
schedule_config=schedule_config,
train_checkpoint=checkpoint
)