test(trainer): 增加训练中断与检查点恢复测试
This commit is contained in:
parent
25ec56a1f5
commit
0b96b11a6e
|
|
@ -170,12 +170,30 @@ def test_checkpoint(test_env):
|
|||
|
||||
|
||||
def test_checkpoint_train(test_env):
|
||||
temp_dir = test_env["test_dir"]
|
||||
config = test_env["transformer_config"]
|
||||
model = test_env["model"]
|
||||
tokenizer = test_env["tokenizer"]
|
||||
dataset = test_env["dataset"]
|
||||
|
||||
class InterruptDataset(Dataset):
|
||||
def __init__(self, length, interrupt_idx=0):
|
||||
self.length = length
|
||||
self.interrupt_idx = interrupt_idx
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx == self.interrupt_idx:
|
||||
self.interrupt_idx = -1
|
||||
raise Exception("Interrupt")
|
||||
|
||||
return {
|
||||
"input_ids": torch.randint(0, 1000, (64,)),
|
||||
"target_ids": torch.randint(0, 1000, (64,))
|
||||
}
|
||||
|
||||
|
||||
dataset = InterruptDataset(length=10, interrupt_idx=3)
|
||||
param = ModelParameter(model, tokenizer, config)
|
||||
trainer = Trainer(param)
|
||||
|
||||
|
|
@ -185,7 +203,7 @@ def test_checkpoint_train(test_env):
|
|||
dataset=dataset,
|
||||
optimizer=optimizer,
|
||||
ckpt_dir=test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
n_iter_ckpt=5,
|
||||
n_iter_step=1,
|
||||
|
|
@ -196,10 +214,17 @@ def test_checkpoint_train(test_env):
|
|||
warning_step=100,
|
||||
total_iters=1000
|
||||
)
|
||||
|
||||
try:
|
||||
trainer.train(
|
||||
train_config=train_config,
|
||||
schedule_config=schedule_config,
|
||||
)
|
||||
except Exception:
|
||||
checkpoint = trainer.checkpoint
|
||||
trainer.train(
|
||||
train_config=train_config,
|
||||
schedule_config=schedule_config,
|
||||
train_checkpoint=checkpoint
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue