feat(tests): 重构测试环境, 便于pickle 序列化
This commit is contained in:
parent
2ccd7bd583
commit
9d5aa952e0
|
|
@ -15,6 +15,66 @@ from khaosz.trainer.data_util import *
|
|||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||
self.length = length or int(np.random.randint(100, 200))
|
||||
self.max_length = max_length
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
||||
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
}
|
||||
|
||||
|
||||
class MultiTurnDataset(Dataset):
|
||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||
self.length = length or int(np.random.randint(100, 200))
|
||||
self.max_length = max_length
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
loss_mask = build_loss_mask(input_ids, 0, 1)
|
||||
attn_mask = build_attention_mask(input_ids, 2, True)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"target_ids": target_ids,
|
||||
"loss_mask": loss_mask,
|
||||
"attn_mask": attn_mask,
|
||||
}
|
||||
|
||||
|
||||
class EarlyStoppingDataset(Dataset):
|
||||
def __init__(self, length=10, stop_after=5):
|
||||
self.length = length
|
||||
self.stop_after = stop_after
|
||||
self.count = 0
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
self.count += 1
|
||||
if self.count == self.stop_after:
|
||||
raise RuntimeError("Simulated early stopping")
|
||||
|
||||
return {
|
||||
"input_ids": torch.randint(0, 1000, (64,)),
|
||||
"target_ids": torch.randint(0, 1000, (64,))
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_test_env(request: pytest.FixtureRequest):
|
||||
func_name = request.function.__name__
|
||||
|
|
@ -60,49 +120,15 @@ def base_test_env(request: pytest.FixtureRequest):
|
|||
|
||||
@pytest.fixture
|
||||
def random_dataset():
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||
self.length = length or int(np.random.randint(100, 200))
|
||||
self.max_length = max_length
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
||||
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
}
|
||||
|
||||
dataset = RandomDataset()
|
||||
|
||||
yield dataset
|
||||
|
||||
@pytest.fixture
|
||||
def multi_turn_dataset():
|
||||
class MultiTurnDataset(Dataset):
|
||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||
self.length = length or int(np.random.randint(100, 200))
|
||||
self.max_length = max_length
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||
loss_mask = build_loss_mask(input_ids, 0, 1)
|
||||
attn_mask = build_attention_mask(input_ids, 2, True)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"target_ids": target_ids,
|
||||
"loss_mask": loss_mask,
|
||||
"attn_mask": attn_mask,
|
||||
}
|
||||
|
||||
dataset = MultiTurnDataset()
|
||||
|
||||
yield dataset
|
||||
|
||||
@pytest.fixture
|
||||
def early_stopping_dataset():
|
||||
dataset = EarlyStoppingDataset()
|
||||
yield dataset
|
||||
|
|
@ -27,7 +27,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
# Create custom callbacks to track calls
|
||||
callback_calls = []
|
||||
|
||||
class TrackingCallback(TrainerCallback):
|
||||
class TrackingCallback(TrainCallback):
|
||||
def on_train_begin(self, trainer, **kwargs):
|
||||
callback_calls.append('on_train_begin')
|
||||
|
||||
|
|
|
|||
|
|
@ -5,32 +5,12 @@ from khaosz.core import *
|
|||
from khaosz.trainer import *
|
||||
from khaosz.trainer.data_util import *
|
||||
|
||||
def test_early_stopping_simulation(base_test_env):
|
||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||
"""Simulate early stopping behavior"""
|
||||
class EarlyStoppingDataset(Dataset):
|
||||
def __init__(self, length=10, stop_after=5):
|
||||
self.length = length
|
||||
self.stop_after = stop_after
|
||||
self.count = 0
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
self.count += 1
|
||||
if self.count == self.stop_after:
|
||||
raise RuntimeError("Simulated early stopping")
|
||||
|
||||
return {
|
||||
"input_ids": torch.randint(0, 1000, (64,)),
|
||||
"target_ids": torch.randint(0, 1000, (64,))
|
||||
}
|
||||
|
||||
dataset = EarlyStoppingDataset()
|
||||
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
train_config = TrainConfig(
|
||||
dataset=dataset,
|
||||
dataset=early_stopping_dataset,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=2,
|
||||
|
|
|
|||
Loading…
Reference in New Issue