chore: 优化未使用的模块

This commit is contained in:
ViperEkura 2026-04-06 09:54:17 +08:00
parent a57a16430d
commit ace8f6ee68
5 changed files with 24 additions and 165 deletions

View File

@ -12,18 +12,19 @@ from astrai.inference import (
InferenceEngine,
)
from astrai.model import AutoModel, Transformer
from astrai.tokenize import BpeTokenizer
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [
"Transformer",
"ModelConfig",
"TrainConfig",
"DatasetFactory",
"BpeTokenizer",
"AutoTokenizer",
"GenerationRequest",
"InferenceEngine",
"Trainer",
"CallbackFactory",
"StrategyFactory",
"SchedulerFactory",
"BaseFactory",

View File

@ -1,15 +1,8 @@
from astrai.tokenize.chat_template import ChatTemplate, MessageType
from astrai.tokenize.tokenizer import (
AutoTokenizer,
BpeTokenizer,
)
from astrai.tokenize.trainer import BpeTrainer
from astrai.tokenize.tokenizer import AutoTokenizer
__all__ = [
"AutoTokenizer",
"BpeTokenizer",
"BpeTrainer",
"ChatTemplate",
"MessageType",
"HistoryType",
]

View File

@ -238,42 +238,3 @@ class AutoTokenizer:
return self.encode(rendered)
return rendered
class BpeTokenizer(AutoTokenizer):
"""BPE tokenizer implementation."""
def __init__(
self,
special_token_map: Dict[str, str] = None,
path: Optional[str] = None,
chat_template: Optional[str] = None,
):
special_token_map = special_token_map or {
"bos": "<begin▁of▁sentence>",
"eos": "<end▁of▁sentence>",
"pad": "<▁pad▁>",
"im_start": "<im▁start>",
"im_end": "<im▁end>",
}
self._tokenizer = None
self._init_tokenizer()
super().__init__(
path, special_token_map=special_token_map, chat_template=chat_template
)
def _init_tokenizer(self):
"""Initialize a new BPE tokenizer with default settings."""
model = BPE()
self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence(
[normalizers.NFC(), normalizers.Strip()]
)
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.UnicodeScripts(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
]
)
self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)

View File

@ -1,108 +0,0 @@
"""
BPE Tokenizer Trainer module.
Provides training functionality for BPE tokenizers.
"""
from typing import List, Union
from tokenizers import pre_tokenizers
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
class BpeTrainer:
"""BPE tokenizer trainer."""
def __init__(self, tokenizer):
"""Initialize trainer with a tokenizer instance.
Args:
tokenizer: A BpeTokenizer instance
"""
self.tokenizer = tokenizer
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length: int = 18,
):
"""Prepare the BPE trainer with proper configuration."""
assert reserved_token_size > len(self.tokenizer._special_tokens)
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self.tokenizer._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainerImpl(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 6,
max_token_length=max_token_length,
special_tokens=self.tokenizer._control_tokens,
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, reserved_tokens
def train(
self,
files: Union[str, List[str]],
vocab_size: int,
min_freq: int,
reserved_token_size: int = 100,
**kwargs,
):
"""Train tokenizer from files.
Args:
files: Path or list of paths to training files
vocab_size: Target vocabulary size
min_freq: Minimum frequency for tokens
reserved_token_size: Number of reserved tokens
**kwargs: Additional arguments
"""
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
def train_from_iterator(
self,
iterator,
vocab_size: int,
min_freq: int,
reserved_token_size: int = 100,
**kwargs,
):
"""Train tokenizer from iterator.
Args:
iterator: Iterator yielding training strings
vocab_size: Target vocabulary size
min_freq: Minimum frequency for tokens
reserved_token_size: Number of reserved tokens
**kwargs: Additional arguments
"""
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train_from_iterator(
iterator=iterator, trainer=trainer
)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
__all__ = ["BpeTrainer"]

View File

@ -7,12 +7,27 @@ import numpy as np
import pytest
import safetensors.torch as st
import torch
from tokenizers import pre_tokenizers
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
from astrai.tokenize import BpeTokenizer, BpeTrainer
from astrai.tokenize import AutoTokenizer
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
"""Create a simple tokenizer for testing purposes."""
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
trainer = trainers.BpeTrainer(
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
)
# Train on empty iterator with single character
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
auto_tokenizer = AutoTokenizer()
auto_tokenizer._tokenizer = tokenizer
auto_tokenizer._special_token_map = {"unk_token": "<unk>", "pad_token": "<pad>"}
return auto_tokenizer
class RandomDataset(Dataset):
@ -109,7 +124,7 @@ def base_test_env(request: pytest.FixtureRequest):
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer()
tokenizer = create_test_tokenizer()
yield {
"device": device,
@ -164,10 +179,7 @@ def test_env(request: pytest.FixtureRequest):
with open(config_path, "w") as f:
json.dump(config, f)
tokenizer = BpeTokenizer()
trainer = BpeTrainer(tokenizer)
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path)