157 lines
5.5 KiB
Python
157 lines
5.5 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler
|
|
|
|
|
|
def test_schedule_factory_random_configs():
|
|
"""Test scheduler factory with random configurations"""
|
|
|
|
# Create a simple model and optimizer for testing
|
|
model = torch.nn.Linear(10, 2)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
|
|
|
# Test multiple random configurations
|
|
for _ in range(5): # Test 5 random configurations
|
|
# Test multiple random configurations
|
|
cosine_params = {
|
|
"schedule_type": "cosine",
|
|
"warmup_steps": np.random.randint(50, 200),
|
|
"total_steps": np.random.randint(1000, 5000),
|
|
"min_rate": np.random.uniform(0.01, 0.1),
|
|
}
|
|
sgdr_params = {
|
|
"schedule_type": "sgdr",
|
|
"warmup_steps": np.random.randint(50, 200),
|
|
"cycle_length": np.random.randint(500, 2000),
|
|
"t_mult": np.random.randint(1, 3),
|
|
"min_rate": np.random.uniform(0.01, 0.1),
|
|
}
|
|
for params in [cosine_params, sgdr_params]:
|
|
schedule_type = params["schedule_type"]
|
|
# Convert parameters for scheduler constructor
|
|
if schedule_type == "cosine":
|
|
warmup_steps = params["warmup_steps"]
|
|
total_steps = params["total_steps"]
|
|
min_rate = params["min_rate"]
|
|
lr_decay_steps = total_steps - warmup_steps
|
|
scheduler = SchedulerFactory.create(
|
|
optimizer,
|
|
schedule_type,
|
|
warmup_steps=warmup_steps,
|
|
lr_decay_steps=lr_decay_steps,
|
|
min_rate=min_rate,
|
|
)
|
|
assert isinstance(scheduler, CosineScheduler)
|
|
assert scheduler.warmup_steps == warmup_steps
|
|
assert scheduler.lr_decay_steps == lr_decay_steps
|
|
assert scheduler.min_rate == min_rate
|
|
elif schedule_type == "sgdr":
|
|
warmup_steps = params["warmup_steps"]
|
|
cycle_length = params["cycle_length"]
|
|
t_mult = params["t_mult"]
|
|
min_rate = params["min_rate"]
|
|
scheduler = SchedulerFactory.create(
|
|
optimizer,
|
|
schedule_type,
|
|
warmup_steps=warmup_steps,
|
|
cycle_length=cycle_length,
|
|
t_mult=t_mult,
|
|
min_rate=min_rate,
|
|
)
|
|
assert isinstance(scheduler, SGDRScheduler)
|
|
assert scheduler.warmup_steps == warmup_steps
|
|
assert scheduler.cycle_length == cycle_length
|
|
assert scheduler.t_mult == t_mult
|
|
assert scheduler.min_rate == min_rate
|
|
|
|
# Test scheduler state dict functionality
|
|
state_dict = scheduler.state_dict()
|
|
assert "warmup_steps" in state_dict
|
|
assert "min_rate" in state_dict
|
|
|
|
# Test scheduler step functionality
|
|
initial_lr = scheduler.get_last_lr()
|
|
scheduler.step()
|
|
new_lr = scheduler.get_last_lr()
|
|
|
|
# Learning rate should change after step, or if it's the first step,
|
|
# the epoch counter should increment
|
|
assert initial_lr != new_lr or scheduler.last_epoch > -1
|
|
|
|
|
|
def test_schedule_factory_edge_cases():
|
|
"""Test scheduler factory with edge cases and boundary conditions"""
|
|
|
|
model = torch.nn.Linear(10, 2)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
|
|
|
# Test edge cases for CosineScheduleConfig
|
|
edge_cases = [
|
|
# Minimal warmup and steps
|
|
{"warmup_steps": 1, "total_steps": 10, "min_rate": 0.01},
|
|
# Large values
|
|
{"warmup_steps": 1000, "total_steps": 10000, "min_rate": 0.5},
|
|
# Zero min_rate (edge case)
|
|
{"warmup_steps": 100, "total_steps": 1000, "min_rate": 0.0},
|
|
]
|
|
|
|
for params in edge_cases:
|
|
warmup_steps = params["warmup_steps"]
|
|
total_steps = params["total_steps"]
|
|
min_rate = params["min_rate"]
|
|
lr_decay_steps = total_steps - warmup_steps
|
|
scheduler = SchedulerFactory.create(
|
|
optimizer,
|
|
"cosine",
|
|
warmup_steps=warmup_steps,
|
|
lr_decay_steps=lr_decay_steps,
|
|
min_rate=min_rate,
|
|
)
|
|
assert scheduler is not None
|
|
|
|
# Test multiple steps
|
|
for _ in range(10):
|
|
scheduler.step()
|
|
|
|
|
|
def test_schedule_factory_state_persistence():
|
|
"""Test scheduler state persistence (save/load)"""
|
|
|
|
model = torch.nn.Linear(10, 2)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
|
|
|
# Create scheduler directly with parameters
|
|
warmup_steps = 100
|
|
total_steps = 1000
|
|
min_rate = 0.1
|
|
lr_decay_steps = total_steps - warmup_steps
|
|
scheduler = SchedulerFactory.create(
|
|
optimizer,
|
|
"cosine",
|
|
warmup_steps=warmup_steps,
|
|
lr_decay_steps=lr_decay_steps,
|
|
min_rate=min_rate,
|
|
)
|
|
|
|
# Take a few steps
|
|
for _ in range(5):
|
|
scheduler.step()
|
|
|
|
# Save state
|
|
state_dict = scheduler.state_dict()
|
|
|
|
# Create new scheduler with same parameters
|
|
new_scheduler = SchedulerFactory.create(
|
|
optimizer,
|
|
"cosine",
|
|
warmup_steps=warmup_steps,
|
|
lr_decay_steps=lr_decay_steps,
|
|
min_rate=min_rate,
|
|
)
|
|
new_scheduler.load_state_dict(state_dict)
|
|
|
|
# Verify states match
|
|
assert scheduler.last_epoch == new_scheduler.last_epoch
|
|
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()
|