From 0b96b11a6e5d2e3c357dabe51c46ee2c8d4efa03 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 28 Sep 2025 14:38:23 +0800 Subject: [PATCH] =?UTF-8?q?test(trainer):=20=E5=A2=9E=E5=8A=A0=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E4=B8=AD=E6=96=AD=E4=B8=8E=E6=A3=80=E6=9F=A5=E7=82=B9?= =?UTF-8?q?=E6=81=A2=E5=A4=8D=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_trainer.py | 41 +++++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 227c9c3..eb0412d 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -170,12 +170,30 @@ def test_checkpoint(test_env): def test_checkpoint_train(test_env): - temp_dir = test_env["test_dir"] config = test_env["transformer_config"] model = test_env["model"] tokenizer = test_env["tokenizer"] - dataset = test_env["dataset"] + class InterruptDataset(Dataset): + def __init__(self, length, interrupt_idx=0): + self.length = length + self.interrupt_idx = interrupt_idx + + def __len__(self): + return self.length + + def __getitem__(self, idx): + if idx == self.interrupt_idx: + self.interrupt_idx = -1 + raise Exception("Interrupt") + + return { + "input_ids": torch.randint(0, 1000, (64,)), + "target_ids": torch.randint(0, 1000, (64,)) + } + + + dataset = InterruptDataset(length=10, interrupt_idx=3) param = ModelParameter(model, tokenizer, config) trainer = Trainer(param) @@ -185,7 +203,7 @@ def test_checkpoint_train(test_env): dataset=dataset, optimizer=optimizer, ckpt_dir=test_env["test_dir"], - n_epoch=1, + n_epoch=2, batch_size=2, n_iter_ckpt=5, n_iter_step=1, @@ -196,10 +214,17 @@ def test_checkpoint_train(test_env): warning_step=100, total_iters=1000 ) - - trainer.train( - train_config=train_config, - schedule_config=schedule_config, - ) + try: + trainer.train( + train_config=train_config, + schedule_config=schedule_config, + ) + except Exception: + checkpoint = trainer.checkpoint + trainer.train( + train_config=train_config, + schedule_config=schedule_config, + train_checkpoint=checkpoint + ) \ No newline at end of file