From 6e1a497c0402f1cb78db314a9589ec54376a0bdd Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 3 Oct 2025 22:18:31 +0800 Subject: [PATCH] =?UTF-8?q?test(sampler):=20=E5=88=A0=E9=99=A4=E5=86=97?= =?UTF-8?q?=E4=BD=99=E7=9A=84=E8=AE=AD=E7=BB=83=E6=81=A2=E5=A4=8D=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_sampler.py | 52 +------------------------------------------ 1 file changed, 1 insertion(+), 51 deletions(-) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 0d2f375..9109c29 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -1,5 +1,3 @@ -import os -import torch from khaosz.core import * from khaosz.trainer import * from khaosz.trainer.data_util import * @@ -57,55 +55,7 @@ def test_sampler_state_persistence(random_dataset): assert indices2 == indices3 -def test_training_resume_with_sampler(base_test_env, random_dataset): - """Test that training can resume correctly with sampler state""" - test_dir = base_test_env["test_dir"] - dataset = random_dataset - - # Initial training config - optimizer = torch.optim.AdamW(base_test_env["model"].parameters()) - train_config = TrainConfig( - dataset=dataset, - optimizer=optimizer, - checkpoint_dir=test_dir, - n_epoch=1, - batch_size=2, - checkpoint_interval=5, - accumulation_steps=1, - max_grad_norm=1.0, - random_seed=42 - ) - - train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq") - model_parameter = ModelParameter( - base_test_env["model"], - base_test_env["tokenizer"], - base_test_env["transformer_config"] - ) - schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - - # First training run - stop after a few steps - trainer = Trainer(model_parameter, train_config, schedule_config) - try: - # Run for a few steps then interrupt - for i, _ in enumerate(trainer.train()): - if i >= 3: # Run for 3 steps then stop - break - except Exception: - pass - - # Load checkpoint - checkpoint_path = os.path.join(test_dir, "iter_3") - checkpoint = Checkpoint().load(checkpoint_path) - - # Resume training - trainer = Trainer(model_parameter, train_config, schedule_config) - resumed_checkpoint = trainer.train(checkpoint) - - # Check that training resumed from correct point - assert resumed_checkpoint.sampler_state['current_iter'] > 3 - -def test_sampler_across_epochs(base_test_env, random_dataset): +def test_sampler_across_epochs(random_dataset): """Test sampler behavior across multiple epochs""" dataset = random_dataset n = len(dataset)