diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py index 6c449af..35d634e 100644 --- a/astrai/tokenize/tokenizer.py +++ b/astrai/tokenize/tokenizer.py @@ -64,10 +64,23 @@ class AutoTokenizer: Args: save_path: Path to save the tokenizer """ + save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) + + # Save tokenizer self._tokenizer.save(str(save_path / "tokenizer.json")) + # Save tokenizer config + config = {} + if self._special_token_map is not None: + config["special_tokens"] = self._special_token_map + if self._chat_template is not None: + config["chat_template"] = self._chat_template.template_str + + with open(save_path / "tokenizer_config.json", "w", encoding="utf-8") as f: + json.dump(config, f, ensure_ascii=False, indent=2) + @classmethod def register_tokenizer(cls, name: str, tokenizer_class: type): """ @@ -166,14 +179,6 @@ class AutoTokenizer: def vocab_size(self) -> int: return len(self) - @property - def pad_id(self) -> Optional[int]: - """Return the pad token ID if available.""" - pad_token = self._special_token_map.get("pad") - if pad_token is None or self._tokenizer is None: - return None - return self._tokenizer.token_to_id(pad_token) - def set_chat_template(self, template: Union[str, ChatTemplate]): """ Set the chat template for the tokenizer.