263 lines
9.1 KiB
Python
263 lines
9.1 KiB
Python
"""聊天服务模块"""
|
|
import json
|
|
import re
|
|
from typing import Dict, List, Optional, Any, Generator
|
|
from datetime import datetime
|
|
|
|
from alcor.models import Conversation, Message
|
|
from alcor.tools.executor import ToolExecutor
|
|
from alcor.tools.core import registry
|
|
from alcor.services.llm_client import llm_client, LLMClient
|
|
|
|
|
|
# 最大迭代次数,防止无限循环
|
|
MAX_ITERATIONS = 10
|
|
|
|
|
|
class ChatService:
|
|
"""聊天服务"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm_client: Optional[LLMClient] = None,
|
|
max_iterations: int = MAX_ITERATIONS
|
|
):
|
|
self.llm_client = llm_client or llm_client
|
|
self.tool_executor = ToolExecutor(enable_cache=True, cache_ttl=300)
|
|
self.max_iterations = max_iterations
|
|
|
|
def build_messages(
|
|
self,
|
|
conversation: Conversation,
|
|
include_system: bool = True
|
|
) -> List[Dict[str, str]]:
|
|
"""构建消息列表"""
|
|
messages = []
|
|
|
|
# 添加系统提示
|
|
if include_system and conversation.system_prompt:
|
|
messages.append({
|
|
"role": "system",
|
|
"content": conversation.system_prompt
|
|
})
|
|
|
|
# 添加历史消息
|
|
for msg in conversation.messages.order_by(Message.created_at).all():
|
|
try:
|
|
content_data = json.loads(msg.content) if msg.content else {}
|
|
if isinstance(content_data, dict):
|
|
text = content_data.get("text", "")
|
|
else:
|
|
text = str(msg.content)
|
|
except json.JSONDecodeError:
|
|
text = msg.content
|
|
|
|
messages.append({
|
|
"role": msg.role,
|
|
"content": text
|
|
})
|
|
|
|
return messages
|
|
|
|
def stream_response(
|
|
self,
|
|
conversation: Conversation,
|
|
user_message: str,
|
|
tools_enabled: bool = True,
|
|
context: Optional[Dict[str, Any]] = None
|
|
) -> Generator[Dict[str, Any], None, None]:
|
|
"""
|
|
流式响应生成器
|
|
|
|
生成事件类型:
|
|
- process_step: thinking/text/tool_call/tool_result 步骤
|
|
- done: 最终响应完成
|
|
- error: 出错时
|
|
"""
|
|
try:
|
|
# 构建消息列表
|
|
messages = self.build_messages(conversation)
|
|
|
|
# 添加用户消息
|
|
messages.append({
|
|
"role": "user",
|
|
"content": user_message
|
|
})
|
|
|
|
# 获取工具列表
|
|
tools = registry.list_all() if tools_enabled else None
|
|
|
|
# 迭代处理
|
|
iteration = 0
|
|
full_response = ""
|
|
tool_calls_buffer: List[Dict] = []
|
|
|
|
while iteration < self.max_iterations:
|
|
iteration += 1
|
|
|
|
# 调用LLM
|
|
tool_calls_this_round = None
|
|
|
|
for event in self.llm_client.stream(
|
|
model=conversation.model,
|
|
messages=messages,
|
|
tools=tools,
|
|
temperature=conversation.temperature,
|
|
max_tokens=conversation.max_tokens,
|
|
thinking_enabled=conversation.thinking_enabled
|
|
):
|
|
event_type = event.get("type")
|
|
|
|
if event_type == "content_delta":
|
|
# 内容增量
|
|
content = event.get("content", "")
|
|
if content:
|
|
full_response += content
|
|
yield {
|
|
"type": "process_step",
|
|
"step_type": "text",
|
|
"content": content
|
|
}
|
|
|
|
elif event_type == "done":
|
|
# 完成
|
|
tool_calls_this_round = event.get("tool_calls")
|
|
|
|
# 处理工具调用
|
|
if tool_calls_this_round and tools_enabled:
|
|
yield {
|
|
"type": "process_step",
|
|
"step_type": "tool_call",
|
|
"tool_calls": tool_calls_this_round
|
|
}
|
|
|
|
# 执行工具
|
|
tool_results = self.tool_executor.process_tool_calls_parallel(
|
|
tool_calls_this_round
|
|
)
|
|
|
|
for result in tool_results:
|
|
yield {
|
|
"type": "process_step",
|
|
"step_type": "tool_result",
|
|
"result": result
|
|
}
|
|
|
|
# 添加到消息历史
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": full_response,
|
|
"tool_calls": tool_calls_this_round
|
|
})
|
|
|
|
# 添加工具结果
|
|
for tr in tool_results:
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tr.get("tool_call_id"),
|
|
"content": tr.get("content", ""),
|
|
"name": tr.get("name")
|
|
})
|
|
|
|
tool_calls_buffer.extend(tool_calls_this_round)
|
|
else:
|
|
# 没有工具调用,退出循环
|
|
break
|
|
|
|
# 如果没有更多工具调用,结束
|
|
if not tool_calls_this_round or not tools_enabled:
|
|
break
|
|
|
|
# 最终完成
|
|
yield {
|
|
"type": "done",
|
|
"content": full_response,
|
|
"tool_calls": tool_calls_buffer if tool_calls_buffer else None,
|
|
"iterations": iteration
|
|
}
|
|
|
|
except Exception as e:
|
|
yield {
|
|
"type": "error",
|
|
"error": str(e)
|
|
}
|
|
|
|
def non_stream_response(
|
|
self,
|
|
conversation: Conversation,
|
|
user_message: str,
|
|
tools_enabled: bool = True,
|
|
context: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""非流式响应"""
|
|
try:
|
|
messages = self.build_messages(conversation)
|
|
messages.append({
|
|
"role": "user",
|
|
"content": user_message
|
|
})
|
|
|
|
tools = registry.list_all() if tools_enabled else None
|
|
|
|
# 迭代处理
|
|
iteration = 0
|
|
full_response = ""
|
|
all_tool_calls = []
|
|
|
|
while iteration < self.max_iterations:
|
|
iteration += 1
|
|
|
|
response = self.llm_client.call(
|
|
model=conversation.model,
|
|
messages=messages,
|
|
tools=tools,
|
|
stream=False,
|
|
temperature=conversation.temperature,
|
|
max_tokens=conversation.max_tokens
|
|
)
|
|
|
|
full_response = response.content
|
|
tool_calls = response.tool_calls
|
|
|
|
if tool_calls and tools_enabled:
|
|
# 执行工具
|
|
tool_results = self.tool_executor.process_tool_calls_parallel(tool_calls)
|
|
all_tool_calls.extend(tool_calls)
|
|
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": full_response,
|
|
"tool_calls": tool_calls
|
|
})
|
|
|
|
for tr in tool_results:
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tr.get("tool_call_id"),
|
|
"content": tr.get("content", ""),
|
|
"name": tr.get("name")
|
|
})
|
|
else:
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": full_response
|
|
})
|
|
break
|
|
|
|
return {
|
|
"success": True,
|
|
"content": full_response,
|
|
"tool_calls": all_tool_calls,
|
|
"iterations": iteration
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
|
|
# 全局聊天服务
|
|
chat_service = ChatService()
|