fix: 修复工厂模式问题并增加chat-template设置
This commit is contained in:
parent
073baf105c
commit
aa5e03d7f6
|
|
@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||||
from typing import Dict, Generator, List, Optional, Tuple, Union
|
from typing import Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from jinja2 import Template
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
|
|
@ -9,39 +10,69 @@ from astrai.core.factory import BaseFactory
|
||||||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
||||||
|
|
||||||
HistoryType = List[Tuple[str, str]]
|
HistoryType = List[Tuple[str, str]]
|
||||||
|
MessageType = Dict[str, str]
|
||||||
|
|
||||||
|
# Predefined chat templates using jinja2
|
||||||
|
CHAT_TEMPLATES: Dict[str, str] = {
|
||||||
|
"chatml": """{%- if system_prompt -%}
|
||||||
|
<|im▁start|>system
|
||||||
|
{{ system_prompt }}<|im▁end|>
|
||||||
|
{%- endif -%}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
<|im▁start|>{{ message['role'] }}
|
||||||
|
{{ message['content'] }}<|im▁end|>
|
||||||
|
{%- endfor -%}
|
||||||
|
<|im▁start|>assistant
|
||||||
|
""",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
query: str,
|
query: str,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
history: Optional[HistoryType] = None,
|
history: Optional[HistoryType] = None,
|
||||||
|
template: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""Build prompt using jinja2 template for query and history.
|
||||||
Build prompt in ChatML format for query and history.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query string.
|
query (str): query string.
|
||||||
system_prompt (Optional[str]): system prompt string.
|
system_prompt (Optional[str]): system prompt string.
|
||||||
history (Optional[HistoryType]): history list of query and response.
|
history (Optional[HistoryType]): history list of query and response.
|
||||||
|
template (Optional[str]): jinja2 template string. If None, uses default chatml template.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: prompt string in ChatML format.
|
str: prompt string formatted according to the template.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Use default template
|
||||||
|
prompt = build_prompt(query="Hello", history=[...])
|
||||||
|
|
||||||
|
# Use custom template
|
||||||
|
custom_template = '''
|
||||||
|
{%- for msg in messages -%}
|
||||||
|
{{ msg['role'] }}: {{ msg['content'] }}
|
||||||
|
{%- endfor -%}
|
||||||
|
'''
|
||||||
|
prompt = build_prompt(query="Hello", template=custom_template)
|
||||||
"""
|
"""
|
||||||
result = ""
|
# Convert history to message format
|
||||||
|
messages: List[MessageType] = []
|
||||||
if system_prompt:
|
|
||||||
result += f"<|im▁start|>system\n{system_prompt}<|im▁end|>\n"
|
|
||||||
|
|
||||||
# (convert tuple format to ChatML)
|
|
||||||
if history:
|
if history:
|
||||||
for user_msg, assistant_msg in history:
|
for user_msg, assistant_msg in history:
|
||||||
result += f"<|im▁start|>user\n{user_msg}<|im▁end|>\n"
|
messages.append({"role": "user", "content": user_msg})
|
||||||
result += f"<|im▁start|>assistant\n{assistant_msg}<|im▁end|>\n"
|
messages.append({"role": "assistant", "content": assistant_msg})
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
result += f"<|im▁start|>user\n{query}<|im▁end|>\n"
|
# Use provided template or default chatml template
|
||||||
result += "<|im▁start|>assistant\n"
|
template_str = template if template is not None else CHAT_TEMPLATES["chatml"]
|
||||||
|
|
||||||
return result
|
# Render template
|
||||||
|
jinja_template = Template(template_str)
|
||||||
|
return jinja_template.render(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
def pad_sequence(ids_list: List[List[int]], pad_id: int) -> Tuple[List[List[int]], int]:
|
||||||
|
|
@ -303,18 +334,3 @@ class GeneratorFactory(BaseFactory[GeneratorCore]):
|
||||||
EmbeddingEncoderCore instance
|
EmbeddingEncoderCore instance
|
||||||
"""
|
"""
|
||||||
return EmbeddingEncoder(parameter)
|
return EmbeddingEncoder(parameter)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls, parameter: ModelParameter, request: GenerationRequest
|
|
||||||
) -> GeneratorCore:
|
|
||||||
"""Convenience method that delegates to create_generator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parameter: Model parameters
|
|
||||||
request: Generation request
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator instance
|
|
||||||
"""
|
|
||||||
return cls.create_generator(parameter, request)
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ dependencies = [
|
||||||
"tqdm==4.67.1",
|
"tqdm==4.67.1",
|
||||||
"safetensors==0.5.3",
|
"safetensors==0.5.3",
|
||||||
"huggingface-hub==0.34.3",
|
"huggingface-hub==0.34.3",
|
||||||
|
"jinja2>=3.0.0",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"uvicorn[standard]",
|
"uvicorn[standard]",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue