fix: 修复特殊token 的问题

This commit is contained in:
ViperEkura 2026-04-05 20:09:47 +08:00
parent fc278d17ab
commit d2b36cc85d
1 changed files with 16 additions and 6 deletions

View File

@ -129,14 +129,21 @@ class TextTokenizer:
Supports three forms:
- tokenizer.bos_token returns string
- tokenizer.bos_token_id returns corresponding integer ID
- tokenizer.stop_ids returns list of corresponding integer IDs
- tokenizer.stop_ids returns list of corresponding integer IDs for all special tokens
"""
# Handle stop_ids
# Handle stop_ids - return IDs for all special tokens
if key == "stop_ids":
return [
self._special_token_map.get(val)
for val in self._special_token_map.values()
]
stop_ids = []
if self._tokenizer is None:
return stop_ids
for val in self._special_token_map.values():
token_id = self._tokenizer.token_to_id(val)
if token_id is not None:
stop_ids.append(token_id)
return stop_ids
# Handle _id suffix (e.g., bos_token_id -> bos_token)
if key.endswith("_id"):
@ -190,6 +197,7 @@ class TextTokenizer:
messages: List[Dict[str, str]],
system_prompt: Optional[str] = None,
tokenize: bool = True,
add_generation_prompt: bool = True,
**kwargs,
) -> Union[str, List[int]]:
"""
@ -199,6 +207,7 @@ class TextTokenizer:
messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string.
tokenize: Whether to return token IDs (True) or raw string (False).
add_generation_prompt: Whether to add the generation prompt (default: False).
**kwargs: Additional variables to pass to the template.
Returns:
@ -216,6 +225,7 @@ class TextTokenizer:
rendered = self._chat_template.render(
messages=messages,
system_prompt=system_prompt,
add_generation_prompt=add_generation_prompt,
**kwargs,
)