Compare commits

...

1 Commits

Author SHA1 Message Date
ViperEkura 6961039db0 feat: 拆分task 逻辑 2026-04-16 21:52:07 +08:00
1 changed files with 296 additions and 361 deletions

View File

@ -1,9 +1,15 @@
"""Chat service module"""
"""Chat service module - Refactored with step-by-step flow"""
import json
import uuid
import logging
from typing import List, Dict, Any, AsyncGenerator, Tuple, Optional
# For Python < 3.9 compatibility
try:
from typing import List
except ImportError:
pass
from typing import List, Dict,AsyncGenerator
from luxx.models import Conversation, Message, LLMProvider
from luxx.tools.executor import ToolExecutor
from luxx.tools.core import registry
@ -11,85 +17,274 @@ from luxx.services.llm_client import LLMClient
from luxx.database import SessionLocal
logger = logging.getLogger(__name__)
# Maximum iterations to prevent infinite loops
MAX_ITERATIONS = 20
def _sse_event(event: str, data: dict) -> str:
"""Format a Server-Sent Event string."""
"""Format SSE event string."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def get_llm_client(conversation: Conversation = None):
"""Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)"""
max_tokens = None
def get_llm_client(conversation: Conversation = None) -> Tuple[LLMClient, Optional[int]]:
"""Get LLM client from conversation's provider."""
if conversation and conversation.provider_id:
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
if provider:
max_tokens = provider.max_tokens
client = LLMClient(
return LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=provider.default_model
)
return client, max_tokens
), provider.max_tokens
finally:
db.close()
return LLMClient(), None
class StreamState:
"""Holds streaming state across iterations."""
# Fallback to global config
client = LLMClient()
return client, max_tokens
def __init__(self):
self.messages: List[Dict] = []
self.all_steps: List[Dict] = []
self.all_tool_calls: List[Dict] = []
self.all_tool_results: List[Dict] = []
self.step_index: int = 0
self.total_usage: Dict[str, int] = {
"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0
}
# Current iteration state
self.full_content: str = ""
self.full_thinking: str = ""
self.tool_calls_list: List[Dict] = []
self.thinking_step_id: Optional[str] = None
self.thinking_step_idx: Optional[int] = None
self.text_step_id: Optional[str] = None
self.text_step_idx: Optional[int] = None
class ChatService:
"""Chat service with tool support"""
"""Chat service with step-by-step flow architecture."""
def __init__(self):
self.tool_executor = ToolExecutor()
def build_messages(
self,
conversation: Conversation,
include_system: bool = True
) -> List[Dict[str, str]]:
"""Build message list"""
# ==================== Step 1: Initialize ====================
def build_messages(self, conversation: Conversation, user_message: str) -> List[Dict[str, str]]:
"""Build message list including user message."""
messages = []
if include_system and conversation.system_prompt:
messages.append({
"role": "system",
"content": conversation.system_prompt
})
if conversation.system_prompt:
messages.append({"role": "system", "content": conversation.system_prompt})
db = SessionLocal()
try:
db_messages = db.query(Message).filter(
for msg in db.query(Message).filter(
Message.conversation_id == conversation.id
).order_by(Message.created_at).all()
for msg in db_messages:
# Parse JSON content if possible
).order_by(Message.created_at).all():
try:
content_obj = json.loads(msg.content) if msg.content else {}
if isinstance(content_obj, dict):
content = content_obj.get("text", msg.content)
else:
content = msg.content
obj = json.loads(msg.content) if msg.content else {}
content = obj.get("text", msg.content) if isinstance(obj, dict) else msg.content
except (json.JSONDecodeError, TypeError):
content = msg.content
messages.append({
"role": msg.role,
"content": content
})
messages.append({"role": msg.role, "content": content})
finally:
db.close()
messages.append({"role": "user", "content": json.dumps({"text": user_message, "attachments": []})})
return messages
def init_stream_state(self, conversation: Conversation, user_message: str, enabled_tools: list) -> Tuple[StreamState, LLMClient, Dict, str, Optional[int]]:
"""Initialize streaming state. Returns: (state, llm, tools, model, max_tokens)"""
state = StreamState()
state.messages = self.build_messages(conversation, user_message)
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else []
llm, max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
tool_context = {"workspace": None, "user_id": None, "username": None, "user_permission_level": 1}
return state, llm, tools, model, max_tokens, tool_context
# ==================== Step 2: Stream LLM ====================
def parse_sse_line(self, sse_line: str) -> Tuple[Optional[str], Optional[str]]:
"""Parse SSE line into (event_type, data_str)."""
event_type = data_str = None
for line in sse_line.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
return event_type, data_str
def stream_llm_response(self, llm, model: str, messages: List[Dict], tools: list,
temperature: float, max_tokens: int, thinking_enabled: bool):
"""
Stream LLM response and yield (sse_line, parsed_chunk) pairs.
"""
for sse_line in llm.stream_call(
model=model, messages=messages, tools=tools,
temperature=temperature, max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled
):
_, data_str = self.parse_sse_line(sse_line)
chunk = None
if data_str:
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
pass
yield sse_line, chunk
def process_delta(self, state: StreamState, delta: dict) -> List[str]:
"""
Process a single delta, return list of SSE event strings.
"""
events = []
# Handle thinking/reasoning
reasoning = delta.get("reasoning_content", "")
if reasoning:
if not state.full_thinking:
state.thinking_step_idx = state.step_index
state.thinking_step_id = f"step-{state.step_index}"
state.step_index += 1
state.full_thinking += reasoning
events.append(_sse_event("process_step", {
"step": {"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking}
}))
# Handle content
content = delta.get("content", "")
if content:
if not state.full_content:
state.text_step_idx = state.step_index
state.text_step_id = f"step-{state.step_index}"
state.step_index += 1
state.full_content += content
events.append(_sse_event("process_step", {
"step": {"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content}
}))
# Handle tool calls
for tc in delta.get("tool_calls", []):
idx = tc.get("index", 0)
if idx >= len(state.tool_calls_list):
state.tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}})
func = tc.get("function", {})
if func.get("name"):
state.tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
state.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
return events
def save_steps(self, state: StreamState):
"""Save current iteration steps to all_steps."""
if state.thinking_step_id:
state.all_steps.append({"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking})
if state.text_step_id:
state.all_steps.append({"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content})
# ==================== Step 3: Execute Tools ====================
def execute_tools(self, state: StreamState, tool_context: Dict) -> Tuple[List[Dict], List[str]]:
"""
Execute tools and return (results, events).
"""
if not state.tool_calls_list:
return [], []
state.all_tool_calls.extend(state.tool_calls_list)
tool_call_ids = []
events = []
# Yield tool_call steps
for tc in state.tool_calls_list:
step_id = f"step-{state.step_index}"
tool_call_ids.append(step_id)
state.step_index += 1
step = {
"id": step_id, "index": len(state.all_steps), "type": "tool_call",
"id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"]
}
state.all_steps.append(step)
events.append(_sse_event("process_step", {"step": step}))
# Execute tools
results = self.tool_executor.process_tool_calls_parallel(state.tool_calls_list, tool_context)
# Yield tool_result steps
for i, tr in enumerate(results):
ref_id = tool_call_ids[i] if i < len(tool_call_ids) else f"tool-{i}"
step_id = f"step-{state.step_index}"
state.step_index += 1
content = tr.get("content", "")
success = True
try:
obj = json.loads(content)
if isinstance(obj, dict):
success = obj.get("success", True)
except:
pass
step = {
"id": step_id, "index": len(state.all_steps), "type": "tool_result",
"id_ref": ref_id, "name": tr.get("name", ""), "content": content, "success": success
}
state.all_steps.append(step)
events.append(_sse_event("process_step", {"step": step}))
state.all_tool_results.append({"role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": content})
return results, events
def update_messages_for_next_iteration(self, state: StreamState, results: List[Dict]):
"""Update messages list with assistant response and tool results for next iteration."""
state.messages.append({"role": "assistant", "content": state.full_content or "", "tool_calls": state.tool_calls_list})
if results:
state.messages.extend(state.all_tool_results[-len(results):])
state.all_tool_results = []
def reset_iteration_state(self, state: StreamState):
"""Reset state for next iteration."""
state.full_content = state.full_thinking = ""
state.tool_calls_list = []
state.thinking_step_id = state.thinking_step_idx = state.text_step_id = state.text_step_idx = None
# ==================== Step 4: Finalize ====================
def save_message(self, conversation_id: str, state: StreamState, token_count: int):
"""Save assistant message to database."""
content_json = {"text": state.full_content, "steps": state.all_steps}
if state.all_tool_calls:
content_json["tool_calls"] = state.all_tool_calls
db = SessionLocal()
try:
db.add(Message(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
usage=json.dumps(state.total_usage) if state.total_usage else None
))
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# ==================== Main Orchestrator ====================
async def stream_response(
self,
conversation: Conversation,
@ -100,360 +295,100 @@ class ChatService:
username: str = None,
workspace: str = None,
user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]:
"""
Streaming response generator
Yields raw SSE event strings for direct forwarding.
"""
) -> AsyncGenerator[str, None]:
"""Main streaming orchestrator - step-by-step flow."""
try:
messages = self.build_messages(conversation)
messages.append({
"role": "user",
"content": json.dumps({"text": user_message, "attachments": []})
# Step 1: Initialize
state, llm, tools, model, max_tokens, tool_context = self.init_stream_state(
conversation, user_message, enabled_tools or []
)
tool_context.update
({
"user_id": user_id,
"username": username,
"workspace": workspace,
"user_permission_level": user_permission_level
})
# Get tools based on enabled_tools filter
if enabled_tools:
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
else:
tools = []
llm, provider_max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
# 直接使用 provider 的 max_tokens
max_tokens = provider_max_tokens
# State tracking
all_steps = []
all_tool_calls = []
all_tool_results = []
step_index = 0
# Token usage tracking
total_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
# Global step IDs for thinking and text (persist across iterations)
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
# ReAct loop
for _ in range(MAX_ITERATIONS):
# Stream from LLM
full_content = ""
full_thinking = ""
tool_calls_list = []
self.reset_iteration_state(state)
# Step tracking - use unified step-{index} format
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
async for sse_line in llm.stream_call(
model=model,
messages=messages,
tools=tools,
temperature=conversation.temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled
# Step 2: Stream LLM
async for sse_line, chunk in self.stream_llm_response(
llm, model, state.messages, tools,
conversation.temperature, max_tokens,
thinking_enabled or conversation.thinking_enabled
):
# Parse SSE line
# Format: "event: xxx\ndata: {...}\n\n"
event_type = None
data_str = None
# Handle error events
event_type, data_str = self.parse_sse_line(sse_line)
if event_type == 'error':
error_data = json.loads(data_str) if data_str else {}
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
return
for line in sse_line.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
if data_str is None:
if not chunk:
continue
# Handle error events from LLM
if event_type == 'error':
try:
error_data = json.loads(data_str)
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
except json.JSONDecodeError:
yield _sse_event("error", {"content": data_str})
return
# Parse the data
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"})
return
# 提取 API 返回的 usage 信息
# Extract usage
if "usage" in chunk:
usage = chunk["usage"]
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
total_usage["total_tokens"] = usage.get("total_tokens", 0)
u = chunk["usage"]
state.total_usage = {
"prompt_tokens": u.get("prompt_tokens", 0),
"completion_tokens": u.get("completion_tokens", 0),
"total_tokens": u.get("total_tokens", 0)
}
# Check for error in response
# Check for API errors
if "error" in chunk:
error_msg = chunk["error"].get("message", str(chunk["error"]))
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
yield _sse_event("error", {"content": f"API Error: {chunk['error'].get('message', str(chunk['error']))}"})
return
# Get delta
choices = chunk.get("choices", [])
delta = None
choices = chunk.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
# If no delta but has message (non-streaming response)
if not delta:
message = choices[0].get("message", {})
if message.get("content"):
delta = {"content": message["content"]}
content = choices[0].get("message", {}).get("content")
if content:
delta = {"content": content}
if not delta:
# Check if there's any content in the response (for non-standard LLM responses)
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
if content:
delta = {"content": content}
if delta:
# Handle reasoning (thinking)
reasoning = delta.get("reasoning_content", "")
if reasoning:
prev_thinking_len = len(full_thinking)
full_thinking += reasoning
if prev_thinking_len == 0: # New thinking stream started
thinking_step_idx = step_index
thinking_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
}
})
# Handle content
content = delta.get("content", "")
if content:
prev_content_len = len(full_content)
full_content += content
if prev_content_len == 0: # New text stream started
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
}
})
# Accumulate tool calls
tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta:
idx = tc.get("index", 0)
if idx >= len(tool_calls_list):
tool_calls_list.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""}
})
func = tc.get("function", {})
if func.get("name"):
tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
# Step 2b: Process delta
for event in self.process_delta(state, delta):
yield event
# Save thinking step
if thinking_step_id is not None:
all_steps.append({
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
})
# Save steps after streaming
self.save_steps(state)
# Save text step
if text_step_id is not None:
all_steps.append({
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
})
# Handle tool calls
if tool_calls_list:
all_tool_calls.extend(tool_calls_list)
# Yield tool_call steps - use unified step-{index} format
tool_call_step_ids = [] # Track step IDs for tool calls
for tc in tool_calls_list:
call_step_idx = step_index
call_step_id = f"step-{step_index}"
tool_call_step_ids.append(call_step_id)
step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
all_steps.append(call_step)
yield _sse_event("process_step", {"step": call_step})
# Execute tools
tool_context = {
"workspace": workspace,
"user_id": user_id,
"username": username,
"user_permission_level": user_permission_level
}
tool_results = self.tool_executor.process_tool_calls_parallel(
tool_calls_list, tool_context
)
# Yield tool_result steps - use unified step-{index} format
for i, tr in enumerate(tool_results):
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
result_step_idx = step_index
result_step_id = f"step-{step_index}"
step_index += 1
content = tr.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id, # Reference to the tool_call step
"name": tr.get("name", ""),
"content": content,
"success": success
}
all_steps.append(result_step)
yield _sse_event("process_step", {"step": result_step})
all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": tr.get("content", "")
})
# Add assistant message with tool calls for next iteration
messages.append({
"role": "assistant",
"content": full_content or "",
"tool_calls": tool_calls_list
})
messages.extend(all_tool_results[-len(tool_results):])
all_tool_results = []
# Step 3: Execute tools if present
if state.tool_calls_list:
results, events = self.execute_tools(state, tool_context)
for event in events:
yield event
self.update_messages_for_next_iteration(state, results)
continue
# No tool calls - final iteration, save message
msg_id = str(uuid.uuid4())
actual_token_count = total_usage.get("completion_tokens", 0)
logger.info(f"total_usage: {total_usage}")
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
actual_token_count,
total_usage
)
yield _sse_event("done", {
"message_id": msg_id,
"token_count": actual_token_count,
"usage": total_usage
})
# Step 4: Finalize (no tool calls)
token_count = state.total_usage.get("completion_tokens", 0)
self.save_message(conversation.id, state, token_count)
yield _sse_event("done", {"message_id": str(uuid.uuid4()), "token_count": token_count, "usage": state.total_usage})
return
# Max iterations exceeded - save message before error
if full_content or all_tool_calls:
msg_id = str(uuid.uuid4())
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
actual_token_count,
total_usage
)
# Max iterations exceeded
if state.full_content or state.all_tool_calls:
self.save_message(conversation.id, state, state.total_usage.get("completion_tokens", 0))
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
except Exception as e:
yield _sse_event("error", {"content": str(e)})
def _save_message(
self,
conversation_id: str,
msg_id: str,
full_content: str,
all_tool_calls: list,
all_tool_results: list,
all_steps: list,
token_count: int = 0,
usage: dict = None
):
"""Save the assistant message to database."""
content_json = {
"text": full_content,
"steps": all_steps
}
if all_tool_calls:
content_json["tool_calls"] = all_tool_calls
db = SessionLocal()
try:
msg = Message(
id=msg_id,
conversation_id=conversation_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
usage=json.dumps(usage) if usage else None
)
db.add(msg)
db.commit()
except Exception as e:
db.rollback()
raise
finally:
db.close()
# Global chat service
# Global service
chat_service = ChatService()