From aa5e03d7f69cafe1addef86c88c45b2755a7e6f7 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Apr 2026 12:05:05 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=B7=A5=E5=8E=82?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E9=97=AE=E9=A2=98=E5=B9=B6=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?chat-template=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/generator.py | 74 +++++++++++++++++++++-------------- pyproject.toml | 1 + 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/astrai/inference/generator.py b/astrai/inference/generator.py index a99f7e3..67d76cf 100644 --- a/astrai/inference/generator.py +++ b/astrai/inference/generator.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Dict, Generator, List, Optional, Tuple, Union import torch +from jinja2 import Template from torch import Tensor from astrai.config.param_config import ModelParameter @@ -9,39 +10,69 @@ from astrai.core.factory import BaseFactory from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager HistoryType = List[Tuple[str, str]] +MessageType = Dict[str, str] + +# Predefined chat templates using jinja2 +CHAT_TEMPLATES: Dict[str, str] = { + "chatml": """{%- if system_prompt -%} +<|im▁start|>system +{{ system_prompt }}<|im▁end|> +{%- endif -%} +{%- for message in messages -%} +<|im▁start|>{{ message['role'] }} +{{ message['content'] }}<|im▁end|> +{%- endfor -%} +<|im▁start|>assistant +""", +} def build_prompt( query: str, system_prompt: Optional[str] = None, history: Optional[HistoryType] = None, + template: Optional[str] = None, ) -> str: - """ - Build prompt in ChatML format for query and history. + """Build prompt using jinja2 template for query and history. Args: query (str): query string. system_prompt (Optional[str]): system prompt string. history (Optional[HistoryType]): history list of query and response. + template (Optional[str]): jinja2 template string. If None, uses default chatml template. Returns: - str: prompt string in ChatML format. + str: prompt string formatted according to the template. + + Example: + # Use default template + prompt = build_prompt(query="Hello", history=[...]) + + # Use custom template + custom_template = ''' + {%- for msg in messages -%} + {{ msg['role'] }}: {{ msg['content'] }} + {%- endfor -%} + ''' + prompt = build_prompt(query="Hello", template=custom_template) """ - result = "" - - if system_prompt: - result += f"<|im▁start|>system\n{system_prompt}<|im▁end|>\n" - - # (convert tuple format to ChatML) + # Convert history to message format + messages: List[MessageType] = [] if history: for user_msg, assistant_msg in history: - result += f"<|im▁start|>user\n{user_msg}<|im▁end|>\n" - result += f"<|im▁start|>assistant\n{assistant_msg}<|im▁end|>\n" + messages.append({"role": "user", "content": user_msg}) + messages.append({"role": "assistant", "content": assistant_msg}) + messages.append({"role": "user", "content": query}) - result += f"<|im▁start|>user\n{query}<|im▁end|>\n" - result += "<|im▁start|>assistant\n" + # Use provided template or default chatml template + template_str = template if template is not None else CHAT_TEMPLATES["chatml"] - return result + # Render template + jinja_template = Template(template_str) + return jinja_template.render( + messages=messages, + system_prompt=system_prompt, + ) def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]: @@ -303,18 +334,3 @@ class GeneratorFactory(BaseFactory[GeneratorCore]): EmbeddingEncoderCore instance """ return EmbeddingEncoder(parameter) - - @classmethod - def create( - cls, parameter: ModelParameter, request: GenerationRequest - ) -> GeneratorCore: - """Convenience method that delegates to create_generator. - - Args: - parameter: Model parameters - request: Generation request - - Returns: - Generator instance - """ - return cls.create_generator(parameter, request) diff --git a/pyproject.toml b/pyproject.toml index 1fe090b..9c3ecec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "tqdm==4.67.1", "safetensors==0.5.3", "huggingface-hub==0.34.3", + "jinja2>=3.0.0", "fastapi", "uvicorn[standard]", "httpx",