fix: 修复工厂模式问题并增加chat-template设置

This commit is contained in:
ViperEkura 2026-04-04 12:05:05 +08:00
parent 073baf105c
commit aa5e03d7f6
2 changed files with 46 additions and 29 deletions

View File

@ -2,6 +2,7 @@ from dataclasses import dataclass
from typing import Dict, Generator, List, Optional, Tuple, Union from typing import Dict, Generator, List, Optional, Tuple, Union
import torch import torch
from jinja2 import Template
from torch import Tensor from torch import Tensor
from astrai.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
@ -9,39 +10,69 @@ from astrai.core.factory import BaseFactory
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
HistoryType = List[Tuple[str, str]] HistoryType = List[Tuple[str, str]]
MessageType = Dict[str, str]
# Predefined chat templates using jinja2
CHAT_TEMPLATES: Dict[str, str] = {
"chatml": """{%- if system_prompt -%}
<imstart>system
{{ system_prompt }}<imend>
{%- endif -%}
{%- for message in messages -%}
<imstart>{{ message['role'] }}
{{ message['content'] }}<imend>
{%- endfor -%}
<imstart>assistant
""",
}
def build_prompt( def build_prompt(
query: str, query: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
history: Optional[HistoryType] = None, history: Optional[HistoryType] = None,
template: Optional[str] = None,
) -> str: ) -> str:
""" """Build prompt using jinja2 template for query and history.
Build prompt in ChatML format for query and history.
Args: Args:
query (str): query string. query (str): query string.
system_prompt (Optional[str]): system prompt string. system_prompt (Optional[str]): system prompt string.
history (Optional[HistoryType]): history list of query and response. history (Optional[HistoryType]): history list of query and response.
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
Returns: Returns:
str: prompt string in ChatML format. str: prompt string formatted according to the template.
Example:
# Use default template
prompt = build_prompt(query="Hello", history=[...])
# Use custom template
custom_template = '''
{%- for msg in messages -%}
{{ msg['role'] }}: {{ msg['content'] }}
{%- endfor -%}
'''
prompt = build_prompt(query="Hello", template=custom_template)
""" """
result = "" # Convert history to message format
messages: List[MessageType] = []
if system_prompt:
result += f"<im▁start>system\n{system_prompt}<im▁end>\n"
# (convert tuple format to ChatML)
if history: if history:
for user_msg, assistant_msg in history: for user_msg, assistant_msg in history:
result += f"<im▁start>user\n{user_msg}<im▁end>\n" messages.append({"role": "user", "content": user_msg})
result += f"<im▁start>assistant\n{assistant_msg}<im▁end>\n" messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": query})
result += f"<im▁start>user\n{query}<im▁end>\n" # Use provided template or default chatml template
result += "<im▁start>assistant\n" template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
return result # Render template
jinja_template = Template(template_str)
return jinja_template.render(
messages=messages,
system_prompt=system_prompt,
)
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]: def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
@ -303,18 +334,3 @@ class GeneratorFactory(BaseFactory[GeneratorCore]):
EmbeddingEncoderCore instance EmbeddingEncoderCore instance
""" """
return EmbeddingEncoder(parameter) return EmbeddingEncoder(parameter)
@classmethod
def create(
cls, parameter: ModelParameter, request: GenerationRequest
) -> GeneratorCore:
"""Convenience method that delegates to create_generator.
Args:
parameter: Model parameters
request: Generation request
Returns:
Generator instance
"""
return cls.create_generator(parameter, request)

View File

@ -15,6 +15,7 @@ dependencies = [
"tqdm==4.67.1", "tqdm==4.67.1",
"safetensors==0.5.3", "safetensors==0.5.3",
"huggingface-hub==0.34.3", "huggingface-hub==0.34.3",
"jinja2>=3.0.0",
"fastapi", "fastapi",
"uvicorn[standard]", "uvicorn[standard]",
"httpx", "httpx",