fix: 修复测试部分导入问题

This commit is contained in:
ViperEkura 2026-04-03 15:01:39 +08:00
parent c5560740b6
commit 3a7d98a950
1 changed files with 4 additions and 2 deletions

View File

@ -10,7 +10,7 @@ from tokenizers import pre_tokenizers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.data.tokenizer import BpeTokenizer from astrai.data.tokenizer import BpeTokenizer, BpeTrainer
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer
@ -143,6 +143,7 @@ def early_stopping_dataset():
@pytest.fixture @pytest.fixture
def test_env(request: pytest.FixtureRequest): def test_env(request: pytest.FixtureRequest):
"""Create a test environment with saved model and tokenizer files.""" """Create a test environment with saved model and tokenizer files."""
func_name = request.function.__name__ func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json") config_path = os.path.join(test_dir, "config.json")
@ -163,8 +164,9 @@ def test_env(request: pytest.FixtureRequest):
json.dump(config, f) json.dump(config, f)
tokenizer = BpeTokenizer() tokenizer = BpeTokenizer()
trainer = BpeTrainer(tokenizer)
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
tokenizer.save(tokenizer_path) tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path) transformer_config = ModelConfig().load(config_path)