refactor: 实现 chat template 分派设置

This commit is contained in:
ViperEkura 2026-04-04 16:56:31 +08:00
parent 9c31d78a22
commit 2dc9545d7f
2 changed files with 259 additions and 40 deletions

View File

@ -7,7 +7,6 @@ from astrai.tokenizer.tokenizer import (
from astrai.tokenizer.chat_template import ( from astrai.tokenizer.chat_template import (
HistoryType, HistoryType,
MessageType, MessageType,
CHAT_TEMPLATES,
build_prompt, build_prompt,
) )

View File

@ -1,22 +1,137 @@
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Any
from jinja2 import Template from jinja2 import Template
from dataclasses import dataclass
from astrai.factory import Registry
HistoryType = List[Tuple[str, str]] HistoryType = List[Tuple[str, str]]
MessageType = Dict[str, str] MessageType = Dict[str, str]
# Predefined chat templates using jinja2
CHAT_TEMPLATES: Dict[str, str] = { @dataclass
"chatml": """{%- if system_prompt -%} class ChatTemplate:
<imstart>system """A chat template with Jinja2 rendering support.
{{ system_prompt }}<imend>
{%- endif -%} Attributes:
{%- for message in messages -%} name: Unique identifier for the template.
<imstart>{{ message['role'] }} template_str: Jinja2 template string.
{{ message['content'] }}<imend> description: Optional description.
{%- endfor -%} default_variables: Optional dictionary of default variable values
<imstart>assistant that will be passed to the template if not overridden during rendering.
""", special_tokens: Optional dictionary mapping token names to their string values.
} These tokens are automatically added to the template variables.
"""
name: str
template_str: str
description: str = ""
default_variables: Dict[str, Any] = None
special_tokens: Dict[str, str] = None
def __post_init__(self):
if self.default_variables is None:
self.default_variables = {}
if self.special_tokens is None:
self.special_tokens = {}
@classmethod
def from_string(
cls,
template_str: str,
description: str = "",
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None,
) -> "ChatTemplate":
"""Create a ChatTemplate instance directly from a template string."""
return cls(
name="", # empty name for adhoc templates
template_str=template_str,
description=description,
default_variables=default_variables,
special_tokens=special_tokens,
)
def render(
self,
messages: List[MessageType],
system_prompt: Optional[str] = None,
**extra_variables: Any,
) -> str:
"""Render the template with given messages and variables.
Args:
messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string.
**extra_variables: Additional variables to pass to the template.
These override default_variables and special_tokens.
Returns:
Rendered prompt string.
"""
# Merge default variables, special tokens, and extra variables
variables = {**self.default_variables, **self.special_tokens, **extra_variables}
variables["messages"] = messages
if system_prompt is not None:
variables["system_prompt"] = system_prompt
jinja_template = Template(self.template_str)
return jinja_template.render(**variables)
# Global registry instance
_default_registry = Registry()
# Default template name
_default_template_name = "chatml"
# Convenience functions
def register_chat_template(
name: str,
template_str: str,
description: str = "",
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None,
) -> ChatTemplate:
"""Register a chat template in the global registry."""
template = ChatTemplate(
name=name,
template_str=template_str,
description=description,
default_variables=default_variables,
special_tokens=special_tokens,
)
_default_registry.register(name, template, category=None, priority=0)
return template
def set_default_chat_template(name: str) -> None:
"""Set the default chat template name globally."""
global _default_template_name
if not _default_registry.contains(name):
raise KeyError(
f"Chat template '{name}' not found. Available: {list(_default_registry.list_names())}"
)
_default_template_name = name
def get_default_chat_template_name() -> str:
"""Get the current default chat template name."""
return _default_template_name
def get_chat_template(name: str) -> ChatTemplate:
"""Get a chat template from the global registry."""
return _default_registry.get(name)
def list_chat_templates() -> List[str]:
"""List all registered chat template names."""
return _default_registry.list_names()
def chat_template_exists(name: str) -> bool:
"""Check if a chat template exists."""
return _default_registry.contains(name)
def build_prompt( def build_prompt(
@ -24,29 +139,27 @@ def build_prompt(
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None, history: Optional[HistoryType] = None,
template: Optional[str] = None, template: Optional[str] = None,
template_name: Optional[str] = None,
**extra_variables: Any,
) -> str: ) -> str:
"""Build prompt using jinja2 template for query and history. """Build prompt using a registered chat template or a custom template string.
This function maintains backward compatibility with the previous API.
Args: Args:
query (str): query string. query: The current user query.
system_prompt (Optional[str]): system prompt string. system_prompt: Optional system prompt.
history (Optional[HistoryType]): history list of query and response. history: Optional list of (user_msg, assistant_msg) pairs.
template (Optional[str]): jinja2 template string. If None, uses default chatml template. template: If provided, uses this exact Jinja2 template string (overrides template_name).
template_name: Name of a registered template to use (ignored if `template` is given).
If None, uses the globally set default template (see `set_default_chat_template`).
**extra_variables: Additional variables to pass to the template.
Returns: Returns:
str: prompt string formatted according to the template. Rendered prompt string.
Example: Raises:
# Use default template KeyError: If `template_name` is not registered.
prompt = build_prompt(query="Hello", history=[...])
# Use custom template
custom_template = '''
{%- for msg in messages -%}
{{ msg['role'] }}: {{ msg['content'] }}
{%- endfor -%}
'''
prompt = build_prompt(query="Hello", template=custom_template)
""" """
# Convert history to message format # Convert history to message format
messages: List[MessageType] = [] messages: List[MessageType] = []
@ -56,12 +169,119 @@ def build_prompt(
messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": query}) messages.append({"role": "user", "content": query})
# Use provided template or default chatml template if template is not None:
template_str = template if template is not None else CHAT_TEMPLATES["chatml"] # Use the provided template string directly
jinja_template = Template(template)
# Render template variables = {"messages": messages, **extra_variables}
jinja_template = Template(template_str) if system_prompt is not None:
return jinja_template.render( variables["system_prompt"] = system_prompt
return jinja_template.render(**variables)
else:
# Determine which template name to use
if template_name is None:
template_name = _default_template_name
# Use a registered template
chat_template = get_chat_template(template_name)
return chat_template.render(
messages=messages, messages=messages,
system_prompt=system_prompt, system_prompt=system_prompt,
**extra_variables,
) )
# Predefined templates
# ChatML template (original)
register_chat_template(
name="chatml",
template_str=(
"{%- if system_prompt -%}\n"
"{{ bos_token }}system\n"
"{{ system_prompt }}{{ eos_token }}\n"
"{%- endif -%}\n"
"{%- for message in messages -%}\n"
"{{ bos_token }}{{ message['role'] }}\n"
"{{ message['content'] }}{{ eos_token }}\n"
"{%- endfor -%}\n"
"{{ bos_token }}assistant\n"
),
description="ChatML format with configurable special tokens.",
special_tokens={"bos_token": "<im▁start>", "eos_token": "<im▁end>"},
)
# Simplified template without special tokens (plain text)
register_chat_template(
name="plain",
template_str=(
"{%- if system_prompt -%}\n"
"System: {{ system_prompt }}\n"
"{%- endif -%}\n"
"{%- for message in messages -%}\n"
"{{ message['role']|capitalize }}: {{ message['content'] }}\n"
"{%- endfor -%}\n"
"Assistant:"
),
description="Plain text format with role labels.",
)
# Alpaca-style template
register_chat_template(
name="alpaca",
template_str=(
"{%- if system_prompt -%}\n"
"### Instruction:\n"
"{{ system_prompt }}\n"
"{%- endif -%}\n"
"### Input:\n"
"{{ messages[-1]['content'] }}\n"
"### Response:"
),
description="Alpaca instructionresponse format (singleturn).",
default_variables={},
)
# OpenAI chat format (approximation)
register_chat_template(
name="openai",
template_str=(
"{%- if system_prompt -%}\n"
"{{ bos_token }}system\n"
"{{ system_prompt }}{{ eos_token }}\n"
"{%- endif -%}\n"
"{%- for message in messages -%}\n"
"{{ bos_token }}{{ message['role'] }}\n"
"{{ message['content'] }}{{ eos_token }}\n"
"{%- endfor -%}\n"
"{{ bos_token }}assistant\n"
),
description="OpenAIcompatible chat format with configurable special tokens.",
special_tokens={"bos_token": "<im▁start>", "eos_token": "<im▁end>"},
)
# Llama2 style with [INST] tags
register_chat_template(
name="llama2",
template_str=(
"{%- if system_prompt -%}\n"
"<<SYS>>\n"
"{{ system_prompt }}\n"
"<</SYS>>\n"
"{%- endif -%}\n"
"[INST] {{ messages[-1]['content'] }} [/INST]"
),
description="Llama2 style with [INST] tags (singleturn).",
default_variables={},
)
__all__ = [
"ChatTemplate",
"register_chat_template",
"get_chat_template",
"list_chat_templates",
"chat_template_exists",
"build_prompt",
"set_default_chat_template",
"get_default_chat_template_name",
"HistoryType",
"MessageType",
]