From e7d29ca2d5e25684ee65fdc5a7383328cf774202 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Oct 2025 12:12:42 +0800 Subject: [PATCH] =?UTF-8?q?feat(tests):=20=E6=94=B9=E8=BF=9B=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=8E=AF=E5=A2=83=E9=85=8D=E7=BD=AE=E4=B8=8E=E8=AE=BE?= =?UTF-8?q?=E5=A4=87=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 10 ++++++---- tests/test_callbacks.py | 2 +- tests/test_dataset_loader.py | 2 -- tests/test_early_stopping.py | 2 +- tests/test_module.py | 5 +++-- tests/test_train_config.py | 7 +------ tests/test_train_strategy.py | 1 + 7 files changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ceba3cd..b03f98b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,8 +16,9 @@ from khaosz.trainer.data_util import * matplotlib.use("Agg") @pytest.fixture -def base_test_env(): - test_dir = tempfile.mkdtemp() +def base_test_env(request: pytest.FixtureRequest): + func_name = request.function.__name__ + test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") n_dim_choices = [8, 16, 32] @@ -41,12 +42,13 @@ def base_test_env(): with open(config_path, 'w') as f: json.dump(config, f) - + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transformer_config = TransformerConfig().load(config_path) - model = Transformer(transformer_config) + model = Transformer(transformer_config).to(device=device) tokenizer = BpeTokenizer() yield { + "device": device, "test_dir": test_dir, "config_path": config_path, "transformer_config": transformer_config, diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 35b832e..ba4cc9b 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -37,7 +37,7 @@ def test_callback_integration(base_test_env, random_dataset): def on_epoch_end(self, trainer, **kwargs): callback_calls.append('on_epoch_end') - train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq") + train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) model_parameter = ModelParameter( base_test_env["model"], base_test_env["tokenizer"], diff --git a/tests/test_dataset_loader.py b/tests/test_dataset_loader.py index d92d2da..3fe29ad 100644 --- a/tests/test_dataset_loader.py +++ b/tests/test_dataset_loader.py @@ -34,7 +34,6 @@ def test_dataset_loader_random_paths(base_test_env): train_type="seq", load_path=pkl_paths, max_len=64, - device="cpu" ) assert loaded_dataset is not None assert len(loaded_dataset) > 0 @@ -62,7 +61,6 @@ def test_dpo_strategy_with_random_data(base_test_env): train_type="dpo", load_path=pkl_path, max_len=64, - device="cpu" ) assert dpo_dataset is not None diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index e4f12f5..b9cf322 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -41,7 +41,7 @@ def test_early_stopping_simulation(base_test_env): random_seed=42 ) - train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq") + train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) model_parameter = ModelParameter( base_test_env["model"], base_test_env["tokenizer"], diff --git a/tests/test_module.py b/tests/test_module.py index 025188e..8250570 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -10,8 +10,9 @@ from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore from tokenizers import pre_tokenizers @pytest.fixture -def test_env(): - test_dir = tempfile.mkdtemp() +def test_env(request: pytest.FixtureRequest): + func_name = request.function.__name__ + test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") tokenizer_path = os.path.join(test_dir, "tokenizer.json") model_path = os.path.join(test_dir, "model.safetensors") diff --git a/tests/test_train_config.py b/tests/test_train_config.py index 2236a41..60f2bd6 100644 --- a/tests/test_train_config.py +++ b/tests/test_train_config.py @@ -47,12 +47,7 @@ def test_gradient_accumulation(base_test_env, random_dataset): warmup_steps=10, total_steps=20 ) - - train_config.strategy = StrategyFactory.load( - base_test_env["model"], - "seq" - ) - + train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) model_parameter = ModelParameter( base_test_env["model"], base_test_env["tokenizer"], diff --git a/tests/test_train_strategy.py b/tests/test_train_strategy.py index a884bbb..100130b 100644 --- a/tests/test_train_strategy.py +++ b/tests/test_train_strategy.py @@ -28,6 +28,7 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset): train_config.strategy = StrategyFactory.load( base_test_env["model"], "sft", + base_test_env["device"], bos_token_id=2, eos_token_id=3, user_token_id=1,