fix: 修复工厂模式问题并增加chat-template设置

This commit is contained in:
ViperEkura 2026-04-04 12:05:05 +08:00
parent 073baf105c
commit aa5e03d7f6
2 changed files with 46 additions and 29 deletions

View File

@ -2,6 +2,7 @@ from dataclasses import dataclass
from typing import Dict, Generator, List, Optional, Tuple, Union
import torch
from jinja2 import Template
from torch import Tensor
from astrai.config.param_config import ModelParameter
@ -9,39 +10,69 @@ from astrai.core.factory import BaseFactory
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
HistoryType = List[Tuple[str, str]]
MessageType = Dict[str, str]
# Predefined chat templates using jinja2
CHAT_TEMPLATES: Dict[str, str] = {
"chatml": """{%- if system_prompt -%}
<imstart>system
{{ system_prompt }}<imend>
{%- endif -%}
{%- for message in messages -%}
<imstart>{{ message['role'] }}
{{ message['content'] }}<imend>
{%- endfor -%}
<imstart>assistant
""",
}
def build_prompt(
query: str,
system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None,
template: Optional[str] = None,
) -> str:
"""
Build prompt in ChatML format for query and history.
"""Build prompt using jinja2 template for query and history.
Args:
query (str): query string.
system_prompt (Optional[str]): system prompt string.
history (Optional[HistoryType]): history list of query and response.
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
Returns:
str: prompt string in ChatML format.
str: prompt string formatted according to the template.
Example:
# Use default template
prompt = build_prompt(query="Hello", history=[...])
# Use custom template
custom_template = '''
{%- for msg in messages -%}
{{ msg['role'] }}: {{ msg['content'] }}
{%- endfor -%}
'''
prompt = build_prompt(query="Hello", template=custom_template)
"""
result = ""
if system_prompt:
result += f"<im▁start>system\n{system_prompt}<im▁end>\n"
# (convert tuple format to ChatML)
# Convert history to message format
messages: List[MessageType] = []
if history:
for user_msg, assistant_msg in history:
result += f"<im▁start>user\n{user_msg}<im▁end>\n"
result += f"<im▁start>assistant\n{assistant_msg}<im▁end>\n"
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": query})
result += f"<im▁start>user\n{query}<im▁end>\n"
result += "<im▁start>assistant\n"
# Use provided template or default chatml template
template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
return result
# Render template
jinja_template = Template(template_str)
return jinja_template.render(
messages=messages,
system_prompt=system_prompt,
)
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
@ -303,18 +334,3 @@ class GeneratorFactory(BaseFactory[GeneratorCore]):
EmbeddingEncoderCore instance
"""
return EmbeddingEncoder(parameter)
@classmethod
def create(
cls, parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
"""Convenience method that delegates to create_generator.
Args:
parameter: Model parameters
request: Generation request
Returns:
Generator instance
"""
return cls.create_generator(parameter, request)

View File

@ -15,6 +15,7 @@ dependencies = [
"tqdm==4.67.1",
"safetensors==0.5.3",
"huggingface-hub==0.34.3",
"jinja2>=3.0.0",
"fastapi",
"uvicorn[standard]",
"httpx",