refactor: 修改分词器部分结构, 更新特殊token等

This commit is contained in:
ViperEkura 2026-04-03 14:52:35 +08:00
parent 94c6a015c8
commit c5560740b6
3 changed files with 158 additions and 78 deletions

View File

@ -61,7 +61,7 @@ flowchart LR
#### 1.1 Tokenizer (`tokenizer.py`)
- Implemented based on Byte-Level BPE (BBPE)
- Supports special tokens: `<bos>`, `<eos>`, `<pad>`, `<|im_start|>`, `<|im_end|>`
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Learns vocabulary from corpus during training, saved as `.json` files

View File

@ -1,86 +1,110 @@
from tokenizers import Tokenizer, Encoding
from tokenizers import decoders, processors, normalizers, pre_tokenizers
from abc import ABC, abstractmethod
from tokenizers import Tokenizer, decoders, processors, normalizers, pre_tokenizers
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
from typing import List, Union
class BpeTokenizer:
def __init__(self, path=None):
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
self._special_tokens = ["<|im_start|>", "<|im_end|>"]
class BaseTokenizer(ABC):
@abstractmethod
def _init_tokenizer(self):
pass
@abstractmethod
def save(self, path):
pass
@abstractmethod
def load(self, path):
pass
@abstractmethod
def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
add_special_tokens: bool = False,
) -> List:
pass
@abstractmethod
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
pass
@abstractmethod
def __len__(self) -> int:
pass
@property
@abstractmethod
def stop_ids(self) -> List[int]:
pass
@property
@abstractmethod
def bos_id(self) -> int:
pass
@property
@abstractmethod
def eos_id(self) -> int:
pass
@property
@abstractmethod
def pad_id(self) -> int:
pass
class BaseTrainer(ABC):
def __init__(self, tokenizer: BaseTokenizer):
self.tokenizer = tokenizer
@abstractmethod
def train(self, files, vocab_size, min_freq, **kwargs):
pass
@abstractmethod
def train_from_iterator(self, iterator, vocab_size, min_freq, **kwargs):
pass
class BpeTokenizer(BaseTokenizer):
def __init__(
self,
control_tokens: List[str] = None,
special_tokens: List[str] = None,
path=None,
):
self._control_tokens = control_tokens or [
"<begin▁of▁sentence>",
"<end▁of▁sentence>",
"<▁pad▁>",
]
self._special_tokens = special_tokens or [
"<im▁start>",
"<im▁end>",
]
self._tokenizer = None
self._init_tokenizer()
if path is not None:
self.load(path)
def _init_tokenizer(self):
model = BPE()
self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence(
[normalizers.NFC(), normalizers.Strip()]
)
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.UnicodeScripts(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
]
)
self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
if path is not None:
self._tokenizer = Tokenizer.from_file(path)
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length=18,
) -> tuple:
assert reserved_token_size > len(self._special_tokens)
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainer(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 6,
max_token_length=max_token_length,
special_tokens=self._control_tokens,
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, detail_vocab_size, reserved_tokens
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size,
)
self._tokenizer.train(files=files, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def train_from_iterator(
self, iterator, vocab_size, min_freq, reserved_token_size=100
):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size,
)
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def save(self, path):
self._tokenizer.save(path)
@ -94,12 +118,12 @@ class BpeTokenizer:
add_special_tokens: bool = False,
) -> List:
if isinstance(tokens, str):
encoded: Encoding = self._tokenizer.encode(
encoded = self._tokenizer.encode(
tokens, add_special_tokens=add_special_tokens
)
return encoded.ids if out_ids else encoded.tokens
elif isinstance(tokens, list):
encoded_list: List[Encoding] = self._tokenizer.encode_batch(
else:
encoded_list = self._tokenizer.encode_batch(
tokens, add_special_tokens=add_special_tokens
)
return [
@ -115,17 +139,73 @@ class BpeTokenizer:
@property
def stop_ids(self) -> List[int]:
stop_token = self._control_tokens + self._special_tokens
stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token]
return stop_ids
return [self._tokenizer.token_to_id(tok) for tok in stop_token]
@property
def bos_id(self) -> int:
return self._tokenizer.token_to_id("<bos>")
return self._tokenizer.token_to_id(self._control_tokens[0])
@property
def eos_id(self) -> int:
return self._tokenizer.token_to_id("<eos>")
return self._tokenizer.token_to_id(self._control_tokens[1])
@property
def pad_id(self) -> int:
return self._tokenizer.token_to_id("<pad>")
return self._tokenizer.token_to_id(self._control_tokens[2])
class BpeTrainer(BaseTrainer):
def __init__(self, tokenizer: BaseTokenizer):
super().__init__(tokenizer)
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length=18,
):
assert reserved_token_size > len(self.tokenizer._special_tokens)
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self.tokenizer._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainerImpl(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 6,
max_token_length=max_token_length,
special_tokens=self.tokenizer._control_tokens,
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, reserved_tokens
def train(self, files, vocab_size, min_freq, reserved_token_size=100, **kwargs):
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
def train_from_iterator(
self, iterator, vocab_size, min_freq, reserved_token_size=100, **kwargs
):
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train_from_iterator(
iterator=iterator, trainer=trainer
)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)

View File

@ -28,16 +28,16 @@ def build_prompt(
result = ""
if system_prompt:
result += f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
result += f"<im▁start>system\n{system_prompt}<im▁end>\n"
# (convert tuple format to ChatML)
if history:
for user_msg, assistant_msg in history:
result += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
result += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
result += f"<im▁start>user\n{user_msg}<im▁end>\n"
result += f"<im▁start>assistant\n{assistant_msg}<im▁end>\n"
result += f"<|im_start|>user\n{query}<|im_end|>\n"
result += "<|im_start|>assistant\n"
result += f"<im▁start>user\n{query}<im▁end>\n"
result += "<im▁start>assistant\n"
return result