chore: 优化未使用的模块
This commit is contained in:
parent
a57a16430d
commit
ace8f6ee68
|
|
@ -12,18 +12,19 @@ from astrai.inference import (
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
from astrai.model import AutoModel, Transformer
|
from astrai.model import AutoModel, Transformer
|
||||||
from astrai.tokenize import BpeTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Transformer",
|
"Transformer",
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
"BpeTokenizer",
|
"AutoTokenizer",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
"CallbackFactory",
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
"BaseFactory",
|
"BaseFactory",
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,8 @@
|
||||||
from astrai.tokenize.chat_template import ChatTemplate, MessageType
|
from astrai.tokenize.chat_template import ChatTemplate, MessageType
|
||||||
from astrai.tokenize.tokenizer import (
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
AutoTokenizer,
|
|
||||||
BpeTokenizer,
|
|
||||||
)
|
|
||||||
from astrai.tokenize.trainer import BpeTrainer
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AutoTokenizer",
|
"AutoTokenizer",
|
||||||
"BpeTokenizer",
|
|
||||||
"BpeTrainer",
|
|
||||||
"ChatTemplate",
|
"ChatTemplate",
|
||||||
"MessageType",
|
"MessageType",
|
||||||
"HistoryType",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -238,42 +238,3 @@ class AutoTokenizer:
|
||||||
return self.encode(rendered)
|
return self.encode(rendered)
|
||||||
|
|
||||||
return rendered
|
return rendered
|
||||||
|
|
||||||
|
|
||||||
class BpeTokenizer(AutoTokenizer):
|
|
||||||
"""BPE tokenizer implementation."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
special_token_map: Dict[str, str] = None,
|
|
||||||
path: Optional[str] = None,
|
|
||||||
chat_template: Optional[str] = None,
|
|
||||||
):
|
|
||||||
special_token_map = special_token_map or {
|
|
||||||
"bos": "<|begin▁of▁sentence|>",
|
|
||||||
"eos": "<|end▁of▁sentence|>",
|
|
||||||
"pad": "<|▁pad▁|>",
|
|
||||||
"im_start": "<|im▁start|>",
|
|
||||||
"im_end": "<|im▁end|>",
|
|
||||||
}
|
|
||||||
self._tokenizer = None
|
|
||||||
self._init_tokenizer()
|
|
||||||
super().__init__(
|
|
||||||
path, special_token_map=special_token_map, chat_template=chat_template
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_tokenizer(self):
|
|
||||||
"""Initialize a new BPE tokenizer with default settings."""
|
|
||||||
model = BPE()
|
|
||||||
self._tokenizer = Tokenizer(model)
|
|
||||||
self._tokenizer.normalizer = normalizers.Sequence(
|
|
||||||
[normalizers.NFC(), normalizers.Strip()]
|
|
||||||
)
|
|
||||||
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
|
||||||
[
|
|
||||||
pre_tokenizers.UnicodeScripts(),
|
|
||||||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self._tokenizer.decoder = decoders.ByteLevel()
|
|
||||||
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
|
||||||
|
|
|
||||||
|
|
@ -1,108 +0,0 @@
|
||||||
"""
|
|
||||||
BPE Tokenizer Trainer module.
|
|
||||||
|
|
||||||
Provides training functionality for BPE tokenizers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers
|
|
||||||
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
|
|
||||||
|
|
||||||
|
|
||||||
class BpeTrainer:
|
|
||||||
"""BPE tokenizer trainer."""
|
|
||||||
|
|
||||||
def __init__(self, tokenizer):
|
|
||||||
"""Initialize trainer with a tokenizer instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokenizer: A BpeTokenizer instance
|
|
||||||
"""
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
def _prepare_trainer(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
min_freq: int,
|
|
||||||
reserved_token_size: int,
|
|
||||||
max_token_length: int = 18,
|
|
||||||
):
|
|
||||||
"""Prepare the BPE trainer with proper configuration."""
|
|
||||||
assert reserved_token_size > len(self.tokenizer._special_tokens)
|
|
||||||
reserved_tokens = [
|
|
||||||
f"<|reserve{i:02d}|>"
|
|
||||||
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
|
|
||||||
]
|
|
||||||
detail_vocab_size = vocab_size - (
|
|
||||||
len(reserved_tokens) + len(self.tokenizer._special_tokens)
|
|
||||||
)
|
|
||||||
alphabet = pre_tokenizers.ByteLevel.alphabet()
|
|
||||||
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
|
|
||||||
assert detail_vocab_size > min_size
|
|
||||||
|
|
||||||
trainer = BpeTrainerImpl(
|
|
||||||
vocab_size=detail_vocab_size,
|
|
||||||
min_frequency=min_freq,
|
|
||||||
limit_alphabet=detail_vocab_size // 6,
|
|
||||||
max_token_length=max_token_length,
|
|
||||||
special_tokens=self.tokenizer._control_tokens,
|
|
||||||
initial_alphabet=alphabet,
|
|
||||||
show_progress=True,
|
|
||||||
)
|
|
||||||
return trainer, reserved_tokens
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
files: Union[str, List[str]],
|
|
||||||
vocab_size: int,
|
|
||||||
min_freq: int,
|
|
||||||
reserved_token_size: int = 100,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""Train tokenizer from files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: Path or list of paths to training files
|
|
||||||
vocab_size: Target vocabulary size
|
|
||||||
min_freq: Minimum frequency for tokens
|
|
||||||
reserved_token_size: Number of reserved tokens
|
|
||||||
**kwargs: Additional arguments
|
|
||||||
"""
|
|
||||||
trainer, reserved_tokens = self._prepare_trainer(
|
|
||||||
vocab_size, min_freq, reserved_token_size, **kwargs
|
|
||||||
)
|
|
||||||
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
|
|
||||||
self.tokenizer._tokenizer.add_special_tokens(
|
|
||||||
self.tokenizer._special_tokens + reserved_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def train_from_iterator(
|
|
||||||
self,
|
|
||||||
iterator,
|
|
||||||
vocab_size: int,
|
|
||||||
min_freq: int,
|
|
||||||
reserved_token_size: int = 100,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""Train tokenizer from iterator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
iterator: Iterator yielding training strings
|
|
||||||
vocab_size: Target vocabulary size
|
|
||||||
min_freq: Minimum frequency for tokens
|
|
||||||
reserved_token_size: Number of reserved tokens
|
|
||||||
**kwargs: Additional arguments
|
|
||||||
"""
|
|
||||||
trainer, reserved_tokens = self._prepare_trainer(
|
|
||||||
vocab_size, min_freq, reserved_token_size, **kwargs
|
|
||||||
)
|
|
||||||
self.tokenizer._tokenizer.train_from_iterator(
|
|
||||||
iterator=iterator, trainer=trainer
|
|
||||||
)
|
|
||||||
self.tokenizer._tokenizer.add_special_tokens(
|
|
||||||
self.tokenizer._special_tokens + reserved_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["BpeTrainer"]
|
|
||||||
|
|
@ -7,12 +7,27 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
from tokenizers import pre_tokenizers
|
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
from astrai.tokenize import BpeTokenizer, BpeTrainer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
|
||||||
|
"""Create a simple tokenizer for testing purposes."""
|
||||||
|
tokenizer = Tokenizer(models.BPE())
|
||||||
|
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
|
||||||
|
trainer = trainers.BpeTrainer(
|
||||||
|
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
|
||||||
|
)
|
||||||
|
# Train on empty iterator with single character
|
||||||
|
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
|
||||||
|
auto_tokenizer = AutoTokenizer()
|
||||||
|
auto_tokenizer._tokenizer = tokenizer
|
||||||
|
auto_tokenizer._special_token_map = {"unk_token": "<unk>", "pad_token": "<pad>"}
|
||||||
|
return auto_tokenizer
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset(Dataset):
|
class RandomDataset(Dataset):
|
||||||
|
|
@ -109,7 +124,7 @@ def base_test_env(request: pytest.FixtureRequest):
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
transformer_config = ModelConfig().load(config_path)
|
transformer_config = ModelConfig().load(config_path)
|
||||||
model = Transformer(transformer_config).to(device=device)
|
model = Transformer(transformer_config).to(device=device)
|
||||||
tokenizer = BpeTokenizer()
|
tokenizer = create_test_tokenizer()
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"device": device,
|
"device": device,
|
||||||
|
|
@ -164,10 +179,7 @@ def test_env(request: pytest.FixtureRequest):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
|
|
||||||
tokenizer = BpeTokenizer()
|
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
|
||||||
trainer = BpeTrainer(tokenizer)
|
|
||||||
sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet())
|
|
||||||
trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1)
|
|
||||||
tokenizer.save(tokenizer_path)
|
tokenizer.save(tokenizer_path)
|
||||||
|
|
||||||
transformer_config = ModelConfig().load(config_path)
|
transformer_config = ModelConfig().load(config_path)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue