AstrAI/tests/trainer/test_early_stopping.py

49 lines
1.4 KiB
Python

import os
import torch
import numpy as np
from khaosz.config import *
from khaosz.trainer import *
from khaosz.data.serialization import Checkpoint
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig(
strategy="seq",
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
model=base_test_env["model"],
dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"],
n_epoch=2,
batch_size=2,
ckpt_interval=1,
accumulation_steps=2,
random_seed=np.random.randint(1e4),
device_type=base_test_env["device"],
)
trainer = Trainer(train_config)
# Should handle early stopping gracefully
checkpoint = None
try:
checkpoint = trainer.train()
except Exception:
# Handle any exceptions
pass
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
checkpoint = Checkpoint.load(load_dir)
trainer.train(checkpoint)
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10