fix: 修复测试部分导入问题
This commit is contained in:
parent
c5560740b6
commit
3a7d98a950
|
|
@ -10,7 +10,7 @@ from tokenizers import pre_tokenizers
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
from astrai.data.tokenizer import BpeTokenizer, BpeTrainer
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
|
|
@ -143,6 +143,7 @@ def early_stopping_dataset():
|
|||
@pytest.fixture
|
||||
def test_env(request: pytest.FixtureRequest):
|
||||
"""Create a test environment with saved model and tokenizer files."""
|
||||
|
||||
func_name = request.function.__name__
|
||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
|
|
@ -163,8 +164,9 @@ def test_env(request: pytest.FixtureRequest):
|
|||
json.dump(config, f)
|
||||
|
||||
tokenizer = BpeTokenizer()
|
||||
trainer = BpeTrainer(tokenizer)
|
||||
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
|
||||
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
||||
trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
||||
tokenizer.save(tokenizer_path)
|
||||
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
|
|
|
|||
Loading…
Reference in New Issue