test: 统一重构数据集和调度器测试模块
This commit is contained in:
parent
e86328b753
commit
cdb47a62dc
|
|
@ -33,7 +33,7 @@ def test_dataset_loader_random_paths(base_test_env):
|
|||
loaded_dataset = DatasetLoader.load(
|
||||
train_type="seq",
|
||||
load_path=pkl_paths,
|
||||
max_len=64,
|
||||
window_size=64,
|
||||
)
|
||||
assert loaded_dataset is not None
|
||||
assert len(loaded_dataset) > 0
|
||||
|
|
@ -60,7 +60,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
|||
dpo_dataset = DatasetLoader.load(
|
||||
train_type="dpo",
|
||||
load_path=pkl_path,
|
||||
max_len=64,
|
||||
window_size=64,
|
||||
)
|
||||
|
||||
assert dpo_dataset is not None
|
||||
|
|
|
|||
|
|
@ -1,53 +1,21 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.trainer.schedule import *
|
||||
from khaosz.data.data_util import *
|
||||
|
||||
def test_multi_turn_training(base_test_env, multi_turn_dataset):
|
||||
"""Test training with multi-turn conversation data"""
|
||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
||||
train_config = TrainConfig(
|
||||
dataset=multi_turn_dataset,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=base_test_env["test_dir"],
|
||||
n_epoch=2,
|
||||
batch_size=2,
|
||||
checkpoint_interval=3,
|
||||
accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=int(np.random.randint(1000))
|
||||
)
|
||||
|
||||
schedule_config = CosineScheduleConfig(
|
||||
warmup_steps=50,
|
||||
total_steps=100
|
||||
)
|
||||
|
||||
train_config.strategy = StrategyFactory.load(
|
||||
base_test_env["model"],
|
||||
"sft",
|
||||
base_test_env["device"],
|
||||
bos_token_id=2,
|
||||
eos_token_id=3,
|
||||
multi_turn=True
|
||||
)
|
||||
|
||||
model_parameter = ModelParameter(
|
||||
base_test_env["model"],
|
||||
base_test_env["tokenizer"],
|
||||
base_test_env["transformer_config"]
|
||||
)
|
||||
|
||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
||||
checkpoint = trainer.train()
|
||||
|
||||
assert len(checkpoint.loss_list) > 0
|
||||
|
||||
def test_schedule_factory_random_configs():
|
||||
"""Test scheduler factory with random configurations"""
|
||||
|
||||
# Create a simple model and optimizer for testing
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||
|
||||
# Test multiple random configurations
|
||||
for _ in range(5): # Test 5 random configurations
|
||||
schedule_configs = [
|
||||
CosineScheduleConfig(
|
||||
warmup_steps=np.random.randint(50, 200),
|
||||
|
|
@ -61,4 +29,107 @@ def test_schedule_factory_random_configs():
|
|||
min_rate=np.random.uniform(0.01, 0.1)
|
||||
)
|
||||
]
|
||||
# todo
|
||||
|
||||
for config in schedule_configs:
|
||||
# Validate configuration
|
||||
config.validate()
|
||||
|
||||
# Create scheduler using factory
|
||||
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||
|
||||
# Verify scheduler type
|
||||
if isinstance(config, CosineScheduleConfig):
|
||||
assert isinstance(scheduler, CosineScheduler)
|
||||
assert scheduler.warmup_steps == config.warmup_steps
|
||||
assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
||||
assert scheduler.min_rate == config.min_rate
|
||||
elif isinstance(config, SGDRScheduleConfig):
|
||||
assert isinstance(scheduler, SGDRScheduler)
|
||||
assert scheduler.warmup_steps == config.warmup_steps
|
||||
assert scheduler.cycle_length == config.cycle_length
|
||||
assert scheduler.t_mult == config.t_mult
|
||||
assert scheduler.min_rate == config.min_rate
|
||||
|
||||
# Test scheduler state dict functionality
|
||||
state_dict = scheduler.state_dict()
|
||||
assert 'warmup_steps' in state_dict
|
||||
assert 'min_rate' in state_dict
|
||||
|
||||
# Test scheduler step functionality
|
||||
initial_lr = scheduler.get_last_lr()
|
||||
scheduler.step()
|
||||
new_lr = scheduler.get_last_lr()
|
||||
|
||||
# Learning rate should change after step, or if it's the first step,
|
||||
# the epoch counter should increment
|
||||
assert initial_lr != new_lr or scheduler.last_epoch > -1
|
||||
|
||||
|
||||
def test_schedule_factory_edge_cases():
|
||||
"""Test scheduler factory with edge cases and boundary conditions"""
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||
|
||||
# Test edge cases for CosineScheduleConfig
|
||||
edge_cases = [
|
||||
# Minimal warmup and steps
|
||||
CosineScheduleConfig(warmup_steps=1, total_steps=10, min_rate=0.01),
|
||||
# Large values
|
||||
CosineScheduleConfig(warmup_steps=1000, total_steps=10000, min_rate=0.5),
|
||||
# Zero min_rate (edge case)
|
||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
|
||||
]
|
||||
|
||||
for config in edge_cases:
|
||||
config.validate()
|
||||
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||
assert scheduler is not None
|
||||
|
||||
# Test multiple steps
|
||||
for _ in range(10):
|
||||
scheduler.step()
|
||||
|
||||
|
||||
def test_schedule_factory_invalid_configs():
|
||||
"""Test scheduler factory with invalid configurations"""
|
||||
|
||||
# Test invalid configurations that should raise errors
|
||||
invalid_configs = [
|
||||
# Negative warmup steps
|
||||
CosineScheduleConfig(warmup_steps=-10, total_steps=1000, min_rate=0.1),
|
||||
# Total steps less than warmup steps
|
||||
CosineScheduleConfig(warmup_steps=500, total_steps=400, min_rate=0.1),
|
||||
# Invalid min_rate
|
||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=-0.1),
|
||||
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=1.1),
|
||||
]
|
||||
|
||||
for config in invalid_configs:
|
||||
with pytest.raises(ValueError):
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_schedule_factory_state_persistence():
|
||||
"""Test scheduler state persistence (save/load)"""
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||
|
||||
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
||||
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||
|
||||
# Take a few steps
|
||||
for _ in range(5):
|
||||
scheduler.step()
|
||||
|
||||
# Save state
|
||||
state_dict = scheduler.state_dict()
|
||||
|
||||
# Create new scheduler and load state
|
||||
new_scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||
new_scheduler.load_state_dict(state_dict)
|
||||
|
||||
# Verify states match
|
||||
assert scheduler.last_epoch == new_scheduler.last_epoch
|
||||
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()
|
||||
Loading…
Reference in New Issue