fix: 修复测试部分导入问题
This commit is contained in:
parent
c5560740b6
commit
3a7d98a950
|
|
@ -10,7 +10,7 @@ from tokenizers import pre_tokenizers
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.data.tokenizer import BpeTokenizer
|
from astrai.data.tokenizer import BpeTokenizer, BpeTrainer
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -143,6 +143,7 @@ def early_stopping_dataset():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_env(request: pytest.FixtureRequest):
|
def test_env(request: pytest.FixtureRequest):
|
||||||
"""Create a test environment with saved model and tokenizer files."""
|
"""Create a test environment with saved model and tokenizer files."""
|
||||||
|
|
||||||
func_name = request.function.__name__
|
func_name = request.function.__name__
|
||||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
|
@ -163,8 +164,9 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
|
|
||||||
tokenizer = BpeTokenizer()
|
tokenizer = BpeTokenizer()
|
||||||
|
trainer = BpeTrainer(tokenizer)
|
||||||
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
|
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
|
||||||
tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
||||||
tokenizer.save(tokenizer_path)
|
tokenizer.save(tokenizer_path)
|
||||||
|
|
||||||
transformer_config = ModelConfig().load(config_path)
|
transformer_config = ModelConfig().load(config_path)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue