From 9d5aa952e0f6d5b6570fdf408e60e4bfa0b49713 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Oct 2025 21:31:39 +0800 Subject: [PATCH] =?UTF-8?q?feat(tests):=20=E9=87=8D=E6=9E=84=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=8E=AF=E5=A2=83,=20=E4=BE=BF=E4=BA=8Epickle=20?= =?UTF-8?q?=E5=BA=8F=E5=88=97=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 102 ++++++++++++++++++++++------------- tests/test_callbacks.py | 2 +- tests/test_early_stopping.py | 24 +-------- 3 files changed, 67 insertions(+), 61 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b03f98b..cba44cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,66 @@ from khaosz.trainer.data_util import * matplotlib.use("Agg") + +class RandomDataset(Dataset): + def __init__(self, length=None, max_length=64, vocab_size=1000): + self.length = length or int(np.random.randint(100, 200)) + self.max_length = max_length + self.vocab_size = vocab_size + + def __len__(self): + return self.length + + def __getitem__(self, idx): + return { + "input_ids": torch.randint(0, self.vocab_size, (self.max_length,)), + "target_ids": torch.randint(0, self.vocab_size, (self.max_length,)) + } + + +class MultiTurnDataset(Dataset): + def __init__(self, length=None, max_length=64, vocab_size=1000): + self.length = length or int(np.random.randint(100, 200)) + self.max_length = max_length + self.vocab_size = vocab_size + + def __len__(self): + return self.length + + def __getitem__(self, idx): + input_ids = torch.randint(0, self.vocab_size, (self.max_length,)) + target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) + loss_mask = build_loss_mask(input_ids, 0, 1) + attn_mask = build_attention_mask(input_ids, 2, True) + + return { + "input_ids": input_ids, + "target_ids": target_ids, + "loss_mask": loss_mask, + "attn_mask": attn_mask, + } + + +class EarlyStoppingDataset(Dataset): + def __init__(self, length=10, stop_after=5): + self.length = length + self.stop_after = stop_after + self.count = 0 + + def __len__(self): + return self.length + + def __getitem__(self, idx): + self.count += 1 + if self.count == self.stop_after: + raise RuntimeError("Simulated early stopping") + + return { + "input_ids": torch.randint(0, 1000, (64,)), + "target_ids": torch.randint(0, 1000, (64,)) + } + + @pytest.fixture def base_test_env(request: pytest.FixtureRequest): func_name = request.function.__name__ @@ -60,49 +120,15 @@ def base_test_env(request: pytest.FixtureRequest): @pytest.fixture def random_dataset(): - class RandomDataset(Dataset): - def __init__(self, length=None, max_length=64, vocab_size=1000): - self.length = length or int(np.random.randint(100, 200)) - self.max_length = max_length - self.vocab_size = vocab_size - - def __len__(self): - return self.length - - def __getitem__(self, idx): - return { - "input_ids": torch.randint(0, self.vocab_size, (self.max_length,)), - "target_ids": torch.randint(0, self.vocab_size, (self.max_length,)) - } - dataset = RandomDataset() - yield dataset @pytest.fixture def multi_turn_dataset(): - class MultiTurnDataset(Dataset): - def __init__(self, length=None, max_length=64, vocab_size=1000): - self.length = length or int(np.random.randint(100, 200)) - self.max_length = max_length - self.vocab_size = vocab_size - - def __len__(self): - return self.length - - def __getitem__(self, idx): - input_ids = torch.randint(0, self.vocab_size, (self.max_length,)) - target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) - loss_mask = build_loss_mask(input_ids, 0, 1) - attn_mask = build_attention_mask(input_ids, 2, True) - - return { - "input_ids": input_ids, - "target_ids": target_ids, - "loss_mask": loss_mask, - "attn_mask": attn_mask, - } - dataset = MultiTurnDataset() + yield dataset +@pytest.fixture +def early_stopping_dataset(): + dataset = EarlyStoppingDataset() yield dataset \ No newline at end of file diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index ba4cc9b..d48d548 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -27,7 +27,7 @@ def test_callback_integration(base_test_env, random_dataset): # Create custom callbacks to track calls callback_calls = [] - class TrackingCallback(TrainerCallback): + class TrackingCallback(TrainCallback): def on_train_begin(self, trainer, **kwargs): callback_calls.append('on_train_begin') diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index b9cf322..b342494 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -5,32 +5,12 @@ from khaosz.core import * from khaosz.trainer import * from khaosz.trainer.data_util import * -def test_early_stopping_simulation(base_test_env): +def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior""" - class EarlyStoppingDataset(Dataset): - def __init__(self, length=10, stop_after=5): - self.length = length - self.stop_after = stop_after - self.count = 0 - - def __len__(self): - return self.length - - def __getitem__(self, idx): - self.count += 1 - if self.count == self.stop_after: - raise RuntimeError("Simulated early stopping") - - return { - "input_ids": torch.randint(0, 1000, (64,)), - "target_ids": torch.randint(0, 1000, (64,)) - } - - dataset = EarlyStoppingDataset() optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) train_config = TrainConfig( - dataset=dataset, + dataset=early_stopping_dataset, optimizer=optimizer, checkpoint_dir=base_test_env["test_dir"], n_epoch=2,