diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index f1e6bb7..29ef4ea 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -61,7 +61,7 @@ flowchart LR #### 1.1 Tokenizer (`tokenizer.py`) - Implemented based on Byte-Level BPE (BBPE) -- Supports special tokens: ``, ``, ``, `<|im_start|>`, `<|im_end|>` +- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>` - Provides `encode`/`decode` methods for mutual conversion between text and token IDs - Learns vocabulary from corpus during training, saved as `.json` files diff --git a/astrai/data/tokenizer.py b/astrai/data/tokenizer.py index ff850e1..1b1ef61 100644 --- a/astrai/data/tokenizer.py +++ b/astrai/data/tokenizer.py @@ -1,86 +1,110 @@ -from tokenizers import Tokenizer, Encoding -from tokenizers import decoders, processors, normalizers, pre_tokenizers +from abc import ABC, abstractmethod +from tokenizers import Tokenizer, decoders, processors, normalizers, pre_tokenizers from tokenizers.models import BPE -from tokenizers.trainers import BpeTrainer +from tokenizers.trainers import BpeTrainer as BpeTrainerImpl from typing import List, Union -class BpeTokenizer: - def __init__(self, path=None): - self._control_tokens = ["", "", ""] - self._special_tokens = ["<|im_start|>", "<|im_end|>"] +class BaseTokenizer(ABC): + @abstractmethod + def _init_tokenizer(self): + pass + @abstractmethod + def save(self, path): + pass + + @abstractmethod + def load(self, path): + pass + + @abstractmethod + def encode( + self, + tokens: Union[str, List[str]], + out_ids: bool = True, + add_special_tokens: bool = False, + ) -> List: + pass + + @abstractmethod + def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str: + pass + + @abstractmethod + def __len__(self) -> int: + pass + + @property + @abstractmethod + def stop_ids(self) -> List[int]: + pass + + @property + @abstractmethod + def bos_id(self) -> int: + pass + + @property + @abstractmethod + def eos_id(self) -> int: + pass + + @property + @abstractmethod + def pad_id(self) -> int: + pass + + +class BaseTrainer(ABC): + def __init__(self, tokenizer: BaseTokenizer): + self.tokenizer = tokenizer + + @abstractmethod + def train(self, files, vocab_size, min_freq, **kwargs): + pass + + @abstractmethod + def train_from_iterator(self, iterator, vocab_size, min_freq, **kwargs): + pass + + +class BpeTokenizer(BaseTokenizer): + def __init__( + self, + control_tokens: List[str] = None, + special_tokens: List[str] = None, + path=None, + ): + self._control_tokens = control_tokens or [ + "<|begin▁of▁sentence|>", + "<|end▁of▁sentence|>", + "<|▁pad▁|>", + ] + self._special_tokens = special_tokens or [ + "<|im▁start|>", + "<|im▁end|>", + ] + self._tokenizer = None + self._init_tokenizer() + if path is not None: + self.load(path) + + def _init_tokenizer(self): model = BPE() self._tokenizer = Tokenizer(model) self._tokenizer.normalizer = normalizers.Sequence( [normalizers.NFC(), normalizers.Strip()] ) - self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence( [ pre_tokenizers.UnicodeScripts(), pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True), ] ) - self._tokenizer.decoder = decoders.ByteLevel() self._tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) - if path is not None: - self._tokenizer = Tokenizer.from_file(path) - - def _prepare_trainer( - self, - vocab_size: int, - min_freq: int, - reserved_token_size: int, - max_token_length=18, - ) -> tuple: - assert reserved_token_size > len(self._special_tokens) - reserved_tokens = [ - f"<|reserve{i:02d}|>" - for i in range(reserved_token_size - len(self._special_tokens)) - ] - detail_vocab_size = vocab_size - ( - len(reserved_tokens) + len(self._special_tokens) - ) - - alphabet = pre_tokenizers.ByteLevel.alphabet() - min_size = len(alphabet) + len(self._control_tokens) - assert detail_vocab_size > min_size - - trainer = BpeTrainer( - vocab_size=detail_vocab_size, - min_frequency=min_freq, - limit_alphabet=detail_vocab_size // 6, - max_token_length=max_token_length, - special_tokens=self._control_tokens, - initial_alphabet=alphabet, - show_progress=True, - ) - - return trainer, detail_vocab_size, reserved_tokens - - def train(self, files, vocab_size, min_freq, reserved_token_size=100): - trainer, _, reserved_tokens = self._prepare_trainer( - vocab_size=vocab_size, - min_freq=min_freq, - reserved_token_size=reserved_token_size, - ) - self._tokenizer.train(files=files, trainer=trainer) - self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) - - def train_from_iterator( - self, iterator, vocab_size, min_freq, reserved_token_size=100 - ): - trainer, _, reserved_tokens = self._prepare_trainer( - vocab_size=vocab_size, - min_freq=min_freq, - reserved_token_size=reserved_token_size, - ) - self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer) - self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens) - def save(self, path): self._tokenizer.save(path) @@ -94,12 +118,12 @@ class BpeTokenizer: add_special_tokens: bool = False, ) -> List: if isinstance(tokens, str): - encoded: Encoding = self._tokenizer.encode( + encoded = self._tokenizer.encode( tokens, add_special_tokens=add_special_tokens ) return encoded.ids if out_ids else encoded.tokens - elif isinstance(tokens, list): - encoded_list: List[Encoding] = self._tokenizer.encode_batch( + else: + encoded_list = self._tokenizer.encode_batch( tokens, add_special_tokens=add_special_tokens ) return [ @@ -115,17 +139,73 @@ class BpeTokenizer: @property def stop_ids(self) -> List[int]: stop_token = self._control_tokens + self._special_tokens - stop_ids = [self._tokenizer.token_to_id(token) for token in stop_token] - return stop_ids + return [self._tokenizer.token_to_id(tok) for tok in stop_token] @property def bos_id(self) -> int: - return self._tokenizer.token_to_id("") + return self._tokenizer.token_to_id(self._control_tokens[0]) @property def eos_id(self) -> int: - return self._tokenizer.token_to_id("") + return self._tokenizer.token_to_id(self._control_tokens[1]) @property def pad_id(self) -> int: - return self._tokenizer.token_to_id("") + return self._tokenizer.token_to_id(self._control_tokens[2]) + + +class BpeTrainer(BaseTrainer): + def __init__(self, tokenizer: BaseTokenizer): + super().__init__(tokenizer) + + def _prepare_trainer( + self, + vocab_size: int, + min_freq: int, + reserved_token_size: int, + max_token_length=18, + ): + assert reserved_token_size > len(self.tokenizer._special_tokens) + reserved_tokens = [ + f"<|reserve{i:02d}|>" + for i in range(reserved_token_size - len(self.tokenizer._special_tokens)) + ] + detail_vocab_size = vocab_size - ( + len(reserved_tokens) + len(self.tokenizer._special_tokens) + ) + alphabet = pre_tokenizers.ByteLevel.alphabet() + min_size = len(alphabet) + len(self.tokenizer._control_tokens) + assert detail_vocab_size > min_size + + trainer = BpeTrainerImpl( + vocab_size=detail_vocab_size, + min_frequency=min_freq, + limit_alphabet=detail_vocab_size // 6, + max_token_length=max_token_length, + special_tokens=self.tokenizer._control_tokens, + initial_alphabet=alphabet, + show_progress=True, + ) + return trainer, reserved_tokens + + def train(self, files, vocab_size, min_freq, reserved_token_size=100, **kwargs): + trainer, reserved_tokens = self._prepare_trainer( + vocab_size, min_freq, reserved_token_size, **kwargs + ) + self.tokenizer._tokenizer.train(files=files, trainer=trainer) + self.tokenizer._tokenizer.add_special_tokens( + self.tokenizer._special_tokens + reserved_tokens + ) + + def train_from_iterator( + self, iterator, vocab_size, min_freq, reserved_token_size=100, **kwargs + ): + trainer, reserved_tokens = self._prepare_trainer( + vocab_size, min_freq, reserved_token_size, **kwargs + ) + self.tokenizer._tokenizer.train_from_iterator( + iterator=iterator, trainer=trainer + ) + self.tokenizer._tokenizer.add_special_tokens( + self.tokenizer._special_tokens + reserved_tokens + ) diff --git a/astrai/inference/generator.py b/astrai/inference/generator.py index 89d10f3..2974c59 100644 --- a/astrai/inference/generator.py +++ b/astrai/inference/generator.py @@ -28,16 +28,16 @@ def build_prompt( result = "" if system_prompt: - result += f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + result += f"<|im▁start|>system\n{system_prompt}<|im▁end|>\n" # (convert tuple format to ChatML) if history: for user_msg, assistant_msg in history: - result += f"<|im_start|>user\n{user_msg}<|im_end|>\n" - result += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" + result += f"<|im▁start|>user\n{user_msg}<|im▁end|>\n" + result += f"<|im▁start|>assistant\n{assistant_msg}<|im▁end|>\n" - result += f"<|im_start|>user\n{query}<|im_end|>\n" - result += "<|im_start|>assistant\n" + result += f"<|im▁start|>user\n{query}<|im▁end|>\n" + result += "<|im▁start|>assistant\n" return result