Luxx/luxx/services/chat.py

397 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Chat service module"""
import json
import uuid
from typing import List, Dict, Any, AsyncGenerator, Optional
from luxx.models import Conversation, Message
from luxx.tools.executor import ToolExecutor
from luxx.tools.core import registry
from luxx.services.llm_client import LLMClient
from luxx.config import config
# Maximum iterations to prevent infinite loops
MAX_ITERATIONS = 10
def _sse_event(event: str, data: dict) -> str:
"""Format a Server-Sent Event string."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def get_llm_client(conversation: Conversation = None):
"""Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)"""
max_tokens = None
if conversation and conversation.provider_id:
from luxx.models import LLMProvider
from luxx.database import SessionLocal
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
if provider:
max_tokens = provider.max_tokens
client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=provider.default_model
)
return client, max_tokens
finally:
db.close()
# Fallback to global config
client = LLMClient()
return client, max_tokens
class ChatService:
"""Chat service with tool support"""
def __init__(self):
self.tool_executor = ToolExecutor()
def build_messages(
self,
conversation: Conversation,
include_system: bool = True
) -> List[Dict[str, str]]:
"""Build message list"""
from luxx.database import SessionLocal
from luxx.models import Message
messages = []
if include_system and conversation.system_prompt:
messages.append({
"role": "system",
"content": conversation.system_prompt
})
db = SessionLocal()
try:
db_messages = db.query(Message).filter(
Message.conversation_id == conversation.id
).order_by(Message.created_at).all()
for msg in db_messages:
# Parse JSON content if possible
try:
content_obj = json.loads(msg.content) if msg.content else {}
if isinstance(content_obj, dict):
content = content_obj.get("text", msg.content)
else:
content = msg.content
except (json.JSONDecodeError, TypeError):
content = msg.content
messages.append({
"role": msg.role,
"content": content
})
finally:
db.close()
return messages
async def stream_response(
self,
conversation: Conversation,
user_message: str,
tools_enabled: bool = True,
enabled_tools: Optional[List[str]] = None,
thinking_enabled: bool = False
) -> AsyncGenerator[Dict[str, str], None]:
"""
Streaming response generator
Yields raw SSE event strings for direct forwarding.
"""
try:
messages = self.build_messages(conversation)
messages.append({
"role": "user",
"content": json.dumps({"text": user_message, "attachments": []})
})
# Filter tools by enabled list if provided
if enabled_tools is not None and tools_enabled:
all_tools = registry.list_all()
tools = [t for t in all_tools if t.get("name") in enabled_tools]
else:
tools = registry.list_all() if tools_enabled else None
# Only include enabled tools
all_tools = registry.list_all() if tools_enabled else None
tools = [t for t in all_tools] if all_tools else None
llm, provider_max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
# 使用 provider 的 max_tokens如果 conversation 有自己的 max_tokens 则覆盖
max_tokens = conversation.max_tokens if hasattr(conversation, 'max_tokens') and conversation.max_tokens else provider_max_tokens
# State tracking
all_steps = []
all_tool_calls = []
all_tool_results = []
step_index = 0
# Global step IDs for thinking and text (persist across iterations)
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
for iteration in range(MAX_ITERATIONS):
print(f"[CHAT] Starting iteration {iteration + 1}, messages: {len(messages)}")
# Stream from LLM
full_content = ""
full_thinking = ""
tool_calls_list = []
# Step tracking - use unified step-{index} format
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
async for sse_line in llm.stream_call(
model=model,
messages=messages,
tools=tools,
temperature=conversation.temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled
):
# Parse SSE line
# Format: "event: xxx\ndata: {...}\n\n"
event_type = None
data_str = None
for line in sse_line.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
if data_str is None:
continue
# Handle error events from LLM
if event_type == 'error':
try:
error_data = json.loads(data_str)
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
except json.JSONDecodeError:
yield _sse_event("error", {"content": data_str})
return
# Parse the data
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
# Get delta
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
# Handle reasoning (thinking)
reasoning = delta.get("reasoning_content", "")
if reasoning:
prev_thinking_len = len(full_thinking)
full_thinking += reasoning
if prev_thinking_len == 0: # New thinking stream started
thinking_step_idx = step_index
thinking_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
}
})
# Handle content
content = delta.get("content", "")
if content:
prev_content_len = len(full_content)
full_content += content
if prev_content_len == 0: # New text stream started
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
}
})
# Accumulate tool calls
tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta:
idx = tc.get("index", 0)
if idx >= len(tool_calls_list):
tool_calls_list.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""}
})
func = tc.get("function", {})
if func.get("name"):
tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
# Save thinking step
if thinking_step_id is not None:
all_steps.append({
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
})
# Save text step
if text_step_id is not None:
all_steps.append({
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
})
# Handle tool calls
if tool_calls_list:
all_tool_calls.extend(tool_calls_list)
# Yield tool_call steps - use unified step-{index} format
tool_call_step_ids = [] # Track step IDs for tool calls
for tc in tool_calls_list:
call_step_idx = step_index
call_step_id = f"step-{step_index}"
tool_call_step_ids.append(call_step_id)
step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
all_steps.append(call_step)
yield _sse_event("process_step", {"step": call_step})
# Execute tools
tool_results = self.tool_executor.process_tool_calls_parallel(
tool_calls_list, {}
)
# Yield tool_result steps - use unified step-{index} format
for i, tr in enumerate(tool_results):
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
result_step_idx = step_index
result_step_id = f"step-{step_index}"
step_index += 1
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id, # Reference to the tool_call step
"name": tr.get("name", ""),
"content": tr.get("content", "")
}
all_steps.append(result_step)
yield _sse_event("process_step", {"step": result_step})
all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": tr.get("content", "")
})
# Add assistant message with tool calls for next iteration
messages.append({
"role": "assistant",
"content": full_content or "",
"tool_calls": tool_calls_list
})
messages.extend(all_tool_results[-len(tool_results):])
all_tool_results = []
continue
# No tool calls - final iteration, save message
msg_id = str(uuid.uuid4())
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps
)
yield _sse_event("done", {
"message_id": msg_id,
"token_count": len(full_content) // 4
})
return
# Max iterations exceeded
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
except Exception as e:
print(f"[CHAT] Exception: {type(e).__name__}: {str(e)}")
yield _sse_event("error", {"content": str(e)})
def _save_message(
self,
conversation_id: str,
msg_id: str,
full_content: str,
all_tool_calls: list,
all_tool_results: list,
all_steps: list
):
"""Save the assistant message to database."""
from luxx.database import SessionLocal
from luxx.models import Message
content_json = {
"text": full_content,
"steps": all_steps
}
if all_tool_calls:
content_json["tool_calls"] = all_tool_calls
db = SessionLocal()
try:
msg = Message(
id=msg_id,
conversation_id=conversation_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=len(full_content) // 4
)
db.add(msg)
db.commit()
except Exception as e:
print(f"[CHAT] Failed to save message: {e}")
db.rollback()
finally:
db.close()
# Global chat service
chat_service = ChatService()