diff --git a/tests/conftest.py b/tests/conftest.py index f6d2c7f..5280650 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ from tokenizers import pre_tokenizers from torch.utils.data import Dataset from astrai.config.model_config import ModelConfig -from astrai.data.tokenizer import BpeTokenizer +from astrai.data.tokenizer import BpeTokenizer, BpeTrainer from astrai.model.transformer import Transformer @@ -143,6 +143,7 @@ def early_stopping_dataset(): @pytest.fixture def test_env(request: pytest.FixtureRequest): """Create a test environment with saved model and tokenizer files.""" + func_name = request.function.__name__ test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") @@ -163,8 +164,9 @@ def test_env(request: pytest.FixtureRequest): json.dump(config, f) tokenizer = BpeTokenizer() + trainer = BpeTrainer(tokenizer) sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) - tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) + trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) tokenizer.save(tokenizer_path) transformer_config = ModelConfig().load(config_path)