AstrAI/tests/trainer/test_callbacks.py

51 lines
1.5 KiB
Python

import torch
from astrai.config import *
from astrai.trainer import *
def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated"""
schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20)
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config)
train_config = TrainConfig(
model=base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_size=2,
ckpt_interval=3,
accumulation_steps=1,
max_grad_norm=1.0,
random_seed=42,
device_type=base_test_env["device"],
)
# Create custom callbacks to track calls
callback_calls = []
class TrackingCallback(TrainCallback):
def on_train_begin(self, context):
callback_calls.append("on_train_begin")
def on_batch_end(self, context):
callback_calls.append("on_batch_end")
def on_epoch_end(self, context):
callback_calls.append("on_epoch_end")
trainer = Trainer(train_config, callbacks=[TrackingCallback()])
trainer.train()
# Verify callbacks were called
assert "on_train_begin" in callback_calls
assert "on_batch_end" in callback_calls
assert "on_epoch_end" in callback_calls