fix: 修复特殊token 的问题

This commit is contained in:
ViperEkura 2026-04-05 20:09:47 +08:00
parent fc278d17ab
commit d2b36cc85d
1 changed files with 16 additions and 6 deletions

View File

@ -129,14 +129,21 @@ class TextTokenizer:
Supports three forms: Supports three forms:
- tokenizer.bos_token returns string - tokenizer.bos_token returns string
- tokenizer.bos_token_id returns corresponding integer ID - tokenizer.bos_token_id returns corresponding integer ID
- tokenizer.stop_ids returns list of corresponding integer IDs - tokenizer.stop_ids returns list of corresponding integer IDs for all special tokens
""" """
# Handle stop_ids # Handle stop_ids - return IDs for all special tokens
if key == "stop_ids": if key == "stop_ids":
return [ stop_ids = []
self._special_token_map.get(val)
for val in self._special_token_map.values() if self._tokenizer is None:
] return stop_ids
for val in self._special_token_map.values():
token_id = self._tokenizer.token_to_id(val)
if token_id is not None:
stop_ids.append(token_id)
return stop_ids
# Handle _id suffix (e.g., bos_token_id -> bos_token) # Handle _id suffix (e.g., bos_token_id -> bos_token)
if key.endswith("_id"): if key.endswith("_id"):
@ -190,6 +197,7 @@ class TextTokenizer:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
tokenize: bool = True, tokenize: bool = True,
add_generation_prompt: bool = True,
**kwargs, **kwargs,
) -> Union[str, List[int]]: ) -> Union[str, List[int]]:
""" """
@ -199,6 +207,7 @@ class TextTokenizer:
messages: List of message dicts with 'role' and 'content'. messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string. system_prompt: Optional system prompt string.
tokenize: Whether to return token IDs (True) or raw string (False). tokenize: Whether to return token IDs (True) or raw string (False).
add_generation_prompt: Whether to add the generation prompt (default: False).
**kwargs: Additional variables to pass to the template. **kwargs: Additional variables to pass to the template.
Returns: Returns:
@ -216,6 +225,7 @@ class TextTokenizer:
rendered = self._chat_template.render( rendered = self._chat_template.render(
messages=messages, messages=messages,
system_prompt=system_prompt, system_prompt=system_prompt,
add_generation_prompt=add_generation_prompt,
**kwargs, **kwargs,
) )