test: 统一重构数据集和调度器测试模块
This commit is contained in:
parent
e86328b753
commit
cdb47a62dc
|
|
@ -33,7 +33,7 @@ def test_dataset_loader_random_paths(base_test_env):
|
||||||
loaded_dataset = DatasetLoader.load(
|
loaded_dataset = DatasetLoader.load(
|
||||||
train_type="seq",
|
train_type="seq",
|
||||||
load_path=pkl_paths,
|
load_path=pkl_paths,
|
||||||
max_len=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
assert loaded_dataset is not None
|
assert loaded_dataset is not None
|
||||||
assert len(loaded_dataset) > 0
|
assert len(loaded_dataset) > 0
|
||||||
|
|
@ -60,7 +60,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
dpo_dataset = DatasetLoader.load(
|
dpo_dataset = DatasetLoader.load(
|
||||||
train_type="dpo",
|
train_type="dpo",
|
||||||
load_path=pkl_path,
|
load_path=pkl_path,
|
||||||
max_len=64,
|
window_size=64,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dpo_dataset is not None
|
assert dpo_dataset is not None
|
||||||
|
|
|
||||||
|
|
@ -1,53 +1,21 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from khaosz.config import *
|
from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer.schedule import *
|
||||||
from khaosz.data.data_util import *
|
from khaosz.data.data_util import *
|
||||||
|
|
||||||
def test_multi_turn_training(base_test_env, multi_turn_dataset):
|
|
||||||
"""Test training with multi-turn conversation data"""
|
|
||||||
optimizer = torch.optim.AdamW(base_test_env["model"].parameters())
|
|
||||||
train_config = TrainConfig(
|
|
||||||
dataset=multi_turn_dataset,
|
|
||||||
optimizer=optimizer,
|
|
||||||
checkpoint_dir=base_test_env["test_dir"],
|
|
||||||
n_epoch=2,
|
|
||||||
batch_size=2,
|
|
||||||
checkpoint_interval=3,
|
|
||||||
accumulation_steps=1,
|
|
||||||
max_grad_norm=1.0,
|
|
||||||
random_seed=int(np.random.randint(1000))
|
|
||||||
)
|
|
||||||
|
|
||||||
schedule_config = CosineScheduleConfig(
|
|
||||||
warmup_steps=50,
|
|
||||||
total_steps=100
|
|
||||||
)
|
|
||||||
|
|
||||||
train_config.strategy = StrategyFactory.load(
|
|
||||||
base_test_env["model"],
|
|
||||||
"sft",
|
|
||||||
base_test_env["device"],
|
|
||||||
bos_token_id=2,
|
|
||||||
eos_token_id=3,
|
|
||||||
multi_turn=True
|
|
||||||
)
|
|
||||||
|
|
||||||
model_parameter = ModelParameter(
|
|
||||||
base_test_env["model"],
|
|
||||||
base_test_env["tokenizer"],
|
|
||||||
base_test_env["transformer_config"]
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(model_parameter, train_config, schedule_config)
|
|
||||||
checkpoint = trainer.train()
|
|
||||||
|
|
||||||
assert len(checkpoint.loss_list) > 0
|
|
||||||
|
|
||||||
def test_schedule_factory_random_configs():
|
def test_schedule_factory_random_configs():
|
||||||
"""Test scheduler factory with random configurations"""
|
"""Test scheduler factory with random configurations"""
|
||||||
|
|
||||||
|
# Create a simple model and optimizer for testing
|
||||||
|
model = torch.nn.Linear(10, 2)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# Test multiple random configurations
|
||||||
|
for _ in range(5): # Test 5 random configurations
|
||||||
schedule_configs = [
|
schedule_configs = [
|
||||||
CosineScheduleConfig(
|
CosineScheduleConfig(
|
||||||
warmup_steps=np.random.randint(50, 200),
|
warmup_steps=np.random.randint(50, 200),
|
||||||
|
|
@ -61,4 +29,107 @@ def test_schedule_factory_random_configs():
|
||||||
min_rate=np.random.uniform(0.01, 0.1)
|
min_rate=np.random.uniform(0.01, 0.1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
# todo
|
|
||||||
|
for config in schedule_configs:
|
||||||
|
# Validate configuration
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
# Create scheduler using factory
|
||||||
|
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||||
|
|
||||||
|
# Verify scheduler type
|
||||||
|
if isinstance(config, CosineScheduleConfig):
|
||||||
|
assert isinstance(scheduler, CosineScheduler)
|
||||||
|
assert scheduler.warmup_steps == config.warmup_steps
|
||||||
|
assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps
|
||||||
|
assert scheduler.min_rate == config.min_rate
|
||||||
|
elif isinstance(config, SGDRScheduleConfig):
|
||||||
|
assert isinstance(scheduler, SGDRScheduler)
|
||||||
|
assert scheduler.warmup_steps == config.warmup_steps
|
||||||
|
assert scheduler.cycle_length == config.cycle_length
|
||||||
|
assert scheduler.t_mult == config.t_mult
|
||||||
|
assert scheduler.min_rate == config.min_rate
|
||||||
|
|
||||||
|
# Test scheduler state dict functionality
|
||||||
|
state_dict = scheduler.state_dict()
|
||||||
|
assert 'warmup_steps' in state_dict
|
||||||
|
assert 'min_rate' in state_dict
|
||||||
|
|
||||||
|
# Test scheduler step functionality
|
||||||
|
initial_lr = scheduler.get_last_lr()
|
||||||
|
scheduler.step()
|
||||||
|
new_lr = scheduler.get_last_lr()
|
||||||
|
|
||||||
|
# Learning rate should change after step, or if it's the first step,
|
||||||
|
# the epoch counter should increment
|
||||||
|
assert initial_lr != new_lr or scheduler.last_epoch > -1
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_factory_edge_cases():
|
||||||
|
"""Test scheduler factory with edge cases and boundary conditions"""
|
||||||
|
|
||||||
|
model = torch.nn.Linear(10, 2)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# Test edge cases for CosineScheduleConfig
|
||||||
|
edge_cases = [
|
||||||
|
# Minimal warmup and steps
|
||||||
|
CosineScheduleConfig(warmup_steps=1, total_steps=10, min_rate=0.01),
|
||||||
|
# Large values
|
||||||
|
CosineScheduleConfig(warmup_steps=1000, total_steps=10000, min_rate=0.5),
|
||||||
|
# Zero min_rate (edge case)
|
||||||
|
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
for config in edge_cases:
|
||||||
|
config.validate()
|
||||||
|
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||||
|
assert scheduler is not None
|
||||||
|
|
||||||
|
# Test multiple steps
|
||||||
|
for _ in range(10):
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_factory_invalid_configs():
|
||||||
|
"""Test scheduler factory with invalid configurations"""
|
||||||
|
|
||||||
|
# Test invalid configurations that should raise errors
|
||||||
|
invalid_configs = [
|
||||||
|
# Negative warmup steps
|
||||||
|
CosineScheduleConfig(warmup_steps=-10, total_steps=1000, min_rate=0.1),
|
||||||
|
# Total steps less than warmup steps
|
||||||
|
CosineScheduleConfig(warmup_steps=500, total_steps=400, min_rate=0.1),
|
||||||
|
# Invalid min_rate
|
||||||
|
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=-0.1),
|
||||||
|
CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=1.1),
|
||||||
|
]
|
||||||
|
|
||||||
|
for config in invalid_configs:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_factory_state_persistence():
|
||||||
|
"""Test scheduler state persistence (save/load)"""
|
||||||
|
|
||||||
|
model = torch.nn.Linear(10, 2)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1)
|
||||||
|
scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||||
|
|
||||||
|
# Take a few steps
|
||||||
|
for _ in range(5):
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# Save state
|
||||||
|
state_dict = scheduler.state_dict()
|
||||||
|
|
||||||
|
# Create new scheduler and load state
|
||||||
|
new_scheduler = SchedulerFactory.load_scheduler(optimizer, config)
|
||||||
|
new_scheduler.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Verify states match
|
||||||
|
assert scheduler.last_epoch == new_scheduler.last_epoch
|
||||||
|
assert scheduler.get_last_lr() == new_scheduler.get_last_lr()
|
||||||
Loading…
Reference in New Issue