From ace8f6ee68e78a42e9255cac8c8ecfce5715c722 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 6 Apr 2026 09:54:17 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E4=BC=98=E5=8C=96=E6=9C=AA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E7=9A=84=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/__init__.py | 7 ++- astrai/tokenize/__init__.py | 9 +-- astrai/tokenize/tokenizer.py | 39 ------------- astrai/tokenize/trainer.py | 108 ----------------------------------- tests/conftest.py | 26 ++++++--- 5 files changed, 24 insertions(+), 165 deletions(-) delete mode 100644 astrai/tokenize/trainer.py diff --git a/astrai/__init__.py b/astrai/__init__.py index 453eaa0..a89d85e 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -12,18 +12,19 @@ from astrai.inference import ( InferenceEngine, ) from astrai.model import AutoModel, Transformer -from astrai.tokenize import BpeTokenizer -from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer +from astrai.tokenize import AutoTokenizer +from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer __all__ = [ "Transformer", "ModelConfig", "TrainConfig", "DatasetFactory", - "BpeTokenizer", + "AutoTokenizer", "GenerationRequest", "InferenceEngine", "Trainer", + "CallbackFactory", "StrategyFactory", "SchedulerFactory", "BaseFactory", diff --git a/astrai/tokenize/__init__.py b/astrai/tokenize/__init__.py index f0b10f7..3838b0d 100644 --- a/astrai/tokenize/__init__.py +++ b/astrai/tokenize/__init__.py @@ -1,15 +1,8 @@ from astrai.tokenize.chat_template import ChatTemplate, MessageType -from astrai.tokenize.tokenizer import ( - AutoTokenizer, - BpeTokenizer, -) -from astrai.tokenize.trainer import BpeTrainer +from astrai.tokenize.tokenizer import AutoTokenizer __all__ = [ "AutoTokenizer", - "BpeTokenizer", - "BpeTrainer", "ChatTemplate", "MessageType", - "HistoryType", ] diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py index 35d634e..c3e02ef 100644 --- a/astrai/tokenize/tokenizer.py +++ b/astrai/tokenize/tokenizer.py @@ -238,42 +238,3 @@ class AutoTokenizer: return self.encode(rendered) return rendered - - -class BpeTokenizer(AutoTokenizer): - """BPE tokenizer implementation.""" - - def __init__( - self, - special_token_map: Dict[str, str] = None, - path: Optional[str] = None, - chat_template: Optional[str] = None, - ): - special_token_map = special_token_map or { - "bos": "<|begin▁of▁sentence|>", - "eos": "<|end▁of▁sentence|>", - "pad": "<|▁pad▁|>", - "im_start": "<|im▁start|>", - "im_end": "<|im▁end|>", - } - self._tokenizer = None - self._init_tokenizer() - super().__init__( - path, special_token_map=special_token_map, chat_template=chat_template - ) - - def _init_tokenizer(self): - """Initialize a new BPE tokenizer with default settings.""" - model = BPE() - self._tokenizer = Tokenizer(model) - self._tokenizer.normalizer = normalizers.Sequence( - [normalizers.NFC(), normalizers.Strip()] - ) - self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( - [ - pre_tokenizers.UnicodeScripts(), - pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True), - ] - ) - self._tokenizer.decoder = decoders.ByteLevel() - self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) diff --git a/astrai/tokenize/trainer.py b/astrai/tokenize/trainer.py deleted file mode 100644 index 20e5536..0000000 --- a/astrai/tokenize/trainer.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -BPE Tokenizer Trainer module. - -Provides training functionality for BPE tokenizers. -""" - -from typing import List, Union - -from tokenizers import pre_tokenizers -from tokenizers.trainers import BpeTrainer as BpeTrainerImpl - - -class BpeTrainer: - """BPE tokenizer trainer.""" - - def __init__(self, tokenizer): - """Initialize trainer with a tokenizer instance. - - Args: - tokenizer: A BpeTokenizer instance - """ - self.tokenizer = tokenizer - - def _prepare_trainer( - self, - vocab_size: int, - min_freq: int, - reserved_token_size: int, - max_token_length: int = 18, - ): - """Prepare the BPE trainer with proper configuration.""" - assert reserved_token_size > len(self.tokenizer._special_tokens) - reserved_tokens = [ - f"<|reserve{i:02d}|>" - for i in range(reserved_token_size - len(self.tokenizer._special_tokens)) - ] - detail_vocab_size = vocab_size - ( - len(reserved_tokens) + len(self.tokenizer._special_tokens) - ) - alphabet = pre_tokenizers.ByteLevel.alphabet() - min_size = len(alphabet) + len(self.tokenizer._control_tokens) - assert detail_vocab_size > min_size - - trainer = BpeTrainerImpl( - vocab_size=detail_vocab_size, - min_frequency=min_freq, - limit_alphabet=detail_vocab_size // 6, - max_token_length=max_token_length, - special_tokens=self.tokenizer._control_tokens, - initial_alphabet=alphabet, - show_progress=True, - ) - return trainer, reserved_tokens - - def train( - self, - files: Union[str, List[str]], - vocab_size: int, - min_freq: int, - reserved_token_size: int = 100, - **kwargs, - ): - """Train tokenizer from files. - - Args: - files: Path or list of paths to training files - vocab_size: Target vocabulary size - min_freq: Minimum frequency for tokens - reserved_token_size: Number of reserved tokens - **kwargs: Additional arguments - """ - trainer, reserved_tokens = self._prepare_trainer( - vocab_size, min_freq, reserved_token_size, **kwargs - ) - self.tokenizer._tokenizer.train(files=files, trainer=trainer) - self.tokenizer._tokenizer.add_special_tokens( - self.tokenizer._special_tokens + reserved_tokens - ) - - def train_from_iterator( - self, - iterator, - vocab_size: int, - min_freq: int, - reserved_token_size: int = 100, - **kwargs, - ): - """Train tokenizer from iterator. - - Args: - iterator: Iterator yielding training strings - vocab_size: Target vocabulary size - min_freq: Minimum frequency for tokens - reserved_token_size: Number of reserved tokens - **kwargs: Additional arguments - """ - trainer, reserved_tokens = self._prepare_trainer( - vocab_size, min_freq, reserved_token_size, **kwargs - ) - self.tokenizer._tokenizer.train_from_iterator( - iterator=iterator, trainer=trainer - ) - self.tokenizer._tokenizer.add_special_tokens( - self.tokenizer._special_tokens + reserved_tokens - ) - - -__all__ = ["BpeTrainer"] diff --git a/tests/conftest.py b/tests/conftest.py index 01b7f67..272e7e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,12 +7,27 @@ import numpy as np import pytest import safetensors.torch as st import torch -from tokenizers import pre_tokenizers +from tokenizers import Tokenizer, models, pre_tokenizers, trainers from torch.utils.data import Dataset from astrai.config.model_config import ModelConfig from astrai.model.transformer import Transformer -from astrai.tokenize import BpeTokenizer, BpeTrainer +from astrai.tokenize import AutoTokenizer + + +def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer: + """Create a simple tokenizer for testing purposes.""" + tokenizer = Tokenizer(models.BPE()) + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel() + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, min_frequency=1, special_tokens=["", ""] + ) + # Train on empty iterator with single character + tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer) + auto_tokenizer = AutoTokenizer() + auto_tokenizer._tokenizer = tokenizer + auto_tokenizer._special_token_map = {"unk_token": "", "pad_token": ""} + return auto_tokenizer class RandomDataset(Dataset): @@ -109,7 +124,7 @@ def base_test_env(request: pytest.FixtureRequest): device = "cuda" if torch.cuda.is_available() else "cpu" transformer_config = ModelConfig().load(config_path) model = Transformer(transformer_config).to(device=device) - tokenizer = BpeTokenizer() + tokenizer = create_test_tokenizer() yield { "device": device, @@ -164,10 +179,7 @@ def test_env(request: pytest.FixtureRequest): with open(config_path, "w") as f: json.dump(config, f) - tokenizer = BpeTokenizer() - trainer = BpeTrainer(tokenizer) - sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) - trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) + tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"]) tokenizer.save(tokenizer_path) transformer_config = ModelConfig().load(config_path)