diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 971aaa5..cdaa856 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -5,29 +5,39 @@ import shutil import pytest import pickle import tempfile -import matplotlib +import numpy as np from torch.utils.data import Dataset from khaosz.core import * from khaosz.trainer import * +from khaosz.trainer.data_util import * -# to avoid _tkinter.TclError +import matplotlib matplotlib.use('Agg') @pytest.fixture def test_env(): + """Setup test environment with randomized data""" test_dir = tempfile.mkdtemp() config_path = os.path.join(test_dir, "config.json") + n_dim_choices = [16, 32, 64] + n_head_choices = [4, 8, 16] + + n_dim = int(np.random.choice(n_dim_choices)) + n_head = int(np.random.choice(n_head_choices)) + n_kvhead = n_head // 4 + d_ffn = n_dim * 4 + config = { "vocab_size": 1000, - "n_dim": 128, - "n_head": 4, - "n_kvhead": 2, - "d_ffn": 256, - "m_len": 64, - "n_layer": 2, + "n_dim": n_dim, + "n_head": n_head, + "n_kvhead": n_kvhead, + "d_ffn": d_ffn, + "m_len": 1024, + "n_layer": 4, "norm_eps": 1e-5 } @@ -38,20 +48,49 @@ def test_env(): model = Transformer(transformer_config) tokenizer = BpeTokenizer() - class DummyDataset(Dataset): - def __init__(self, length=10): - self.length = length + class RandomDataset(Dataset): + def __init__(self, length=None, max_length=64, vocab_size=1000): + self.length = length or int(np.random.randint(100, 200)) + self.max_length = max_length + self.vocab_size = vocab_size + def __len__(self): return self.length def __getitem__(self, idx): + return { - "input_ids": torch.randint(0, 1000, (64,)), - "target_ids": torch.randint(0, 1000, (64,)) + "input_ids": torch.randint(0, self.vocab_size, (self.max_length,)), + "target_ids": torch.randint(0, self.vocab_size, (self.max_length,)) } - dataset = DummyDataset() + class MultiTurnDataset(Dataset): + def __init__(self, length=None, max_length=64, vocab_size=1000): + self.length = length or int(np.random.randint(100, 200)) + self.max_length = max_length + self.vocab_size = vocab_size + + + def __len__(self): + return self.length + + def __getitem__(self, idx): + input_ids = torch.randint(0, self.vocab_size, (self.max_length,)) + target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) + loss_mask = build_loss_mask(input_ids, 0, 1) + attn_mask = build_attention_mask(input_ids, 2, True) + + + return { + "input_ids": input_ids, + "target_ids": target_ids, + "loss_mask": loss_mask, + "attn_mask": attn_mask, + } + + dataset = RandomDataset() + multi_turn_dataset = MultiTurnDataset() yield { "test_dir": test_dir, @@ -59,146 +98,335 @@ def test_env(): "transformer_config": transformer_config, "model": model, "tokenizer": tokenizer, - "dataset": dataset + "dataset": dataset, + "multi_turn_dataset": multi_turn_dataset } shutil.rmtree(test_dir) -def test_dataset_loader(test_env): +def test_dataset_loader_random_paths(test_env): + """Test dataset loader with multiple random paths""" test_dir = test_env["test_dir"] - pkl_path = os.path.join(test_dir, "test_data.pkl") - dummy_data = {"sequence": torch.randint(0, 1000, (64,))} - with open(pkl_path, "wb") as f: - pickle.dump(dummy_data, f) + # Create multiple pkl files with random data + num_files = np.random.randint(2, 5) + pkl_paths = [] - loaded_dataset = DatasetLoader.load(train_type="seq", load_path=pkl_path, max_len=64, device="cpu") + for i in range(num_files): + pkl_path = os.path.join(test_dir, f"test_data_{i}.pkl") + seq_length = np.random.randint(50, 100) + dummy_data = { + "sequence": torch.randint(0, 1000, (seq_length,)), + "chosen": torch.randint(0, 1000, (seq_length,)), + "rejected": torch.randint(0, 1000, (seq_length,)), + "chosen_mask": torch.ones(seq_length, dtype=torch.bool), + "rejected_mask": torch.ones(seq_length, dtype=torch.bool) + } + with open(pkl_path, "wb") as f: + pickle.dump(dummy_data, f) + pkl_paths.append(pkl_path) + + # Test loading with multiple paths + loaded_dataset = DatasetLoader.load( + train_type="seq", + load_path=pkl_paths, + max_len=64, + device="cpu" + ) assert loaded_dataset is not None + assert len(loaded_dataset) > 0 -def test_training_config(test_env): + +def test_different_batch_sizes(test_env): + """Test training with different batch sizes""" + batch_sizes = [1, 2, 4, 8] + + for batch_size in batch_sizes: + optimizer = torch.optim.AdamW(test_env["model"].parameters()) + train_config = TrainConfig( + dataset=test_env["dataset"], + optimizer=optimizer, + checkpoint_dir=test_env["test_dir"], + n_epoch=1, + batch_size=batch_size, + checkpoint_interval=5, + accumulation_steps=1, + max_grad_norm=1.0, + random_seed=np.random.randint(1000) + ) + + assert train_config.batch_size == batch_size + + +def test_random_sampler_consistency(test_env): + """Test RandomSampler produces consistent results with same seed""" + dataset = test_env["dataset"] + + # Create two samplers with same seed + sampler1 = RandomSampler(dataset, seed=42) + sampler2 = RandomSampler(dataset, seed=42) + + indices1 = list(iter(sampler1)) + indices2 = list(iter(sampler2)) + + assert indices1 == indices2 + + +def test_random_sampler_different_seeds(test_env): + """Test RandomSampler produces different results with different seeds""" + dataset = test_env["dataset"] + + # Create two samplers with different seeds + sampler1 = RandomSampler(dataset, seed=42) + sampler2 = RandomSampler(dataset, seed=123) + + indices1 = list(iter(sampler1)) + indices2 = list(iter(sampler2)) + + # Very high probability they should be different + assert indices1 != indices2 + + +def test_schedule_factory_random_configs(test_env): + """Test scheduler factory with random configurations""" + schedule_configs = [ + CosineScheduleConfig( + warmup_steps=np.random.randint(50, 200), + total_steps=np.random.randint(1000, 5000), + min_rate=np.random.uniform(0.01, 0.1) + ), + SgdrScheduleConfig( + warmup_steps=np.random.randint(50, 200), + cycle_length=np.random.randint(500, 2000), + t_mult=np.random.randint(1, 3), + min_rate=np.random.uniform(0.01, 0.1) + ) + ] + + for config in schedule_configs: + schedule_fn = SchedulerFactory.load_schedule_fn(config) + assert callable(schedule_fn) + + # Test the schedule function at different steps + for step in [0, config.warmup_steps // 2, config.warmup_steps, config.warmup_steps * 2]: + lr_mult = schedule_fn(step) + assert 0 <= lr_mult <= 1 + + +def test_multi_turn_training(test_env): + """Test training with multi-turn conversation data""" optimizer = torch.optim.AdamW(test_env["model"].parameters()) train_config = TrainConfig( - dataset=test_env["dataset"], + dataset=test_env["multi_turn_dataset"], optimizer=optimizer, checkpoint_dir=test_env["test_dir"], n_epoch=1, batch_size=2, - checkpoint_interval=5, + checkpoint_interval=3, accumulation_steps=1, max_grad_norm=1.0, - random_seed=42 + random_seed=np.random.randint(1000) ) - assert train_config.get_kwargs()["batch_size"] == 2 - -def test_cosine_schedule(test_env): - assert test_env is not None - schedule_config = CosineScheduleConfig( - warmup_steps=100, - total_steps=1000 - ) - kwargs = schedule_config.get_kwargs() - assert kwargs["warmup_steps"] == 100 - assert kwargs["lr_decay_steps"] == 900 - -def test_sgdr_schedule(test_env): - assert test_env is not None - schedule_config = SgdrScheduleConfig( - warmup_steps=100, - cycle_length=200, - t_mult=2 - ) - kwargs = schedule_config.get_kwargs() - assert kwargs["warmup_steps"] == 100 - assert kwargs["cycle_length"] == 200 - assert kwargs["t_mult"] == 2 - -def test_trainer_train(test_env): - optimizer = torch.optim.AdamW(test_env["model"].parameters()) - train_config = TrainConfig( - dataset=test_env["dataset"], - optimizer=optimizer, - checkpoint_dir=test_env["test_dir"], - n_epoch=1, - batch_size=2, - checkpoint_interval=5, - accumulation_steps=1, - max_grad_norm=1.0, - random_seed=42 - ) schedule_config = CosineScheduleConfig( - warmup_steps=100, - total_steps=1000 + warmup_steps=50, + total_steps=100 ) train_config.strategy = StrategyFactory.load( test_env["model"], - "seq", - pad_token_id=test_env["tokenizer"].pad_id + "sft", + bos_token_id=2, + eos_token_id=3, + user_token_id=1, + multi_turn=True ) + model_parameter = ModelParameter( test_env["model"], test_env["tokenizer"], test_env["transformer_config"] ) + trainer = Trainer(model_parameter, train_config, schedule_config) - trainer.train() - -def test_checkpoint(test_env): - temp_dir = test_env["test_dir"] - config = test_env["transformer_config"] - model = test_env["model"] - tokenizer = test_env["tokenizer"] - optimizer = torch.optim.AdamW(model.parameters()) - for _ in range(3): - optimizer.step() + checkpoint = trainer.train() - checkpoint = Checkpoint( - model=model, - tokenizer=tokenizer, - config=config, - loss_list=[1.0, 2.0, 3.0], - optim_state=optimizer.state_dict() + assert len(checkpoint.loss_list) > 0 + + +def test_gradient_accumulation(test_env): + """Test training with different gradient accumulation steps""" + accumulation_steps_list = [1, 2, 4] + + for accumulation_steps in accumulation_steps_list: + optimizer = torch.optim.AdamW(test_env["model"].parameters()) + train_config = TrainConfig( + dataset=test_env["dataset"], + optimizer=optimizer, + checkpoint_dir=test_env["test_dir"], + n_epoch=1, + batch_size=2, + checkpoint_interval=10, + accumulation_steps=accumulation_steps, + max_grad_norm=1.0, + random_seed=42 + ) + + schedule_config = CosineScheduleConfig( + warmup_steps=10, + total_steps=20 + ) + + train_config.strategy = StrategyFactory.load( + test_env["model"], + "seq" + ) + + model_parameter = ModelParameter( + test_env["model"], + test_env["tokenizer"], + test_env["transformer_config"] + ) + + trainer = Trainer(model_parameter, train_config, schedule_config) + checkpoint = trainer.train() + + assert train_config.accumulation_steps == accumulation_steps + +def test_dpo_strategy_with_random_data(test_env): + """Test DPO strategy with randomized preference data""" + test_dir = test_env["test_dir"] + + # Create DPO-style data + pkl_path = os.path.join(test_dir, "dpo_data.pkl") + seq_length = np.random.randint(40, 80) + + dummy_data = { + "chosen": torch.randint(0, 1000, (seq_length,)), + "rejected": torch.randint(0, 1000, (seq_length,)), + "chosen_mask": torch.ones(seq_length, dtype=torch.bool), + "rejected_mask": torch.ones(seq_length, dtype=torch.bool) + } + + with open(pkl_path, "wb") as f: + pickle.dump(dummy_data, f) + + # Load DPO dataset + dpo_dataset = DatasetLoader.load( + train_type="dpo", + load_path=pkl_path, + max_len=64, + device="cpu" ) - ckpt_dir = os.path.join(temp_dir, "ckpt") - checkpoint.save(ckpt_dir) - - loaded_ckpt = Checkpoint() - loaded_ckpt.load(ckpt_dir) - assert loaded_ckpt.loss_list == [1.0, 2.0, 3.0] - assert loaded_ckpt.optim_state == optimizer.state_dict() - - for p1, p2 in zip(model.parameters(), loaded_ckpt.model.parameters()): - assert torch.allclose(p1, p2) + assert dpo_dataset is not None + assert hasattr(dpo_dataset, 'fetcher') -def test_checkpoint_train(test_env): - config = test_env["transformer_config"] - model = test_env["model"] - tokenizer = test_env["tokenizer"] +def test_callback_integration(test_env): + """Test that all callbacks are properly integrated""" + optimizer = torch.optim.AdamW(test_env["model"].parameters()) + train_config = TrainConfig( + dataset=test_env["dataset"], + optimizer=optimizer, + checkpoint_dir=test_env["test_dir"], + n_epoch=1, + batch_size=2, + checkpoint_interval=3, + accumulation_steps=1, + max_grad_norm=1.0, + random_seed=42 + ) - class InterruptDataset(Dataset): - def __init__(self, length, interrupt_idx=0): + schedule_config = CosineScheduleConfig( + warmup_steps=10, + total_steps=20 + ) + + # Create custom callbacks to track calls + callback_calls = [] + + class TrackingCallback(TrainerCallback): + def on_train_begin(self, trainer, **kwargs): + callback_calls.append('on_train_begin') + + def on_batch_end(self, trainer, **kwargs): + callback_calls.append('on_batch_end') + + def on_epoch_end(self, trainer, **kwargs): + callback_calls.append('on_epoch_end') + + train_config.strategy = StrategyFactory.load(test_env["model"], "seq") + model_parameter = ModelParameter( + test_env["model"], + test_env["tokenizer"], + test_env["transformer_config"] + ) + + trainer = Trainer( + model_parameter, + train_config, + schedule_config, + callbacks=[TrackingCallback(), ProgressBarCallback()] + ) + + trainer.train() + + # Verify callbacks were called + assert 'on_train_begin' in callback_calls + assert 'on_batch_end' in callback_calls + assert 'on_epoch_end' in callback_calls + + +def test_memory_efficient_training(test_env): + """Test training with memory-efficient configurations""" + # Test with smaller batch sizes and gradient checkpointing + small_batch_configs = [ + {"batch_size": 1, "accumulation_steps": 8}, + {"batch_size": 2, "accumulation_steps": 4}, + {"batch_size": 4, "accumulation_steps": 2} + ] + + for config in small_batch_configs: + optimizer = torch.optim.AdamW(test_env["model"].parameters()) + train_config = TrainConfig( + dataset=test_env["dataset"], + optimizer=optimizer, + checkpoint_dir=test_env["test_dir"], + n_epoch=1, + batch_size=config["batch_size"], + checkpoint_interval=5, + accumulation_steps=config["accumulation_steps"], + max_grad_norm=1.0, + random_seed=42 + ) + + assert train_config.accumulation_steps == config["accumulation_steps"] + + +def test_early_stopping_simulation(test_env): + """Simulate early stopping behavior""" + class EarlyStoppingDataset(Dataset): + def __init__(self, length=10, stop_after=5): self.length = length - self.interrupt_idx = interrupt_idx + self.stop_after = stop_after + self.count = 0 def __len__(self): return self.length def __getitem__(self, idx): - if idx == self.interrupt_idx: - self.interrupt_idx = -1 - raise Exception("Interrupt") + self.count += 1 + if self.count == self.stop_after: + raise RuntimeError("Simulated early stopping") return { "input_ids": torch.randint(0, 1000, (64,)), "target_ids": torch.randint(0, 1000, (64,)) } - - - dataset = InterruptDataset(length=10, interrupt_idx=3) - param = ModelParameter(model, tokenizer, config) + + dataset = EarlyStoppingDataset() + optimizer = torch.optim.AdamW(test_env["model"].parameters()) train_config = TrainConfig( dataset=dataset, @@ -212,24 +440,28 @@ def test_checkpoint_train(test_env): random_seed=42 ) - train_config.strategy = StrategyFactory.load( + train_config.strategy = StrategyFactory.load(test_env["model"], "seq") + model_parameter = ModelParameter( test_env["model"], - "seq", - pad_token_id=test_env["tokenizer"].pad_id + test_env["tokenizer"], + test_env["transformer_config"] ) - schedule_config = CosineScheduleConfig( - warmup_steps=1, - total_steps=5 - ) - trainer = Trainer(param, train_config, schedule_config) + schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) + trainer = Trainer(model_parameter, train_config, schedule_config) + # Should handle early stopping gracefully checkpoint = None - try: checkpoint = trainer.train() + assert len(checkpoint.loss_list) == 2 except Exception: + # Handle any exceptions pass + + checkpoint = trainer.train(checkpoint) + assert len(checkpoint.loss_list) == 5 + 1 - checkpoint = trainer.train(train_checkpoint=checkpoint) - assert len(checkpoint.loss_list) == 5 - 1 - \ No newline at end of file + +if __name__ == "__main__": + # Run all tests + pytest.main([__file__, "-v"]) \ No newline at end of file