280 lines
9.4 KiB
Python
280 lines
9.4 KiB
Python
"""
|
||
Tokenizer module with implementation and auto-loading support.
|
||
"""
|
||
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Union
|
||
|
||
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
||
from tokenizers.models import BPE
|
||
|
||
from astrai.tokenize.chat_template import ChatTemplate
|
||
|
||
|
||
class AutoTokenizer:
|
||
"""Base tokenizer class with automatic loading support"""
|
||
|
||
TOKENIZER_CLASSES = {} # Registry for auto-loading
|
||
|
||
def __init__(
|
||
self,
|
||
path: Optional[Union[str, Path]] = None,
|
||
special_token_map: Optional[Dict[str, str]] = None,
|
||
chat_template: Optional[str] = None,
|
||
):
|
||
self._tokenizer: Tokenizer = None
|
||
self._chat_template: Optional[ChatTemplate] = None
|
||
self._special_token_map: Optional[Dict] = special_token_map or {}
|
||
|
||
if chat_template:
|
||
self.set_chat_template(chat_template)
|
||
|
||
if path:
|
||
self.load(path)
|
||
|
||
def load(self, path: Union[str, Path]):
|
||
"""Load tokenizer from directory."""
|
||
path = Path(path)
|
||
tokenizer_file = path / "tokenizer.json"
|
||
config_file = path / "tokenizer_config.json"
|
||
self._tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
||
|
||
if config_file.exists():
|
||
with open(config_file, "r", encoding="utf-8") as f:
|
||
config = json.load(f)
|
||
|
||
if "special_tokens" in config:
|
||
self._special_token_map.update(config["special_tokens"])
|
||
|
||
# Load chat template from config
|
||
if "chat_template" in config:
|
||
self.set_chat_template(config["chat_template"])
|
||
|
||
@classmethod
|
||
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
|
||
"""Load tokenizer from pretrained directory."""
|
||
instance = cls(path)
|
||
return instance
|
||
|
||
def save_pretrained(self, save_path: str):
|
||
"""
|
||
Save tokenizer to pretrained directory.
|
||
|
||
Args:
|
||
save_path: Path to save the tokenizer
|
||
"""
|
||
|
||
save_path = Path(save_path)
|
||
save_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Save tokenizer
|
||
self._tokenizer.save(str(save_path / "tokenizer.json"))
|
||
|
||
# Save tokenizer config
|
||
config = {}
|
||
if self._special_token_map is not None:
|
||
config["special_tokens"] = self._special_token_map
|
||
if self._chat_template is not None:
|
||
config["chat_template"] = self._chat_template.template_str
|
||
|
||
with open(save_path / "tokenizer_config.json", "w", encoding="utf-8") as f:
|
||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||
|
||
@classmethod
|
||
def register_tokenizer(cls, name: str, tokenizer_class: type):
|
||
"""
|
||
Register a new tokenizer class.
|
||
|
||
Args:
|
||
name: Name to register the tokenizer class under
|
||
tokenizer_class: The tokenizer class to register
|
||
"""
|
||
cls.TOKENIZER_CLASSES[name] = tokenizer_class
|
||
|
||
def encode(
|
||
self,
|
||
tokens: Union[str, List[str]],
|
||
out_ids: bool = True,
|
||
is_pretokenized: bool = False,
|
||
add_special_tokens: bool = True,
|
||
) -> List:
|
||
"""Encode text to tokens or token IDs."""
|
||
if self._tokenizer is None:
|
||
raise RuntimeError(
|
||
"Tokenizer not initialized. Load or create a tokenizer first."
|
||
)
|
||
|
||
if isinstance(tokens, str):
|
||
encoded = self._tokenizer.encode(
|
||
tokens,
|
||
is_pretokenized=is_pretokenized,
|
||
add_special_tokens=add_special_tokens,
|
||
)
|
||
return encoded.ids if out_ids else encoded.tokens
|
||
else:
|
||
encoded_list = self._tokenizer.encode_batch(
|
||
tokens,
|
||
is_pretokenized=is_pretokenized,
|
||
add_special_tokens=add_special_tokens,
|
||
)
|
||
return [
|
||
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
|
||
]
|
||
|
||
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
|
||
"""Decode token IDs to text."""
|
||
if self._tokenizer is None:
|
||
raise RuntimeError(
|
||
"Tokenizer not initialized. Load or create a tokenizer first."
|
||
)
|
||
|
||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||
|
||
def __len__(self) -> int:
|
||
if self._tokenizer is None:
|
||
return 0
|
||
return self._tokenizer.get_vocab_size()
|
||
|
||
def __getattr__(self, key: str):
|
||
"""
|
||
Dynamically intercept special token attribute access.
|
||
Supports three forms:
|
||
- tokenizer.bos_token → returns string
|
||
- tokenizer.bos_token_id → returns corresponding integer ID
|
||
- tokenizer.stop_ids → returns list of corresponding integer IDs for all special tokens
|
||
"""
|
||
# Handle stop_ids - return IDs for all special tokens
|
||
if key == "stop_ids":
|
||
stop_ids = []
|
||
|
||
if self._tokenizer is None:
|
||
return stop_ids
|
||
|
||
for val in self._special_token_map.values():
|
||
token_id = self._tokenizer.token_to_id(val)
|
||
if token_id is not None:
|
||
stop_ids.append(token_id)
|
||
|
||
return stop_ids
|
||
|
||
# Handle _id suffix (e.g., bos_token_id -> bos_token)
|
||
if key.endswith("_id"):
|
||
base_attr = key[:-3] # Remove "_id"
|
||
token_str = self._special_token_map.get(base_attr)
|
||
if token_str is None:
|
||
return None
|
||
if self._tokenizer is None:
|
||
raise RuntimeError("Tokenizer not loaded, cannot convert token to id.")
|
||
return self._tokenizer.token_to_id(token_str)
|
||
|
||
# Handle regular string attributes
|
||
if key in self._special_token_map:
|
||
return self._special_token_map.get(key)
|
||
|
||
# Other attributes trigger default AttributeError
|
||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
|
||
|
||
@property
|
||
def vocab_size(self) -> int:
|
||
return len(self)
|
||
|
||
def set_chat_template(self, template: Union[str, ChatTemplate]):
|
||
"""
|
||
Set the chat template for the tokenizer.
|
||
|
||
Args:
|
||
template: Either a template name (str) registered in the global registry,
|
||
or a ChatTemplate instance, or a Jinja2 template string.
|
||
|
||
Raises:
|
||
KeyError: If template name is not registered.
|
||
"""
|
||
if isinstance(template, str):
|
||
self._chat_template = ChatTemplate.from_string(template)
|
||
elif isinstance(template, ChatTemplate):
|
||
self._chat_template = template
|
||
else:
|
||
raise ValueError("Invalid template type, must be str or ChatTemplate.")
|
||
|
||
def apply_chat_template(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
system_prompt: Optional[str] = None,
|
||
tokenize: bool = True,
|
||
add_generation_prompt: bool = True,
|
||
**kwargs,
|
||
) -> Union[str, List[int]]:
|
||
"""
|
||
Apply the chat template to messages and optionally tokenize the result.
|
||
|
||
Args:
|
||
messages: List of message dicts with 'role' and 'content'.
|
||
system_prompt: Optional system prompt string.
|
||
tokenize: Whether to return token IDs (True) or raw string (False).
|
||
add_generation_prompt: Whether to add the generation prompt (default: False).
|
||
**kwargs: Additional variables to pass to the template.
|
||
|
||
Returns:
|
||
Either the rendered string or list of token IDs.
|
||
|
||
Raises:
|
||
RuntimeError: If chat template is not set.
|
||
"""
|
||
if self._chat_template is None:
|
||
raise RuntimeError(
|
||
"Chat template not set. Use set_chat_template() to set a template first."
|
||
)
|
||
|
||
# Render the template
|
||
rendered = self._chat_template.render(
|
||
messages=messages,
|
||
system_prompt=system_prompt,
|
||
add_generation_prompt=add_generation_prompt,
|
||
**kwargs,
|
||
)
|
||
|
||
if tokenize:
|
||
return self.encode(rendered)
|
||
|
||
return rendered
|
||
|
||
|
||
class BpeTokenizer(AutoTokenizer):
|
||
"""BPE tokenizer implementation."""
|
||
|
||
def __init__(
|
||
self,
|
||
special_token_map: Dict[str, str] = None,
|
||
path: Optional[str] = None,
|
||
chat_template: Optional[str] = None,
|
||
):
|
||
special_token_map = special_token_map or {
|
||
"bos": "<|begin▁of▁sentence|>",
|
||
"eos": "<|end▁of▁sentence|>",
|
||
"pad": "<|▁pad▁|>",
|
||
"im_start": "<|im▁start|>",
|
||
"im_end": "<|im▁end|>",
|
||
}
|
||
self._tokenizer = None
|
||
self._init_tokenizer()
|
||
super().__init__(
|
||
path, special_token_map=special_token_map, chat_template=chat_template
|
||
)
|
||
|
||
def _init_tokenizer(self):
|
||
"""Initialize a new BPE tokenizer with default settings."""
|
||
model = BPE()
|
||
self._tokenizer = Tokenizer(model)
|
||
self._tokenizer.normalizer = normalizers.Sequence(
|
||
[normalizers.NFC(), normalizers.Strip()]
|
||
)
|
||
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||
[
|
||
pre_tokenizers.UnicodeScripts(),
|
||
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
|
||
]
|
||
)
|
||
self._tokenizer.decoder = decoders.ByteLevel()
|
||
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|