Compare commits

...

1 Commits

Author SHA1 Message Date
ViperEkura 6961039db0 feat: 拆分task 逻辑 2026-04-16 21:52:07 +08:00
1 changed files with 296 additions and 361 deletions

View File

@ -1,9 +1,15 @@
"""Chat service module""" """Chat service module - Refactored with step-by-step flow"""
import json import json
import uuid import uuid
import logging import logging
from typing import List, Dict, Any, AsyncGenerator, Tuple, Optional
# For Python < 3.9 compatibility
try:
from typing import List
except ImportError:
pass
from typing import List, Dict,AsyncGenerator
from luxx.models import Conversation, Message, LLMProvider from luxx.models import Conversation, Message, LLMProvider
from luxx.tools.executor import ToolExecutor from luxx.tools.executor import ToolExecutor
from luxx.tools.core import registry from luxx.tools.core import registry
@ -11,85 +17,274 @@ from luxx.services.llm_client import LLMClient
from luxx.database import SessionLocal from luxx.database import SessionLocal
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Maximum iterations to prevent infinite loops
MAX_ITERATIONS = 20 MAX_ITERATIONS = 20
def _sse_event(event: str, data: dict) -> str: def _sse_event(event: str, data: dict) -> str:
"""Format a Server-Sent Event string.""" """Format SSE event string."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def get_llm_client(conversation: Conversation = None): def get_llm_client(conversation: Conversation = None) -> Tuple[LLMClient, Optional[int]]:
"""Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)""" """Get LLM client from conversation's provider."""
max_tokens = None
if conversation and conversation.provider_id: if conversation and conversation.provider_id:
db = SessionLocal() db = SessionLocal()
try: try:
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first() provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
if provider: if provider:
max_tokens = provider.max_tokens return LLMClient(
client = LLMClient(
api_key=provider.api_key, api_key=provider.api_key,
api_url=provider.base_url, api_url=provider.base_url,
model=provider.default_model model=provider.default_model
) ), provider.max_tokens
return client, max_tokens
finally: finally:
db.close() db.close()
return LLMClient(), None
class StreamState:
"""Holds streaming state across iterations."""
# Fallback to global config def __init__(self):
client = LLMClient() self.messages: List[Dict] = []
return client, max_tokens self.all_steps: List[Dict] = []
self.all_tool_calls: List[Dict] = []
self.all_tool_results: List[Dict] = []
self.step_index: int = 0
self.total_usage: Dict[str, int] = {
"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0
}
# Current iteration state
self.full_content: str = ""
self.full_thinking: str = ""
self.tool_calls_list: List[Dict] = []
self.thinking_step_id: Optional[str] = None
self.thinking_step_idx: Optional[int] = None
self.text_step_id: Optional[str] = None
self.text_step_idx: Optional[int] = None
class ChatService: class ChatService:
"""Chat service with tool support""" """Chat service with step-by-step flow architecture."""
def __init__(self): def __init__(self):
self.tool_executor = ToolExecutor() self.tool_executor = ToolExecutor()
def build_messages( # ==================== Step 1: Initialize ====================
self,
conversation: Conversation, def build_messages(self, conversation: Conversation, user_message: str) -> List[Dict[str, str]]:
include_system: bool = True """Build message list including user message."""
) -> List[Dict[str, str]]:
"""Build message list"""
messages = [] messages = []
if conversation.system_prompt:
if include_system and conversation.system_prompt: messages.append({"role": "system", "content": conversation.system_prompt})
messages.append({
"role": "system",
"content": conversation.system_prompt
})
db = SessionLocal() db = SessionLocal()
try: try:
db_messages = db.query(Message).filter( for msg in db.query(Message).filter(
Message.conversation_id == conversation.id Message.conversation_id == conversation.id
).order_by(Message.created_at).all() ).order_by(Message.created_at).all():
for msg in db_messages:
# Parse JSON content if possible
try: try:
content_obj = json.loads(msg.content) if msg.content else {} obj = json.loads(msg.content) if msg.content else {}
if isinstance(content_obj, dict): content = obj.get("text", msg.content) if isinstance(obj, dict) else msg.content
content = content_obj.get("text", msg.content)
else:
content = msg.content
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
content = msg.content content = msg.content
messages.append({"role": msg.role, "content": content})
messages.append({
"role": msg.role,
"content": content
})
finally: finally:
db.close() db.close()
messages.append({"role": "user", "content": json.dumps({"text": user_message, "attachments": []})})
return messages return messages
def init_stream_state(self, conversation: Conversation, user_message: str, enabled_tools: list) -> Tuple[StreamState, LLMClient, Dict, str, Optional[int]]:
"""Initialize streaming state. Returns: (state, llm, tools, model, max_tokens)"""
state = StreamState()
state.messages = self.build_messages(conversation, user_message)
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else []
llm, max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
tool_context = {"workspace": None, "user_id": None, "username": None, "user_permission_level": 1}
return state, llm, tools, model, max_tokens, tool_context
# ==================== Step 2: Stream LLM ====================
def parse_sse_line(self, sse_line: str) -> Tuple[Optional[str], Optional[str]]:
"""Parse SSE line into (event_type, data_str)."""
event_type = data_str = None
for line in sse_line.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
return event_type, data_str
def stream_llm_response(self, llm, model: str, messages: List[Dict], tools: list,
temperature: float, max_tokens: int, thinking_enabled: bool):
"""
Stream LLM response and yield (sse_line, parsed_chunk) pairs.
"""
for sse_line in llm.stream_call(
model=model, messages=messages, tools=tools,
temperature=temperature, max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled
):
_, data_str = self.parse_sse_line(sse_line)
chunk = None
if data_str:
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
pass
yield sse_line, chunk
def process_delta(self, state: StreamState, delta: dict) -> List[str]:
"""
Process a single delta, return list of SSE event strings.
"""
events = []
# Handle thinking/reasoning
reasoning = delta.get("reasoning_content", "")
if reasoning:
if not state.full_thinking:
state.thinking_step_idx = state.step_index
state.thinking_step_id = f"step-{state.step_index}"
state.step_index += 1
state.full_thinking += reasoning
events.append(_sse_event("process_step", {
"step": {"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking}
}))
# Handle content
content = delta.get("content", "")
if content:
if not state.full_content:
state.text_step_idx = state.step_index
state.text_step_id = f"step-{state.step_index}"
state.step_index += 1
state.full_content += content
events.append(_sse_event("process_step", {
"step": {"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content}
}))
# Handle tool calls
for tc in delta.get("tool_calls", []):
idx = tc.get("index", 0)
if idx >= len(state.tool_calls_list):
state.tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}})
func = tc.get("function", {})
if func.get("name"):
state.tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
state.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
return events
def save_steps(self, state: StreamState):
"""Save current iteration steps to all_steps."""
if state.thinking_step_id:
state.all_steps.append({"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking})
if state.text_step_id:
state.all_steps.append({"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content})
# ==================== Step 3: Execute Tools ====================
def execute_tools(self, state: StreamState, tool_context: Dict) -> Tuple[List[Dict], List[str]]:
"""
Execute tools and return (results, events).
"""
if not state.tool_calls_list:
return [], []
state.all_tool_calls.extend(state.tool_calls_list)
tool_call_ids = []
events = []
# Yield tool_call steps
for tc in state.tool_calls_list:
step_id = f"step-{state.step_index}"
tool_call_ids.append(step_id)
state.step_index += 1
step = {
"id": step_id, "index": len(state.all_steps), "type": "tool_call",
"id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"]
}
state.all_steps.append(step)
events.append(_sse_event("process_step", {"step": step}))
# Execute tools
results = self.tool_executor.process_tool_calls_parallel(state.tool_calls_list, tool_context)
# Yield tool_result steps
for i, tr in enumerate(results):
ref_id = tool_call_ids[i] if i < len(tool_call_ids) else f"tool-{i}"
step_id = f"step-{state.step_index}"
state.step_index += 1
content = tr.get("content", "")
success = True
try:
obj = json.loads(content)
if isinstance(obj, dict):
success = obj.get("success", True)
except:
pass
step = {
"id": step_id, "index": len(state.all_steps), "type": "tool_result",
"id_ref": ref_id, "name": tr.get("name", ""), "content": content, "success": success
}
state.all_steps.append(step)
events.append(_sse_event("process_step", {"step": step}))
state.all_tool_results.append({"role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": content})
return results, events
def update_messages_for_next_iteration(self, state: StreamState, results: List[Dict]):
"""Update messages list with assistant response and tool results for next iteration."""
state.messages.append({"role": "assistant", "content": state.full_content or "", "tool_calls": state.tool_calls_list})
if results:
state.messages.extend(state.all_tool_results[-len(results):])
state.all_tool_results = []
def reset_iteration_state(self, state: StreamState):
"""Reset state for next iteration."""
state.full_content = state.full_thinking = ""
state.tool_calls_list = []
state.thinking_step_id = state.thinking_step_idx = state.text_step_id = state.text_step_idx = None
# ==================== Step 4: Finalize ====================
def save_message(self, conversation_id: str, state: StreamState, token_count: int):
"""Save assistant message to database."""
content_json = {"text": state.full_content, "steps": state.all_steps}
if state.all_tool_calls:
content_json["tool_calls"] = state.all_tool_calls
db = SessionLocal()
try:
db.add(Message(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
usage=json.dumps(state.total_usage) if state.total_usage else None
))
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# ==================== Main Orchestrator ====================
async def stream_response( async def stream_response(
self, self,
conversation: Conversation, conversation: Conversation,
@ -100,360 +295,100 @@ class ChatService:
username: str = None, username: str = None,
workspace: str = None, workspace: str = None,
user_permission_level: int = 1 user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]: ) -> AsyncGenerator[str, None]:
""" """Main streaming orchestrator - step-by-step flow."""
Streaming response generator
Yields raw SSE event strings for direct forwarding.
"""
try: try:
messages = self.build_messages(conversation) # Step 1: Initialize
state, llm, tools, model, max_tokens, tool_context = self.init_stream_state(
messages.append({ conversation, user_message, enabled_tools or []
"role": "user", )
"content": json.dumps({"text": user_message, "attachments": []}) tool_context.update
({
"user_id": user_id,
"username": username,
"workspace": workspace,
"user_permission_level": user_permission_level
}) })
# Get tools based on enabled_tools filter # ReAct loop
if enabled_tools:
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
else:
tools = []
llm, provider_max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
# 直接使用 provider 的 max_tokens
max_tokens = provider_max_tokens
# State tracking
all_steps = []
all_tool_calls = []
all_tool_results = []
step_index = 0
# Token usage tracking
total_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
# Global step IDs for thinking and text (persist across iterations)
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
for _ in range(MAX_ITERATIONS): for _ in range(MAX_ITERATIONS):
# Stream from LLM self.reset_iteration_state(state)
full_content = ""
full_thinking = ""
tool_calls_list = []
# Step tracking - use unified step-{index} format # Step 2: Stream LLM
thinking_step_id = None async for sse_line, chunk in self.stream_llm_response(
thinking_step_idx = None llm, model, state.messages, tools,
text_step_id = None conversation.temperature, max_tokens,
text_step_idx = None thinking_enabled or conversation.thinking_enabled
async for sse_line in llm.stream_call(
model=model,
messages=messages,
tools=tools,
temperature=conversation.temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled
): ):
# Parse SSE line # Handle error events
# Format: "event: xxx\ndata: {...}\n\n" event_type, data_str = self.parse_sse_line(sse_line)
event_type = None if event_type == 'error':
data_str = None error_data = json.loads(data_str) if data_str else {}
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
return
for line in sse_line.strip().split('\n'): if not chunk:
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
if data_str is None:
continue continue
# Handle error events from LLM # Extract usage
if event_type == 'error':
try:
error_data = json.loads(data_str)
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
except json.JSONDecodeError:
yield _sse_event("error", {"content": data_str})
return
# Parse the data
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"})
return
# 提取 API 返回的 usage 信息
if "usage" in chunk: if "usage" in chunk:
usage = chunk["usage"] u = chunk["usage"]
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0) state.total_usage = {
total_usage["completion_tokens"] = usage.get("completion_tokens", 0) "prompt_tokens": u.get("prompt_tokens", 0),
total_usage["total_tokens"] = usage.get("total_tokens", 0) "completion_tokens": u.get("completion_tokens", 0),
"total_tokens": u.get("total_tokens", 0)
}
# Check for error in response # Check for API errors
if "error" in chunk: if "error" in chunk:
error_msg = chunk["error"].get("message", str(chunk["error"])) yield _sse_event("error", {"content": f"API Error: {chunk['error'].get('message', str(chunk['error']))}"})
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
return return
# Get delta # Get delta
choices = chunk.get("choices", [])
delta = None delta = None
choices = chunk.get("choices", [])
if choices: if choices:
delta = choices[0].get("delta", {}) delta = choices[0].get("delta", {})
# If no delta but has message (non-streaming response)
if not delta: if not delta:
message = choices[0].get("message", {}) content = choices[0].get("message", {}).get("content")
if message.get("content"): if content:
delta = {"content": message["content"]} delta = {"content": content}
if not delta: if not delta:
# Check if there's any content in the response (for non-standard LLM responses)
content = chunk.get("content") or chunk.get("message", {}).get("content", "") content = chunk.get("content") or chunk.get("message", {}).get("content", "")
if content: if content:
delta = {"content": content} delta = {"content": content}
if delta: if delta:
# Handle reasoning (thinking) # Step 2b: Process delta
reasoning = delta.get("reasoning_content", "") for event in self.process_delta(state, delta):
if reasoning: yield event
prev_thinking_len = len(full_thinking)
full_thinking += reasoning
if prev_thinking_len == 0: # New thinking stream started
thinking_step_idx = step_index
thinking_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
}
})
# Handle content
content = delta.get("content", "")
if content:
prev_content_len = len(full_content)
full_content += content
if prev_content_len == 0: # New text stream started
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
}
})
# Accumulate tool calls
tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta:
idx = tc.get("index", 0)
if idx >= len(tool_calls_list):
tool_calls_list.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""}
})
func = tc.get("function", {})
if func.get("name"):
tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
# Save thinking step # Save steps after streaming
if thinking_step_id is not None: self.save_steps(state)
all_steps.append({
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
})
# Save text step # Step 3: Execute tools if present
if text_step_id is not None: if state.tool_calls_list:
all_steps.append({ results, events = self.execute_tools(state, tool_context)
"id": text_step_id, for event in events:
"index": text_step_idx, yield event
"type": "text", self.update_messages_for_next_iteration(state, results)
"content": full_content
})
# Handle tool calls
if tool_calls_list:
all_tool_calls.extend(tool_calls_list)
# Yield tool_call steps - use unified step-{index} format
tool_call_step_ids = [] # Track step IDs for tool calls
for tc in tool_calls_list:
call_step_idx = step_index
call_step_id = f"step-{step_index}"
tool_call_step_ids.append(call_step_id)
step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
all_steps.append(call_step)
yield _sse_event("process_step", {"step": call_step})
# Execute tools
tool_context = {
"workspace": workspace,
"user_id": user_id,
"username": username,
"user_permission_level": user_permission_level
}
tool_results = self.tool_executor.process_tool_calls_parallel(
tool_calls_list, tool_context
)
# Yield tool_result steps - use unified step-{index} format
for i, tr in enumerate(tool_results):
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
result_step_idx = step_index
result_step_id = f"step-{step_index}"
step_index += 1
content = tr.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id, # Reference to the tool_call step
"name": tr.get("name", ""),
"content": content,
"success": success
}
all_steps.append(result_step)
yield _sse_event("process_step", {"step": result_step})
all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": tr.get("content", "")
})
# Add assistant message with tool calls for next iteration
messages.append({
"role": "assistant",
"content": full_content or "",
"tool_calls": tool_calls_list
})
messages.extend(all_tool_results[-len(tool_results):])
all_tool_results = []
continue continue
# No tool calls - final iteration, save message # Step 4: Finalize (no tool calls)
msg_id = str(uuid.uuid4()) token_count = state.total_usage.get("completion_tokens", 0)
self.save_message(conversation.id, state, token_count)
actual_token_count = total_usage.get("completion_tokens", 0) yield _sse_event("done", {"message_id": str(uuid.uuid4()), "token_count": token_count, "usage": state.total_usage})
logger.info(f"total_usage: {total_usage}")
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
actual_token_count,
total_usage
)
yield _sse_event("done", {
"message_id": msg_id,
"token_count": actual_token_count,
"usage": total_usage
})
return return
# Max iterations exceeded - save message before error # Max iterations exceeded
if full_content or all_tool_calls: if state.full_content or state.all_tool_calls:
msg_id = str(uuid.uuid4()) self.save_message(conversation.id, state, state.total_usage.get("completion_tokens", 0))
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
actual_token_count,
total_usage
)
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
except Exception as e: except Exception as e:
yield _sse_event("error", {"content": str(e)}) yield _sse_event("error", {"content": str(e)})
def _save_message(
self,
conversation_id: str,
msg_id: str,
full_content: str,
all_tool_calls: list,
all_tool_results: list,
all_steps: list,
token_count: int = 0,
usage: dict = None
):
"""Save the assistant message to database."""
content_json = {
"text": full_content,
"steps": all_steps
}
if all_tool_calls:
content_json["tool_calls"] = all_tool_calls
db = SessionLocal()
try:
msg = Message(
id=msg_id,
conversation_id=conversation_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
usage=json.dumps(usage) if usage else None
)
db.add(msg)
db.commit()
except Exception as e:
db.rollback()
raise
finally:
db.close()
# Global chat service # Global service
chat_service = ChatService() chat_service = ChatService()