chore: 优化未使用的模块

This commit is contained in:
ViperEkura 2026-04-06 09:54:17 +08:00
parent a57a16430d
commit ace8f6ee68
5 changed files with 24 additions and 165 deletions

View File

@ -12,18 +12,19 @@ from astrai.inference import (
InferenceEngine, InferenceEngine,
) )
from astrai.model import AutoModel, Transformer from astrai.model import AutoModel, Transformer
from astrai.tokenize import BpeTokenizer from astrai.tokenize import AutoTokenizer
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [ __all__ = [
"Transformer", "Transformer",
"ModelConfig", "ModelConfig",
"TrainConfig", "TrainConfig",
"DatasetFactory", "DatasetFactory",
"BpeTokenizer", "AutoTokenizer",
"GenerationRequest", "GenerationRequest",
"InferenceEngine", "InferenceEngine",
"Trainer", "Trainer",
"CallbackFactory",
"StrategyFactory", "StrategyFactory",
"SchedulerFactory", "SchedulerFactory",
"BaseFactory", "BaseFactory",

View File

@ -1,15 +1,8 @@
from astrai.tokenize.chat_template import ChatTemplate, MessageType from astrai.tokenize.chat_template import ChatTemplate, MessageType
from astrai.tokenize.tokenizer import ( from astrai.tokenize.tokenizer import AutoTokenizer
AutoTokenizer,
BpeTokenizer,
)
from astrai.tokenize.trainer import BpeTrainer
__all__ = [ __all__ = [
"AutoTokenizer", "AutoTokenizer",
"BpeTokenizer",
"BpeTrainer",
"ChatTemplate", "ChatTemplate",
"MessageType", "MessageType",
"HistoryType",
] ]

View File

@ -238,42 +238,3 @@ class AutoTokenizer:
return self.encode(rendered) return self.encode(rendered)
return rendered return rendered
class BpeTokenizer(AutoTokenizer):
"""BPE tokenizer implementation."""
def __init__(
self,
special_token_map: Dict[str, str] = None,
path: Optional[str] = None,
chat_template: Optional[str] = None,
):
special_token_map = special_token_map or {
"bos": "<begin▁of▁sentence>",
"eos": "<end▁of▁sentence>",
"pad": "<▁pad▁>",
"im_start": "<im▁start>",
"im_end": "<im▁end>",
}
self._tokenizer = None
self._init_tokenizer()
super().__init__(
path, special_token_map=special_token_map, chat_template=chat_template
)
def _init_tokenizer(self):
"""Initialize a new BPE tokenizer with default settings."""
model = BPE()
self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence(
[normalizers.NFC(), normalizers.Strip()]
)
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.UnicodeScripts(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
]
)
self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)

View File

@ -1,108 +0,0 @@
"""
BPE Tokenizer Trainer module.
Provides training functionality for BPE tokenizers.
"""
from typing import List, Union
from tokenizers import pre_tokenizers
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
class BpeTrainer:
"""BPE tokenizer trainer."""
def __init__(self, tokenizer):
"""Initialize trainer with a tokenizer instance.
Args:
tokenizer: A BpeTokenizer instance
"""
self.tokenizer = tokenizer
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length: int = 18,
):
"""Prepare the BPE trainer with proper configuration."""
assert reserved_token_size > len(self.tokenizer._special_tokens)
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self.tokenizer._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainerImpl(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 6,
max_token_length=max_token_length,
special_tokens=self.tokenizer._control_tokens,
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, reserved_tokens
def train(
self,
files: Union[str, List[str]],
vocab_size: int,
min_freq: int,
reserved_token_size: int = 100,
**kwargs,
):
"""Train tokenizer from files.
Args:
files: Path or list of paths to training files
vocab_size: Target vocabulary size
min_freq: Minimum frequency for tokens
reserved_token_size: Number of reserved tokens
**kwargs: Additional arguments
"""
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
def train_from_iterator(
self,
iterator,
vocab_size: int,
min_freq: int,
reserved_token_size: int = 100,
**kwargs,
):
"""Train tokenizer from iterator.
Args:
iterator: Iterator yielding training strings
vocab_size: Target vocabulary size
min_freq: Minimum frequency for tokens
reserved_token_size: Number of reserved tokens
**kwargs: Additional arguments
"""
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train_from_iterator(
iterator=iterator, trainer=trainer
)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
__all__ = ["BpeTrainer"]

View File

@ -7,12 +7,27 @@ import numpy as np
import pytest import pytest
import safetensors.torch as st import safetensors.torch as st
import torch import torch
from tokenizers import pre_tokenizers from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer
from astrai.tokenize import BpeTokenizer, BpeTrainer from astrai.tokenize import AutoTokenizer
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
"""Create a simple tokenizer for testing purposes."""
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
trainer = trainers.BpeTrainer(
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
)
# Train on empty iterator with single character
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
auto_tokenizer = AutoTokenizer()
auto_tokenizer._tokenizer = tokenizer
auto_tokenizer._special_token_map = {"unk_token": "<unk>", "pad_token": "<pad>"}
return auto_tokenizer
class RandomDataset(Dataset): class RandomDataset(Dataset):
@ -109,7 +124,7 @@ def base_test_env(request: pytest.FixtureRequest):
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_config = ModelConfig().load(config_path) transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device) model = Transformer(transformer_config).to(device=device)
tokenizer = BpeTokenizer() tokenizer = create_test_tokenizer()
yield { yield {
"device": device, "device": device,
@ -164,10 +179,7 @@ def test_env(request: pytest.FixtureRequest):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config, f) json.dump(config, f)
tokenizer = BpeTokenizer() tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
trainer = BpeTrainer(tokenizer)
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
tokenizer.save(tokenizer_path) tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path) transformer_config = ModelConfig().load(config_path)