139 lines
4.9 KiB
Python
139 lines
4.9 KiB
Python
import torch
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from khaosz.config import *
|
|
from khaosz.trainer.schedule import *
|
|
from khaosz.data.dataset import *
|
|
|
|
|
|
def test_schedule_factory_random_configs():
|
|
"""Test scheduler factory with random configurations"""
|
|
|
|
# Create a simple model and optimizer for testing
|
|
model = torch.nn.Linear(10, 2)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
|
|
|
# Test multiple random configurations
|
|
for _ in range(5): # Test 5 random configurations
|
|
schedule_configs = [
|
|
CosineScheduleConfig(
|
|
warmup_steps=np.random.randint(50, 200),
|
|
total_steps=np.random.randint(1000, 5000),
|
|
min_rate=np.random.uniform(0.01, 0.1),
|
|
),
|
|
SGDRScheduleConfig(
|
|
warmup_steps=np.random.randint(50, 200),
|
|
cycle_length=np.random.randint(500, 2000),
|
|
t_mult=np.random.randint(1, 3),
|
|
min_rate=np.random.uniform(0.01, 0.1),
|
|
),
|
|
]
|
|
|
|
for config in schedule_configs:
|
|
# Validate configuration
|
|
config.validate()
|
|
|
|
# Create scheduler using factory
|
|
scheduler = SchedulerFactory.load(optimizer, config)
|
|
|
|
# Verify scheduler type
|
|
if isinstance(config, CosineScheduleConfig):
|
|
assert isinstance(scheduler, CosineScheduler)
|
|
assert scheduler.warmup_steps == config.warmup_steps
|
|
assert (
|
|
scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
|
)
|
|
assert scheduler.min_rate == config.min_rate
|
|
elif isinstance(config, SGDRScheduleConfig):
|
|
assert isinstance(scheduler, SGDRScheduler)
|
|
assert scheduler.warmup_steps == config.warmup_steps
|
|
assert scheduler.cycle_length == config.cycle_length
|
|
assert scheduler.t_mult == config.t_mult
|
|
assert scheduler.min_rate == config.min_rate
|
|
|
|
# Test scheduler state dict functionality
|
|
state_dict = scheduler.state_dict()
|
|
assert "warmup_steps" in state_dict
|
|
assert "min_rate" in state_dict
|
|
|
|
# Test scheduler step functionality
|
|
initial_lr = scheduler.get_last_lr()
|
|
scheduler.step()
|
|
new_lr = scheduler.get_last_lr()
|
|
|
|
# Learning rate should change after step, or if it's the first step,
|
|
# the epoch counter should increment
|
|
assert initial_lr != new_lr or scheduler.last_epoch > -1
|
|
|
|
|
|
def test_schedule_factory_edge_cases():
|
|
"""Test scheduler factory with edge cases and boundary conditions"""
|
|
|
|
model = torch.nn.Linear(10, 2)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
|
|
|
# Test edge cases for CosineScheduleConfig
|
|
edge_cases = [
|
|
# Minimal warmup and steps
|
|
CosineScheduleConfig(warmup_steps=1, total_steps=10, min_rate=0.01),
|
|
# Large values
|
|
CosineScheduleConfig(warmup_steps=1000, total_steps=10000, min_rate=0.5),
|
|
# Zero min_rate (edge case)
|
|
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
|
|
]
|
|
|
|
for config in edge_cases:
|
|
config.validate()
|
|
scheduler = SchedulerFactory.load(optimizer, config)
|
|
assert scheduler is not None
|
|
|
|
# Test multiple steps
|
|
for _ in range(10):
|
|
scheduler.step()
|
|
|
|
|
|
def test_schedule_factory_invalid_configs():
|
|
"""Test scheduler factory with invalid configurations"""
|
|
|
|
# Test invalid configurations that should raise errors
|
|
invalid_configs = [
|
|
# Negative warmup steps
|
|
{"warmup_steps": -10, "total_steps": 1000, "min_rate": 0.1},
|
|
# Total steps less than warmup steps
|
|
{"warmup_steps": 500, "total_steps": 400, "min_rate": 0.1},
|
|
# Invalid min_rate
|
|
{"warmup_steps": 100, "total_steps": 1000, "min_rate": -0.1},
|
|
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 1.1},
|
|
]
|
|
|
|
for kwargs in invalid_configs:
|
|
with pytest.raises(ValueError):
|
|
config = CosineScheduleConfig(**kwargs)
|
|
config.validate()
|
|
|
|
|
|
def test_schedule_factory_state_persistence():
|
|
"""Test scheduler state persistence (save/load)"""
|
|
|
|
model = torch.nn.Linear(10, 2)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
|
|
|
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
|
scheduler = SchedulerFactory.load(optimizer, config)
|
|
|
|
# Take a few steps
|
|
for _ in range(5):
|
|
scheduler.step()
|
|
|
|
# Save state
|
|
state_dict = scheduler.state_dict()
|
|
|
|
# Create new scheduler and load state
|
|
new_scheduler = SchedulerFactory.load(optimizer, config)
|
|
new_scheduler.load_state_dict(state_dict)
|
|
|
|
# Verify states match
|
|
assert scheduler.last_epoch == new_scheduler.last_epoch
|
|
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()
|