diff --git a/tests/test_dataset_loader.py b/tests/test_dataset_loader.py index b45f51a..a5b54e8 100644 --- a/tests/test_dataset_loader.py +++ b/tests/test_dataset_loader.py @@ -33,7 +33,7 @@ def test_dataset_loader_random_paths(base_test_env): loaded_dataset = DatasetLoader.load( train_type="seq", load_path=pkl_paths, - max_len=64, + window_size=64, ) assert loaded_dataset is not None assert len(loaded_dataset) > 0 @@ -60,7 +60,7 @@ def test_dpo_strategy_with_random_data(base_test_env): dpo_dataset = DatasetLoader.load( train_type="dpo", load_path=pkl_path, - max_len=64, + window_size=64, ) assert dpo_dataset is not None diff --git a/tests/test_train_strategy.py b/tests/test_train_strategy.py index 61ce506..102a9da 100644 --- a/tests/test_train_strategy.py +++ b/tests/test_train_strategy.py @@ -1,64 +1,135 @@ import torch import numpy as np +import pytest from khaosz.config import * -from khaosz.trainer import * +from khaosz.trainer.schedule import * from khaosz.data.data_util import * -def test_multi_turn_training(base_test_env, multi_turn_dataset): - """Test training with multi-turn conversation data""" - optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) - train_config = TrainConfig( - dataset=multi_turn_dataset, - optimizer=optimizer, - checkpoint_dir=base_test_env["test_dir"], - n_epoch=2, - batch_size=2, - checkpoint_interval=3, - accumulation_steps=1, - max_grad_norm=1.0, - random_seed=int(np.random.randint(1000)) - ) - - schedule_config = CosineScheduleConfig( - warmup_steps=50, - total_steps=100 - ) - - train_config.strategy = StrategyFactory.load( - base_test_env["model"], - "sft", - base_test_env["device"], - bos_token_id=2, - eos_token_id=3, - multi_turn=True - ) - - model_parameter = ModelParameter( - base_test_env["model"], - base_test_env["tokenizer"], - base_test_env["transformer_config"] - ) - - trainer = Trainer(model_parameter, train_config, schedule_config) - checkpoint = trainer.train() - - assert len(checkpoint.loss_list) > 0 def test_schedule_factory_random_configs(): """Test scheduler factory with random configurations""" - schedule_configs = [ - CosineScheduleConfig( - warmup_steps=np.random.randint(50, 200), - total_steps=np.random.randint(1000, 5000), - min_rate=np.random.uniform(0.01, 0.1) - ), - SGDRScheduleConfig( - warmup_steps=np.random.randint(50, 200), - cycle_length=np.random.randint(500, 2000), - t_mult=np.random.randint(1, 3), - min_rate=np.random.uniform(0.01, 0.1) - ) + # Create a simple model and optimizer for testing + model = torch.nn.Linear(10, 2) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) + + # Test multiple random configurations + for _ in range(5): # Test 5 random configurations + schedule_configs = [ + CosineScheduleConfig( + warmup_steps=np.random.randint(50, 200), + total_steps=np.random.randint(1000, 5000), + min_rate=np.random.uniform(0.01, 0.1) + ), + SGDRScheduleConfig( + warmup_steps=np.random.randint(50, 200), + cycle_length=np.random.randint(500, 2000), + t_mult=np.random.randint(1, 3), + min_rate=np.random.uniform(0.01, 0.1) + ) + ] + + for config in schedule_configs: + # Validate configuration + config.validate() + + # Create scheduler using factory + scheduler = SchedulerFactory.load_scheduler(optimizer, config) + + # Verify scheduler type + if isinstance(config, CosineScheduleConfig): + assert isinstance(scheduler, CosineScheduler) + assert scheduler.warmup_steps == config.warmup_steps + assert scheduler.lr_decay_steps == config.total_steps - config.warmup_steps + assert scheduler.min_rate == config.min_rate + elif isinstance(config, SGDRScheduleConfig): + assert isinstance(scheduler, SGDRScheduler) + assert scheduler.warmup_steps == config.warmup_steps + assert scheduler.cycle_length == config.cycle_length + assert scheduler.t_mult == config.t_mult + assert scheduler.min_rate == config.min_rate + + # Test scheduler state dict functionality + state_dict = scheduler.state_dict() + assert 'warmup_steps' in state_dict + assert 'min_rate' in state_dict + + # Test scheduler step functionality + initial_lr = scheduler.get_last_lr() + scheduler.step() + new_lr = scheduler.get_last_lr() + + # Learning rate should change after step, or if it's the first step, + # the epoch counter should increment + assert initial_lr != new_lr or scheduler.last_epoch > -1 + + +def test_schedule_factory_edge_cases(): + """Test scheduler factory with edge cases and boundary conditions""" + + model = torch.nn.Linear(10, 2) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) + + # Test edge cases for CosineScheduleConfig + edge_cases = [ + # Minimal warmup and steps + CosineScheduleConfig(warmup_steps=1, total_steps=10, min_rate=0.01), + # Large values + CosineScheduleConfig(warmup_steps=1000, total_steps=10000, min_rate=0.5), + # Zero min_rate (edge case) + CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.0), ] - # todo \ No newline at end of file + + for config in edge_cases: + config.validate() + scheduler = SchedulerFactory.load_scheduler(optimizer, config) + assert scheduler is not None + + # Test multiple steps + for _ in range(10): + scheduler.step() + + +def test_schedule_factory_invalid_configs(): + """Test scheduler factory with invalid configurations""" + + # Test invalid configurations that should raise errors + invalid_configs = [ + # Negative warmup steps + CosineScheduleConfig(warmup_steps=-10, total_steps=1000, min_rate=0.1), + # Total steps less than warmup steps + CosineScheduleConfig(warmup_steps=500, total_steps=400, min_rate=0.1), + # Invalid min_rate + CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=-0.1), + CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=1.1), + ] + + for config in invalid_configs: + with pytest.raises(ValueError): + config.validate() + + +def test_schedule_factory_state_persistence(): + """Test scheduler state persistence (save/load)""" + + model = torch.nn.Linear(10, 2) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) + + config = CosineScheduleConfig(warmup_steps=100, total_steps=1000, min_rate=0.1) + scheduler = SchedulerFactory.load_scheduler(optimizer, config) + + # Take a few steps + for _ in range(5): + scheduler.step() + + # Save state + state_dict = scheduler.state_dict() + + # Create new scheduler and load state + new_scheduler = SchedulerFactory.load_scheduler(optimizer, config) + new_scheduler.load_state_dict(state_dict) + + # Verify states match + assert scheduler.last_epoch == new_scheduler.last_epoch + assert scheduler.get_last_lr() == new_scheduler.get_last_lr() \ No newline at end of file