feat(tests): 重构测试环境, 便于pickle 序列化
This commit is contained in:
parent
2ccd7bd583
commit
9d5aa952e0
|
|
@ -15,6 +15,66 @@ from khaosz.trainer.data_util import *
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
|
|
||||||
|
|
||||||
|
class RandomDataset(Dataset):
|
||||||
|
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||||
|
self.length = length or int(np.random.randint(100, 200))
|
||||||
|
self.max_length = max_length
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return {
|
||||||
|
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
||||||
|
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTurnDataset(Dataset):
|
||||||
|
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||||
|
self.length = length or int(np.random.randint(100, 200))
|
||||||
|
self.max_length = max_length
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||||
|
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
||||||
|
loss_mask = build_loss_mask(input_ids, 0, 1)
|
||||||
|
attn_mask = build_attention_mask(input_ids, 2, True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"target_ids": target_ids,
|
||||||
|
"loss_mask": loss_mask,
|
||||||
|
"attn_mask": attn_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStoppingDataset(Dataset):
|
||||||
|
def __init__(self, length=10, stop_after=5):
|
||||||
|
self.length = length
|
||||||
|
self.stop_after = stop_after
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
self.count += 1
|
||||||
|
if self.count == self.stop_after:
|
||||||
|
raise RuntimeError("Simulated early stopping")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": torch.randint(0, 1000, (64,)),
|
||||||
|
"target_ids": torch.randint(0, 1000, (64,))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def base_test_env(request: pytest.FixtureRequest):
|
def base_test_env(request: pytest.FixtureRequest):
|
||||||
func_name = request.function.__name__
|
func_name = request.function.__name__
|
||||||
|
|
@ -60,49 +120,15 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def random_dataset():
|
def random_dataset():
|
||||||
class RandomDataset(Dataset):
|
|
||||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
|
||||||
self.length = length or int(np.random.randint(100, 200))
|
|
||||||
self.max_length = max_length
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return {
|
|
||||||
"input_ids": torch.randint(0, self.vocab_size, (self.max_length,)),
|
|
||||||
"target_ids": torch.randint(0, self.vocab_size, (self.max_length,))
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset = RandomDataset()
|
dataset = RandomDataset()
|
||||||
|
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def multi_turn_dataset():
|
def multi_turn_dataset():
|
||||||
class MultiTurnDataset(Dataset):
|
|
||||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
|
||||||
self.length = length or int(np.random.randint(100, 200))
|
|
||||||
self.max_length = max_length
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
input_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
|
||||||
target_ids = torch.randint(0, self.vocab_size, (self.max_length,))
|
|
||||||
loss_mask = build_loss_mask(input_ids, 0, 1)
|
|
||||||
attn_mask = build_attention_mask(input_ids, 2, True)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"target_ids": target_ids,
|
|
||||||
"loss_mask": loss_mask,
|
|
||||||
"attn_mask": attn_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset = MultiTurnDataset()
|
dataset = MultiTurnDataset()
|
||||||
|
yield dataset
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def early_stopping_dataset():
|
||||||
|
dataset = EarlyStoppingDataset()
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
@ -27,7 +27,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
# Create custom callbacks to track calls
|
# Create custom callbacks to track calls
|
||||||
callback_calls = []
|
callback_calls = []
|
||||||
|
|
||||||
class TrackingCallback(TrainerCallback):
|
class TrackingCallback(TrainCallback):
|
||||||
def on_train_begin(self, trainer, **kwargs):
|
def on_train_begin(self, trainer, **kwargs):
|
||||||
callback_calls.append('on_train_begin')
|
callback_calls.append('on_train_begin')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,32 +5,12 @@ from khaosz.core import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.trainer.data_util import *
|
from khaosz.trainer.data_util import *
|
||||||
|
|
||||||
def test_early_stopping_simulation(base_test_env):
|
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
"""Simulate early stopping behavior"""
|
"""Simulate early stopping behavior"""
|
||||||
class EarlyStoppingDataset(Dataset):
|
|
||||||
def __init__(self, length=10, stop_after=5):
|
|
||||||
self.length = length
|
|
||||||
self.stop_after = stop_after
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
self.count += 1
|
|
||||||
if self.count == self.stop_after:
|
|
||||||
raise RuntimeError("Simulated early stopping")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": torch.randint(0, 1000, (64,)),
|
|
||||||
"target_ids": torch.randint(0, 1000, (64,))
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset = EarlyStoppingDataset()
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
dataset=dataset,
|
dataset=early_stopping_dataset,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
checkpoint_dir=base_test_env["test_dir"],
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue