diff --git a/tests/conftest.py b/tests/conftest.py index ceba3cd..b03f98b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,8 +16,9 @@ from khaosz.trainer.data_util import * matplotlib.use("Agg") @pytest.fixture -def base_test_env(): - test_dir = tempfile.mkdtemp() +def base_test_env(request: pytest.FixtureRequest): + func_name = request.function.__name__ + test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") n_dim_choices = [8, 16, 32] @@ -41,12 +42,13 @@ def base_test_env(): with open(config_path, 'w') as f: json.dump(config, f) - + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transformer_config = TransformerConfig().load(config_path) - model = Transformer(transformer_config) + model = Transformer(transformer_config).to(device=device) tokenizer = BpeTokenizer() yield { + "device": device, "test_dir": test_dir, "config_path": config_path, "transformer_config": transformer_config, diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 35b832e..ba4cc9b 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -37,7 +37,7 @@ def test_callback_integration(base_test_env, random_dataset): def on_epoch_end(self, trainer, **kwargs): callback_calls.append('on_epoch_end') - train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq") + train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) model_parameter = ModelParameter( base_test_env["model"], base_test_env["tokenizer"], diff --git a/tests/test_dataset_loader.py b/tests/test_dataset_loader.py index d92d2da..3fe29ad 100644 --- a/tests/test_dataset_loader.py +++ b/tests/test_dataset_loader.py @@ -34,7 +34,6 @@ def test_dataset_loader_random_paths(base_test_env): train_type="seq", load_path=pkl_paths, max_len=64, - device="cpu" ) assert loaded_dataset is not None assert len(loaded_dataset) > 0 @@ -62,7 +61,6 @@ def test_dpo_strategy_with_random_data(base_test_env): train_type="dpo", load_path=pkl_path, max_len=64, - device="cpu" ) assert dpo_dataset is not None diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index e4f12f5..b9cf322 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -41,7 +41,7 @@ def test_early_stopping_simulation(base_test_env): random_seed=42 ) - train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq") + train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) model_parameter = ModelParameter( base_test_env["model"], base_test_env["tokenizer"], diff --git a/tests/test_module.py b/tests/test_module.py index 025188e..8250570 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -10,8 +10,9 @@ from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore from tokenizers import pre_tokenizers @pytest.fixture -def test_env(): - test_dir = tempfile.mkdtemp() +def test_env(request: pytest.FixtureRequest): + func_name = request.function.__name__ + test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") tokenizer_path = os.path.join(test_dir, "tokenizer.json") model_path = os.path.join(test_dir, "model.safetensors") diff --git a/tests/test_train_config.py b/tests/test_train_config.py index 2236a41..60f2bd6 100644 --- a/tests/test_train_config.py +++ b/tests/test_train_config.py @@ -47,12 +47,7 @@ def test_gradient_accumulation(base_test_env, random_dataset): warmup_steps=10, total_steps=20 ) - - train_config.strategy = StrategyFactory.load( - base_test_env["model"], - "seq" - ) - + train_config.strategy = StrategyFactory.load(base_test_env["model"], "seq", base_test_env["device"]) model_parameter = ModelParameter( base_test_env["model"], base_test_env["tokenizer"], diff --git a/tests/test_train_strategy.py b/tests/test_train_strategy.py index a884bbb..100130b 100644 --- a/tests/test_train_strategy.py +++ b/tests/test_train_strategy.py @@ -28,6 +28,7 @@ def test_multi_turn_training(base_test_env, multi_turn_dataset): train_config.strategy = StrategyFactory.load( base_test_env["model"], "sft", + base_test_env["device"], bos_token_id=2, eos_token_id=3, user_token_id=1,