feat(tests): 重构测试环境, 便于pickle 序列化

This commit is contained in:
ViperEkura 2025-10-04 21:31:39 +08:00
parent 2ccd7bd583
commit 9d5aa952e0
3 changed files with 67 additions and 61 deletions

View File

@ -15,6 +15,66 @@ from khaosz.trainer.data_util import *
matplotlib.use("Agg") matplotlib.use("Agg")
class RandomDataset(Dataset):
def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length
self.vocab_size = vocab_size
def __len__(self):
return self.length
def __getitem__(self, idx):
return {
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
}
class MultiTurnDataset(Dataset):
def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length
self.vocab_size = vocab_size
def __len__(self):
return self.length
def __getitem__(self, idx):
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
loss_mask = build_loss_mask(input_ids, 0, 1)
attn_mask = build_attention_mask(input_ids, 2, True)
return {
"input_ids": input_ids,
"target_ids": target_ids,
"loss_mask": loss_mask,
"attn_mask": attn_mask,
}
class EarlyStoppingDataset(Dataset):
def __init__(self, length=10, stop_after=5):
self.length = length
self.stop_after = stop_after
self.count = 0
def __len__(self):
return self.length
def __getitem__(self, idx):
self.count += 1
if self.count == self.stop_after:
raise RuntimeError("Simulated early stopping")
return {
"input_ids": torch.randint(0, 1000, (64,)),
"target_ids": torch.randint(0, 1000, (64,))
}
@pytest.fixture @pytest.fixture
def base_test_env(request: pytest.FixtureRequest): def base_test_env(request: pytest.FixtureRequest):
func_name = request.function.__name__ func_name = request.function.__name__
@ -60,49 +120,15 @@ def base_test_env(request: pytest.FixtureRequest):
@pytest.fixture @pytest.fixture
def random_dataset(): def random_dataset():
class RandomDataset(Dataset):
def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length
self.vocab_size = vocab_size
def __len__(self):
return self.length
def __getitem__(self, idx):
return {
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
}
dataset = RandomDataset() dataset = RandomDataset()
yield dataset yield dataset
@pytest.fixture @pytest.fixture
def multi_turn_dataset(): def multi_turn_dataset():
class MultiTurnDataset(Dataset):
def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length
self.vocab_size = vocab_size
def __len__(self):
return self.length
def __getitem__(self, idx):
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
loss_mask = build_loss_mask(input_ids, 0, 1)
attn_mask = build_attention_mask(input_ids, 2, True)
return {
"input_ids": input_ids,
"target_ids": target_ids,
"loss_mask": loss_mask,
"attn_mask": attn_mask,
}
dataset = MultiTurnDataset() dataset = MultiTurnDataset()
yield dataset
@pytest.fixture
def early_stopping_dataset():
dataset = EarlyStoppingDataset()
yield dataset yield dataset

View File

@ -27,7 +27,7 @@ def test_callback_integration(base_test_env, random_dataset):
# Create custom callbacks to track calls # Create custom callbacks to track calls
callback_calls = [] callback_calls = []
class TrackingCallback(TrainerCallback): class TrackingCallback(TrainCallback):
def on_train_begin(self, trainer, **kwargs): def on_train_begin(self, trainer, **kwargs):
callback_calls.append('on_train_begin') callback_calls.append('on_train_begin')

View File

@ -5,32 +5,12 @@ from khaosz.core import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.trainer.data_util import * from khaosz.trainer.data_util import *
def test_early_stopping_simulation(base_test_env): def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior""" """Simulate early stopping behavior"""
class EarlyStoppingDataset(Dataset):
def __init__(self, length=10, stop_after=5):
self.length = length
self.stop_after = stop_after
self.count = 0
def __len__(self):
return self.length
def __getitem__(self, idx):
self.count += 1
if self.count == self.stop_after:
raise RuntimeError("Simulated early stopping")
return {
"input_ids": torch.randint(0, 1000, (64,)),
"target_ids": torch.randint(0, 1000, (64,))
}
dataset = EarlyStoppingDataset()
optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
train_config = TrainConfig( train_config = TrainConfig(
dataset=dataset, dataset=early_stopping_dataset,
optimizer=optimizer, optimizer=optimizer,
checkpoint_dir=base_test_env["test_dir"], checkpoint_dir=base_test_env["test_dir"],
n_epoch=2, n_epoch=2,