AstrAI/astrai/tokenize/chat_template.py

288 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Dict, List, Optional, Tuple, Any
from jinja2 import Template
from dataclasses import dataclass
from astrai.factory import Registry
HistoryType = List[Tuple[str, str]]
MessageType = Dict[str, str]
@dataclass
class ChatTemplate:
"""A chat template with Jinja2 rendering support.
Attributes:
name: Unique identifier for the template.
template_str: Jinja2 template string.
description: Optional description.
default_variables: Optional dictionary of default variable values
that will be passed to the template if not overridden during rendering.
special_tokens: Optional dictionary mapping token names to their string values.
These tokens are automatically added to the template variables.
"""
name: str
template_str: str
description: str = ""
default_variables: Dict[str, Any] = None
special_tokens: Dict[str, str] = None
def __post_init__(self):
if self.default_variables is None:
self.default_variables = {}
if self.special_tokens is None:
self.special_tokens = {}
@classmethod
def from_string(
cls,
template_str: str,
description: str = "",
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None,
) -> "ChatTemplate":
"""Create a ChatTemplate instance directly from a template string."""
return cls(
name="", # empty name for adhoc templates
template_str=template_str,
description=description,
default_variables=default_variables,
special_tokens=special_tokens,
)
def render(
self,
messages: List[MessageType],
system_prompt: Optional[str] = None,
**extra_variables: Any,
) -> str:
"""Render the template with given messages and variables.
Args:
messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string.
**extra_variables: Additional variables to pass to the template.
These override default_variables and special_tokens.
Returns:
Rendered prompt string.
"""
# Merge default variables, special tokens, and extra variables
variables = {**self.default_variables, **self.special_tokens, **extra_variables}
variables["messages"] = messages
if system_prompt is not None:
variables["system_prompt"] = system_prompt
jinja_template = Template(self.template_str)
return jinja_template.render(**variables)
# Global registry instance
_default_registry = Registry()
# Default template name
_default_template_name = "chatml"
# Convenience functions
def register_chat_template(
name: str,
template_str: str,
description: str = "",
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None,
) -> ChatTemplate:
"""Register a chat template in the global registry."""
template = ChatTemplate(
name=name,
template_str=template_str,
description=description,
default_variables=default_variables,
special_tokens=special_tokens,
)
_default_registry.register(name, template, category=None, priority=0)
return template
def set_default_chat_template(name: str) -> None:
"""Set the default chat template name globally."""
global _default_template_name
if not _default_registry.contains(name):
raise KeyError(
f"Chat template '{name}' not found. Available: {list(_default_registry.list_names())}"
)
_default_template_name = name
def get_default_chat_template_name() -> str:
"""Get the current default chat template name."""
return _default_template_name
def get_chat_template(name: str) -> ChatTemplate:
"""Get a chat template from the global registry."""
return _default_registry.get(name)
def list_chat_templates() -> List[str]:
"""List all registered chat template names."""
return _default_registry.list_names()
def chat_template_exists(name: str) -> bool:
"""Check if a chat template exists."""
return _default_registry.contains(name)
def build_prompt(
query: str,
system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None,
template: Optional[str] = None,
template_name: Optional[str] = None,
**extra_variables: Any,
) -> str:
"""Build prompt using a registered chat template or a custom template string.
This function maintains backward compatibility with the previous API.
Args:
query: The current user query.
system_prompt: Optional system prompt.
history: Optional list of (user_msg, assistant_msg) pairs.
template: If provided, uses this exact Jinja2 template string (overrides template_name).
template_name: Name of a registered template to use (ignored if `template` is given).
If None, uses the globally set default template (see `set_default_chat_template`).
**extra_variables: Additional variables to pass to the template.
Returns:
Rendered prompt string.
Raises:
KeyError: If `template_name` is not registered.
"""
# Convert history to message format
messages: List[MessageType] = []
if history:
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": query})
if template is not None:
# Use the provided template string directly
jinja_template = Template(template)
variables = {"messages": messages, **extra_variables}
if system_prompt is not None:
variables["system_prompt"] = system_prompt
return jinja_template.render(**variables)
else:
# Determine which template name to use
if template_name is None:
template_name = _default_template_name
# Use a registered template
chat_template = get_chat_template(template_name)
return chat_template.render(
messages=messages,
system_prompt=system_prompt,
**extra_variables,
)
# Predefined templates
# ChatML template (original)
register_chat_template(
name="chatml",
template_str=(
"{%- if system_prompt -%}\n"
"{{ bos_token }}system\n"
"{{ system_prompt }}{{ eos_token }}\n"
"{%- endif -%}\n"
"{%- for message in messages -%}\n"
"{{ bos_token }}{{ message['role'] }}\n"
"{{ message['content'] }}{{ eos_token }}\n"
"{%- endfor -%}\n"
"{{ bos_token }}assistant\n"
),
description="ChatML format with configurable special tokens.",
special_tokens={"bos_token": "<im▁start>", "eos_token": "<im▁end>"},
)
# Simplified template without special tokens (plain text)
register_chat_template(
name="plain",
template_str=(
"{%- if system_prompt -%}\n"
"System: {{ system_prompt }}\n"
"{%- endif -%}\n"
"{%- for message in messages -%}\n"
"{{ message['role']|capitalize }}: {{ message['content'] }}\n"
"{%- endfor -%}\n"
"Assistant:"
),
description="Plain text format with role labels.",
)
# Alpaca-style template
register_chat_template(
name="alpaca",
template_str=(
"{%- if system_prompt -%}\n"
"### Instruction:\n"
"{{ system_prompt }}\n"
"{%- endif -%}\n"
"### Input:\n"
"{{ messages[-1]['content'] }}\n"
"### Response:"
),
description="Alpaca instructionresponse format (singleturn).",
default_variables={},
)
# OpenAI chat format (approximation)
register_chat_template(
name="openai",
template_str=(
"{%- if system_prompt -%}\n"
"{{ bos_token }}system\n"
"{{ system_prompt }}{{ eos_token }}\n"
"{%- endif -%}\n"
"{%- for message in messages -%}\n"
"{{ bos_token }}{{ message['role'] }}\n"
"{{ message['content'] }}{{ eos_token }}\n"
"{%- endfor -%}\n"
"{{ bos_token }}assistant\n"
),
description="OpenAIcompatible chat format with configurable special tokens.",
special_tokens={"bos_token": "<im▁start>", "eos_token": "<im▁end>"},
)
# Llama2 style with [INST] tags
register_chat_template(
name="llama2",
template_str=(
"{%- if system_prompt -%}\n"
"<<SYS>>\n"
"{{ system_prompt }}\n"
"<</SYS>>\n"
"{%- endif -%}\n"
"[INST] {{ messages[-1]['content'] }} [/INST]"
),
description="Llama2 style with [INST] tags (singleturn).",
default_variables={},
)
__all__ = [
"ChatTemplate",
"register_chat_template",
"get_chat_template",
"list_chat_templates",
"chat_template_exists",
"build_prompt",
"set_default_chat_template",
"get_default_chat_template_name",
"HistoryType",
"MessageType",
]