refactor: 修改分词器部分结构, 更新特殊token等

This commit is contained in:
ViperEkura 2026-04-03 14:52:35 +08:00
parent 94c6a015c8
commit c5560740b6
3 changed files with 158 additions and 78 deletions

View File

@ -61,7 +61,7 @@ flowchart LR
#### 1.1 Tokenizer (`tokenizer.py`) #### 1.1 Tokenizer (`tokenizer.py`)
- Implemented based on Byte-Level BPE (BBPE) - Implemented based on Byte-Level BPE (BBPE)
- Supports special tokens: `<bos>`, `<eos>`, `<pad>`, `<|im_start|>`, `<|im_end|>` - Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs - Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Learns vocabulary from corpus during training, saved as `.json` files - Learns vocabulary from corpus during training, saved as `.json` files

View File

@ -1,86 +1,110 @@
from tokenizers import Tokenizer, Encoding from abc import ABC, abstractmethod
from tokenizers import decoders, processors, normalizers, pre_tokenizers from tokenizers import Tokenizer, decoders, processors, normalizers, pre_tokenizers
from tokenizers.models import BPE from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer from tokenizers.trainers import BpeTrainer as BpeTrainerImpl
from typing import List, Union from typing import List, Union
class BpeTokenizer: class BaseTokenizer(ABC):
def __init__(self, path=None): @abstractmethod
self._control_tokens = ["<bos>", "<eos>", "<pad>"] def _init_tokenizer(self):
self._special_tokens = ["<|im_start|>", "<|im_end|>"] pass
@abstractmethod
def save(self, path):
pass
@abstractmethod
def load(self, path):
pass
@abstractmethod
def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
add_special_tokens: bool = False,
) -> List:
pass
@abstractmethod
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
pass
@abstractmethod
def __len__(self) -> int:
pass
@property
@abstractmethod
def stop_ids(self) -> List[int]:
pass
@property
@abstractmethod
def bos_id(self) -> int:
pass
@property
@abstractmethod
def eos_id(self) -> int:
pass
@property
@abstractmethod
def pad_id(self) -> int:
pass
class BaseTrainer(ABC):
def __init__(self, tokenizer: BaseTokenizer):
self.tokenizer = tokenizer
@abstractmethod
def train(self, files, vocab_size, min_freq, **kwargs):
pass
@abstractmethod
def train_from_iterator(self, iterator, vocab_size, min_freq, **kwargs):
pass
class BpeTokenizer(BaseTokenizer):
def __init__(
self,
control_tokens: List[str] = None,
special_tokens: List[str] = None,
path=None,
):
self._control_tokens = control_tokens or [
"<begin▁of▁sentence>",
"<end▁of▁sentence>",
"<▁pad▁>",
]
self._special_tokens = special_tokens or [
"<im▁start>",
"<im▁end>",
]
self._tokenizer = None
self._init_tokenizer()
if path is not None:
self.load(path)
def _init_tokenizer(self):
model = BPE() model = BPE()
self._tokenizer = Tokenizer(model) self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence( self._tokenizer.normalizer = normalizers.Sequence(
[normalizers.NFC(), normalizers.Strip()] [normalizers.NFC(), normalizers.Strip()]
) )
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[ [
pre_tokenizers.UnicodeScripts(), pre_tokenizers.UnicodeScripts(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True), pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
] ]
) )
self._tokenizer.decoder = decoders.ByteLevel() self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
if path is not None:
self._tokenizer = Tokenizer.from_file(path)
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length=18,
) -> tuple:
assert reserved_token_size > len(self._special_tokens)
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainer(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 6,
max_token_length=max_token_length,
special_tokens=self._control_tokens,
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, detail_vocab_size, reserved_tokens
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size,
)
self._tokenizer.train(files=files, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def train_from_iterator(
self, iterator, vocab_size, min_freq, reserved_token_size=100
):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size,
)
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def save(self, path): def save(self, path):
self._tokenizer.save(path) self._tokenizer.save(path)
@ -94,12 +118,12 @@ class BpeTokenizer:
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> List: ) -> List:
if isinstance(tokens, str): if isinstance(tokens, str):
encoded: Encoding = self._tokenizer.encode( encoded = self._tokenizer.encode(
tokens, add_special_tokens=add_special_tokens tokens, add_special_tokens=add_special_tokens
) )
return encoded.ids if out_ids else encoded.tokens return encoded.ids if out_ids else encoded.tokens
elif isinstance(tokens, list): else:
encoded_list: List[Encoding] = self._tokenizer.encode_batch( encoded_list = self._tokenizer.encode_batch(
tokens, add_special_tokens=add_special_tokens tokens, add_special_tokens=add_special_tokens
) )
return [ return [
@ -115,17 +139,73 @@ class BpeTokenizer:
@property @property
def stop_ids(self) -> List[int]: def stop_ids(self) -> List[int]:
stop_token = self._control_tokens + self._special_tokens stop_token = self._control_tokens + self._special_tokens
stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token] return [self._tokenizer.token_to_id(tok) for tok in stop_token]
return stop_ids
@property @property
def bos_id(self) -> int: def bos_id(self) -> int:
return self._tokenizer.token_to_id("<bos>") return self._tokenizer.token_to_id(self._control_tokens[0])
@property @property
def eos_id(self) -> int: def eos_id(self) -> int:
return self._tokenizer.token_to_id("<eos>") return self._tokenizer.token_to_id(self._control_tokens[1])
@property @property
def pad_id(self) -> int: def pad_id(self) -> int:
return self._tokenizer.token_to_id("<pad>") return self._tokenizer.token_to_id(self._control_tokens[2])
class BpeTrainer(BaseTrainer):
def __init__(self, tokenizer: BaseTokenizer):
super().__init__(tokenizer)
def _prepare_trainer(
self,
vocab_size: int,
min_freq: int,
reserved_token_size: int,
max_token_length=18,
):
assert reserved_token_size > len(self.tokenizer._special_tokens)
reserved_tokens = [
f"<|reserve{i:02d}|>"
for i in range(reserved_token_size - len(self.tokenizer._special_tokens))
]
detail_vocab_size = vocab_size - (
len(reserved_tokens) + len(self.tokenizer._special_tokens)
)
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self.tokenizer._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainerImpl(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 6,
max_token_length=max_token_length,
special_tokens=self.tokenizer._control_tokens,
initial_alphabet=alphabet,
show_progress=True,
)
return trainer, reserved_tokens
def train(self, files, vocab_size, min_freq, reserved_token_size=100, **kwargs):
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train(files=files, trainer=trainer)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)
def train_from_iterator(
self, iterator, vocab_size, min_freq, reserved_token_size=100, **kwargs
):
trainer, reserved_tokens = self._prepare_trainer(
vocab_size, min_freq, reserved_token_size, **kwargs
)
self.tokenizer._tokenizer.train_from_iterator(
iterator=iterator, trainer=trainer
)
self.tokenizer._tokenizer.add_special_tokens(
self.tokenizer._special_tokens + reserved_tokens
)

View File

@ -28,16 +28,16 @@ def build_prompt(
result = "" result = ""
if system_prompt: if system_prompt:
result += f"<|im_start|>system\n{system_prompt}<|im_end|>\n" result += f"<im▁start>system\n{system_prompt}<im▁end>\n"
# (convert tuple format to ChatML) # (convert tuple format to ChatML)
if history: if history:
for user_msg, assistant_msg in history: for user_msg, assistant_msg in history:
result += f"<|im_start|>user\n{user_msg}<|im_end|>\n" result += f"<im▁start>user\n{user_msg}<im▁end>\n"
result += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" result += f"<im▁start>assistant\n{assistant_msg}<im▁end>\n"
result += f"<|im_start|>user\n{query}<|im_end|>\n" result += f"<im▁start>user\n{query}<im▁end>\n"
result += "<|im_start|>assistant\n" result += "<im▁start>assistant\n"
return result return result