AstrAI/astrai/tokenize/tokenizer.py

275 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Tokenizer module with implementation and auto-loading support.
"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Union
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE
from astrai.tokenize.chat_template import ChatTemplate
class TextTokenizer:
"""Base tokenizer class with automatic loading support"""
TOKENIZER_CLASSES = {} # Registry for auto-loading
def __init__(
self,
path: Optional[Union[str, Path]] = None,
special_token_map: Optional[Dict[str, str]] = None,
chat_template: Optional[str] = None,
):
self._tokenizer: Tokenizer = None
self._chat_template: Optional[ChatTemplate] = None
self._special_token_map: Optional[Dict] = special_token_map or {}
if chat_template:
self.set_chat_template(chat_template)
if path:
self.load(path)
def load(self, path: Union[str, Path]):
"""Load tokenizer from directory."""
path = Path(path)
tokenizer_file = path / "tokenizer.json"
config_file = path / "tokenizer_config.json"
self._tokenizer = Tokenizer.from_file(str(tokenizer_file))
if config_file.exists():
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
if "special_tokens" in config:
self._special_token_map.update(config["special_tokens"])
# Load chat template from config
if "chat_template" in config:
self.set_chat_template(config["chat_template"])
@classmethod
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "TextTokenizer":
"""Load tokenizer from pretrained directory."""
instance = cls(path)
return instance
def save_pretrained(self, tokenizer, save_path: str):
"""
Save tokenizer to pretrained directory.
Args:
tokenizer: Tokenizer instance to save
save_path: Path to save the tokenizer
"""
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
self._tokenizer.save(tokenizer, save_path)
@classmethod
def register_tokenizer(cls, name: str, tokenizer_class: type):
"""
Register a new tokenizer class.
Args:
name: Name to register the tokenizer class under
tokenizer_class: The tokenizer class to register
"""
cls.TOKENIZER_CLASSES[name] = tokenizer_class
def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
is_pretokenized: bool = False,
add_special_tokens: bool = True,
) -> List:
"""Encode text to tokens or token IDs."""
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
if isinstance(tokens, str):
encoded = self._tokenizer.encode(
tokens,
is_pretokenized=is_pretokenized,
add_special_tokens=add_special_tokens,
)
return encoded.ids if out_ids else encoded.tokens
else:
encoded_list = self._tokenizer.encode_batch(
tokens,
is_pretokenized=is_pretokenized,
add_special_tokens=add_special_tokens,
)
return [
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
]
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
"""Decode token IDs to text."""
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def __len__(self) -> int:
if self._tokenizer is None:
return 0
return self._tokenizer.get_vocab_size()
def __getattr__(self, key: str):
"""
Dynamically intercept special token attribute access.
Supports three forms:
- tokenizer.bos_token → returns string
- tokenizer.bos_token_id → returns corresponding integer ID
- tokenizer.stop_ids → returns list of corresponding integer IDs for all special tokens
"""
# Handle stop_ids - return IDs for all special tokens
if key == "stop_ids":
stop_ids = []
if self._tokenizer is None:
return stop_ids
for val in self._special_token_map.values():
token_id = self._tokenizer.token_to_id(val)
if token_id is not None:
stop_ids.append(token_id)
return stop_ids
# Handle _id suffix (e.g., bos_token_id -> bos_token)
if key.endswith("_id"):
base_attr = key[:-3] # Remove "_id"
token_str = self._special_token_map.get(base_attr)
if token_str is None:
return None
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded, cannot convert token to id.")
return self._tokenizer.token_to_id(token_str)
# Handle regular string attributes
if key in self._special_token_map:
return self._special_token_map.get(key)
# Other attributes trigger default AttributeError
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
@property
def vocab_size(self) -> int:
return len(self)
@property
def pad_id(self) -> Optional[int]:
"""Return the pad token ID if available."""
pad_token = self._special_token_map.get("pad")
if pad_token is None or self._tokenizer is None:
return None
return self._tokenizer.token_to_id(pad_token)
def set_chat_template(self, template: Union[str, ChatTemplate]):
"""
Set the chat template for the tokenizer.
Args:
template: Either a template name (str) registered in the global registry,
or a ChatTemplate instance, or a Jinja2 template string.
Raises:
KeyError: If template name is not registered.
"""
if isinstance(template, str):
self._chat_template = ChatTemplate.from_string(template)
elif isinstance(template, ChatTemplate):
self._chat_template = template
else:
raise ValueError("Invalid template type, must be str or ChatTemplate.")
def apply_chat_template(
self,
messages: List[Dict[str, str]],
system_prompt: Optional[str] = None,
tokenize: bool = True,
add_generation_prompt: bool = True,
**kwargs,
) -> Union[str, List[int]]:
"""
Apply the chat template to messages and optionally tokenize the result.
Args:
messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string.
tokenize: Whether to return token IDs (True) or raw string (False).
add_generation_prompt: Whether to add the generation prompt (default: False).
**kwargs: Additional variables to pass to the template.
Returns:
Either the rendered string or list of token IDs.
Raises:
RuntimeError: If chat template is not set.
"""
if self._chat_template is None:
raise RuntimeError(
"Chat template not set. Use set_chat_template() to set a template first."
)
# Render the template
rendered = self._chat_template.render(
messages=messages,
system_prompt=system_prompt,
add_generation_prompt=add_generation_prompt,
**kwargs,
)
if tokenize:
return self.encode(rendered)
return rendered
class BpeTokenizer(TextTokenizer):
"""BPE tokenizer implementation."""
def __init__(
self,
special_token_map: Dict[str, str] = None,
path: Optional[str] = None,
chat_template: Optional[str] = None,
):
special_token_map = special_token_map or {
"bos": "<begin▁of▁sentence>",
"eos": "<end▁of▁sentence>",
"pad": "<▁pad▁>",
"im_start": "<im▁start>",
"im_end": "<im▁end>",
}
self._tokenizer = None
self._init_tokenizer()
super().__init__(
path, special_token_map=special_token_map, chat_template=chat_template
)
def _init_tokenizer(self):
"""Initialize a new BPE tokenizer with default settings."""
model = BPE()
self._tokenizer = Tokenizer(model)
self._tokenizer.normalizer = normalizers.Sequence(
[normalizers.NFC(), normalizers.Strip()]
)
self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.UnicodeScripts(),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True),
]
)
self._tokenizer.decoder = decoders.ByteLevel()
self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)