diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 0d2f375..9109c29 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -1,5 +1,3 @@ -import os -import torch from khaosz.core import * from khaosz.trainer import * from khaosz.trainer.data_util import * @@ -57,55 +55,7 @@ def test_sampler_state_persistence(random_dataset): assert indices2 == indices3 -def test_training_resume_with_sampler(base_test_env, random_dataset): - """Test that training can resume correctly with sampler state""" - test_dir = base_test_env["test_dir"] - dataset = random_dataset - - # Initial training config - optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) - train_config = TrainConfig( - dataset=dataset, - optimizer=optimizer, - checkpoint_dir=test_dir, - n_epoch=1, - batch_size=2, - checkpoint_interval=5, - accumulation_steps=1, - max_grad_norm=1.0, - random_seed=42 - ) - - train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq") - model_parameter = ModelParameter( - base_test_env["model"], - base_test_env["tokenizer"], - base_test_env["transformer_config"] - ) - schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - - # First training run - stop after a few steps - trainer = Trainer(model_parameter, train_config, schedule_config) - try: - # Run for a few steps then interrupt - for i, _ in enumerate(trainer.train()): - if i >= 3: # Run for 3 steps then stop - break - except Exception: - pass - - # Load checkpoint - checkpoint_path = os.path.join(test_dir, "iter_3") - checkpoint = Checkpoint().load(checkpoint_path) - - # Resume training - trainer = Trainer(model_parameter, train_config, schedule_config) - resumed_checkpoint = trainer.train(checkpoint) - - # Check that training resumed from correct point - assert resumed_checkpoint.sampler_state['current_iter'] > 3 - -def test_sampler_across_epochs(base_test_env, random_dataset): +def test_sampler_across_epochs(random_dataset): """Test sampler behavior across multiple epochs""" dataset = random_dataset n = len(dataset)