fix: 修复工厂模式问题并增加chat-template设置
This commit is contained in:
parent
073baf105c
commit
aa5e03d7f6
|
|
@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|||
from typing import Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from jinja2 import Template
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.param_config import ModelParameter
|
||||
|
|
@ -9,39 +10,69 @@ from astrai.core.factory import BaseFactory
|
|||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
||||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
MessageType = Dict[str, str]
|
||||
|
||||
# Predefined chat templates using jinja2
|
||||
CHAT_TEMPLATES: Dict[str, str] = {
|
||||
"chatml": """{%- if system_prompt -%}
|
||||
<|im▁start|>system
|
||||
{{ system_prompt }}<|im▁end|>
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
<|im▁start|>{{ message['role'] }}
|
||||
{{ message['content'] }}<|im▁end|>
|
||||
{%- endfor -%}
|
||||
<|im▁start|>assistant
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
def build_prompt(
|
||||
query: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history: Optional[HistoryType] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build prompt in ChatML format for query and history.
|
||||
"""Build prompt using jinja2 template for query and history.
|
||||
|
||||
Args:
|
||||
query (str): query string.
|
||||
system_prompt (Optional[str]): system prompt string.
|
||||
history (Optional[HistoryType]): history list of query and response.
|
||||
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
|
||||
|
||||
Returns:
|
||||
str: prompt string in ChatML format.
|
||||
str: prompt string formatted according to the template.
|
||||
|
||||
Example:
|
||||
# Use default template
|
||||
prompt = build_prompt(query="Hello", history=[...])
|
||||
|
||||
# Use custom template
|
||||
custom_template = '''
|
||||
{%- for msg in messages -%}
|
||||
{{ msg['role'] }}: {{ msg['content'] }}
|
||||
{%- endfor -%}
|
||||
'''
|
||||
prompt = build_prompt(query="Hello", template=custom_template)
|
||||
"""
|
||||
result = ""
|
||||
|
||||
if system_prompt:
|
||||
result += f"<|im▁start|>system\n{system_prompt}<|im▁end|>\n"
|
||||
|
||||
# (convert tuple format to ChatML)
|
||||
# Convert history to message format
|
||||
messages: List[MessageType] = []
|
||||
if history:
|
||||
for user_msg, assistant_msg in history:
|
||||
result += f"<|im▁start|>user\n{user_msg}<|im▁end|>\n"
|
||||
result += f"<|im▁start|>assistant\n{assistant_msg}<|im▁end|>\n"
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
messages.append({"role": "assistant", "content": assistant_msg})
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
result += f"<|im▁start|>user\n{query}<|im▁end|>\n"
|
||||
result += "<|im▁start|>assistant\n"
|
||||
# Use provided template or default chatml template
|
||||
template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
|
||||
|
||||
return result
|
||||
# Render template
|
||||
jinja_template = Template(template_str)
|
||||
return jinja_template.render(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
|
||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
||||
|
|
@ -303,18 +334,3 @@ class GeneratorFactory(BaseFactory[GeneratorCore]):
|
|||
EmbeddingEncoderCore instance
|
||||
"""
|
||||
return EmbeddingEncoder(parameter)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, parameter: ModelParameter, request: GenerationRequest
|
||||
) -> GeneratorCore:
|
||||
"""Convenience method that delegates to create_generator.
|
||||
|
||||
Args:
|
||||
parameter: Model parameters
|
||||
request: Generation request
|
||||
|
||||
Returns:
|
||||
Generator instance
|
||||
"""
|
||||
return cls.create_generator(parameter, request)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ dependencies = [
|
|||
"tqdm==4.67.1",
|
||||
"safetensors==0.5.3",
|
||||
"huggingface-hub==0.34.3",
|
||||
"jinja2>=3.0.0",
|
||||
"fastapi",
|
||||
"uvicorn[standard]",
|
||||
"httpx",
|
||||
|
|
|
|||
Loading…
Reference in New Issue