From d2b36cc85daf00792aba645c66b9cc29235acbcb Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 5 Apr 2026 20:09:47 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E7=89=B9=E6=AE=8Atoke?= =?UTF-8?q?n=20=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/tokenize/tokenizer.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py index c6e2fc7..a41b847 100644 --- a/astrai/tokenize/tokenizer.py +++ b/astrai/tokenize/tokenizer.py @@ -129,14 +129,21 @@ class TextTokenizer: Supports three forms: - tokenizer.bos_token → returns string - tokenizer.bos_token_id → returns corresponding integer ID - - tokenizer.stop_ids → returns list of corresponding integer IDs + - tokenizer.stop_ids → returns list of corresponding integer IDs for all special tokens """ - # Handle stop_ids + # Handle stop_ids - return IDs for all special tokens if key == "stop_ids": - return [ - self._special_token_map.get(val) - for val in self._special_token_map.values() - ] + stop_ids = [] + + if self._tokenizer is None: + return stop_ids + + for val in self._special_token_map.values(): + token_id = self._tokenizer.token_to_id(val) + if token_id is not None: + stop_ids.append(token_id) + + return stop_ids # Handle _id suffix (e.g., bos_token_id -> bos_token) if key.endswith("_id"): @@ -190,6 +197,7 @@ class TextTokenizer: messages: List[Dict[str, str]], system_prompt: Optional[str] = None, tokenize: bool = True, + add_generation_prompt: bool = True, **kwargs, ) -> Union[str, List[int]]: """ @@ -199,6 +207,7 @@ class TextTokenizer: messages: List of message dicts with 'role' and 'content'. system_prompt: Optional system prompt string. tokenize: Whether to return token IDs (True) or raw string (False). + add_generation_prompt: Whether to add the generation prompt (default: False). **kwargs: Additional variables to pass to the template. Returns: @@ -216,6 +225,7 @@ class TextTokenizer: rendered = self._chat_template.render( messages=messages, system_prompt=system_prompt, + add_generation_prompt=add_generation_prompt, **kwargs, )