AstrAI/tests/test_early_stopping.py

63 lines
1.9 KiB
Python

import torch
from torch.utils.data import Dataset
from khaosz.core import *
from khaosz.trainer import *
from khaosz.trainer.data_util import *
def test_early_stopping_simulation(base_test_env):
"""Simulate early stopping behavior"""
class EarlyStoppingDataset(Dataset):
def __init__(self, length=10, stop_after=5):
self.length = length
self.stop_after = stop_after
self.count = 0
def __len__(self):
return self.length
def __getitem__(self, idx):
self.count += 1
if self.count == self.stop_after:
raise RuntimeError("Simulated early stopping")
return {
"input_ids": torch.randint(0, 1000, (64,)),
"target_ids": torch.randint(0, 1000, (64,))
}
dataset = EarlyStoppingDataset()
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
train_config = TrainConfig(
dataset=dataset,
optimizer=optimizer,
checkpoint_dir=base_test_env["test_dir"],
n_epoch=2,
batch_size=2,
checkpoint_interval=1,
accumulation_steps=1,
max_grad_norm=1.0,
random_seed=42
)
train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq")
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],
base_test_env["transformer_config"]
)
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
trainer = Trainer(model_parameter, train_config, schedule_config)
# Should handle early stopping gracefully
checkpoint = None
try:
checkpoint = trainer.train()
assert len(checkpoint.loss_list) == 2
except Exception:
# Handle any exceptions
pass
checkpoint = trainer.train(checkpoint)
assert len(checkpoint.loss_list) == 10