test: 统一重构数据集和调度器测试模块

This commit is contained in:
ViperEkura 2025-10-31 20:24:01 +08:00
parent e86328b753
commit cdb47a62dc
2 changed files with 126 additions and 55 deletions

View File

@ -33,7 +33,7 @@ def test_dataset_loader_random_paths(base_test_env):
loaded_dataset = DatasetLoader.load( loaded_dataset = DatasetLoader.load(
train_type="seq", train_type="seq",
load_path=pkl_paths, load_path=pkl_paths,
max_len=64, window_size=64,
) )
assert loaded_dataset is not None assert loaded_dataset is not None
assert len(loaded_dataset) > 0 assert len(loaded_dataset) > 0
@ -60,7 +60,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
dpo_dataset = DatasetLoader.load( dpo_dataset = DatasetLoader.load(
train_type="dpo", train_type="dpo",
load_path=pkl_path, load_path=pkl_path,
max_len=64, window_size=64,
) )
assert dpo_dataset is not None assert dpo_dataset is not None

View File

@ -1,64 +1,135 @@
import torch import torch
import numpy as np import numpy as np
import pytest
from khaosz.config import * from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer.schedule import *
from khaosz.data.data_util import * from khaosz.data.data_util import *
def test_multi_turn_training(base_test_env, multi_turn_dataset):
"""Test training with multi-turn conversation data"""
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
train_config = TrainConfig(
dataset=multi_turn_dataset,
optimizer=optimizer,
checkpoint_dir=base_test_env["test_dir"],
n_epoch=2,
batch_size=2,
checkpoint_interval=3,
accumulation_steps=1,
max_grad_norm=1.0,
random_seed=int(np.random.randint(1000))
)
schedule_config = CosineScheduleConfig(
warmup_steps=50,
total_steps=100
)
train_config.strategy = StrategyFactory.load(
base_test_env["model"],
"sft",
base_test_env["device"],
bos_token_id=2,
eos_token_id=3,
multi_turn=True
)
model_parameter = ModelParameter(
base_test_env["model"],
base_test_env["tokenizer"],
base_test_env["transformer_config"]
)
trainer = Trainer(model_parameter, train_config, schedule_config)
checkpoint = trainer.train()
assert len(checkpoint.loss_list) > 0
def test_schedule_factory_random_configs(): def test_schedule_factory_random_configs():
"""Test scheduler factory with random configurations""" """Test scheduler factory with random configurations"""
schedule_configs = [ # Create a simple model and optimizer for testing
CosineScheduleConfig( model = torch.nn.Linear(10, 2)
warmup_steps=np.random.randint(50, 200), optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
total_steps=np.random.randint(1000, 5000),
min_rate=np.random.uniform(0.01, 0.1) # Test multiple random configurations
), for _ in range(5): # Test 5 random configurations
SGDRScheduleConfig( schedule_configs = [
warmup_steps=np.random.randint(50, 200), CosineScheduleConfig(
cycle_length=np.random.randint(500, 2000), warmup_steps=np.random.randint(50, 200),
t_mult=np.random.randint(1, 3), total_steps=np.random.randint(1000, 5000),
min_rate=np.random.uniform(0.01, 0.1) min_rate=np.random.uniform(0.01, 0.1)
) ),
SGDRScheduleConfig(
warmup_steps=np.random.randint(50, 200),
cycle_length=np.random.randint(500, 2000),
t_mult=np.random.randint(1, 3),
min_rate=np.random.uniform(0.01, 0.1)
)
]
for config in schedule_configs:
# Validate configuration
config.validate()
# Create scheduler using factory
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
# Verify scheduler type
if isinstance(config, CosineScheduleConfig):
assert isinstance(scheduler, CosineScheduler)
assert scheduler.warmup_steps == config.warmup_steps
assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
assert scheduler.min_rate == config.min_rate
elif isinstance(config, SGDRScheduleConfig):
assert isinstance(scheduler, SGDRScheduler)
assert scheduler.warmup_steps == config.warmup_steps
assert scheduler.cycle_length == config.cycle_length
assert scheduler.t_mult == config.t_mult
assert scheduler.min_rate == config.min_rate
# Test scheduler state dict functionality
state_dict = scheduler.state_dict()
assert 'warmup_steps' in state_dict
assert 'min_rate' in state_dict
# Test scheduler step functionality
initial_lr = scheduler.get_last_lr()
scheduler.step()
new_lr = scheduler.get_last_lr()
# Learning rate should change after step, or if it's the first step,
# the epoch counter should increment
assert initial_lr != new_lr or scheduler.last_epoch > -1
def test_schedule_factory_edge_cases():
"""Test scheduler factory with edge cases and boundary conditions"""
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
# Test edge cases for CosineScheduleConfig
edge_cases = [
# Minimal warmup and steps
CosineScheduleConfig(warmup_steps=1, total_steps=10, min_rate=0.01),
# Large values
CosineScheduleConfig(warmup_steps=1000, total_steps=10000, min_rate=0.5),
# Zero min_rate (edge case)
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
] ]
# todo
for config in edge_cases:
config.validate()
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
assert scheduler is not None
# Test multiple steps
for _ in range(10):
scheduler.step()
def test_schedule_factory_invalid_configs():
"""Test scheduler factory with invalid configurations"""
# Test invalid configurations that should raise errors
invalid_configs = [
# Negative warmup steps
CosineScheduleConfig(warmup_steps=-10, total_steps=1000, min_rate=0.1),
# Total steps less than warmup steps
CosineScheduleConfig(warmup_steps=500, total_steps=400, min_rate=0.1),
# Invalid min_rate
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=-0.1),
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=1.1),
]
for config in invalid_configs:
with pytest.raises(ValueError):
config.validate()
def test_schedule_factory_state_persistence():
"""Test scheduler state persistence (save/load)"""
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
# Take a few steps
for _ in range(5):
scheduler.step()
# Save state
state_dict = scheduler.state_dict()
# Create new scheduler and load state
new_scheduler = SchedulerFactory.load_scheduler(optimizer, config)
new_scheduler.load_state_dict(state_dict)
# Verify states match
assert scheduler.last_epoch == new_scheduler.last_epoch
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()