From 3a7d98a950fd3efd9b65661dbf247236d46094f6 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 3 Apr 2026 15:01:39 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E9=83=A8=E5=88=86=E5=AF=BC=E5=85=A5=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f6d2c7f..5280650 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ from tokenizers import pre_tokenizers from torch.utils.data import Dataset from astrai.config.model_config import ModelConfig -from astrai.data.tokenizer import BpeTokenizer +from astrai.data.tokenizer import BpeTokenizer, BpeTrainer from astrai.model.transformer import Transformer @@ -143,6 +143,7 @@ def early_stopping_dataset(): @pytest.fixture def test_env(request: pytest.FixtureRequest): """Create a test environment with saved model and tokenizer files.""" + func_name = request.function.__name__ test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") @@ -163,8 +164,9 @@ def test_env(request: pytest.FixtureRequest): json.dump(config, f) tokenizer = BpeTokenizer() + trainer = BpeTrainer(tokenizer) sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) - tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) + trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) tokenizer.save(tokenizer_path) transformer_config = ModelConfig().load(config_path)