diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 227c9c3..eb0412d 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -170,12 +170,30 @@ def test_checkpoint(test_env): def test_checkpoint_train(test_env): - temp_dir = test_env["test_dir"] config = test_env["transformer_config"] model = test_env["model"] tokenizer = test_env["tokenizer"] - dataset = test_env["dataset"] + class InterruptDataset(Dataset): + def __init__(self, length, interrupt_idx=0): + self.length = length + self.interrupt_idx = interrupt_idx + + def __len__(self): + return self.length + + def __getitem__(self, idx): + if idx == self.interrupt_idx: + self.interrupt_idx = -1 + raise Exception("Interrupt") + + return { + "input_ids": torch.randint(0, 1000, (64,)), + "target_ids": torch.randint(0, 1000, (64,)) + } + + + dataset = InterruptDataset(length=10, interrupt_idx=3) param = ModelParameter(model, tokenizer, config) trainer = Trainer(param) @@ -185,7 +203,7 @@ def test_checkpoint_train(test_env): dataset=dataset, optimizer=optimizer, ckpt_dir=test_env["test_dir"], - n_epoch=1, + n_epoch=2, batch_size=2, n_iter_ckpt=5, n_iter_step=1, @@ -196,10 +214,17 @@ def test_checkpoint_train(test_env): warning_step=100, total_iters=1000 ) - - trainer.train( - train_config=train_config, - schedule_config=schedule_config, - ) + try: + trainer.train( + train_config=train_config, + schedule_config=schedule_config, + ) + except Exception: + checkpoint = trainer.checkpoint + trainer.train( + train_config=train_config, + schedule_config=schedule_config, + train_checkpoint=checkpoint + ) \ No newline at end of file