From 6961039db044b1f82fe38d36779dd40fa7477847 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 16 Apr 2026 21:52:07 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=8B=86=E5=88=86task=20=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- luxx/services/chat.py | 657 +++++++++++++++++++----------------------- 1 file changed, 296 insertions(+), 361 deletions(-) diff --git a/luxx/services/chat.py b/luxx/services/chat.py index b5c9819..9b88e13 100644 --- a/luxx/services/chat.py +++ b/luxx/services/chat.py @@ -1,9 +1,15 @@ -"""Chat service module""" +"""Chat service module - Refactored with step-by-step flow""" import json import uuid import logging +from typing import List, Dict, Any, AsyncGenerator, Tuple, Optional + +# For Python < 3.9 compatibility +try: + from typing import List +except ImportError: + pass -from typing import List, Dict,AsyncGenerator from luxx.models import Conversation, Message, LLMProvider from luxx.tools.executor import ToolExecutor from luxx.tools.core import registry @@ -11,85 +17,274 @@ from luxx.services.llm_client import LLMClient from luxx.database import SessionLocal logger = logging.getLogger(__name__) -# Maximum iterations to prevent infinite loops + MAX_ITERATIONS = 20 def _sse_event(event: str, data: dict) -> str: - """Format a Server-Sent Event string.""" + """Format SSE event string.""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" -def get_llm_client(conversation: Conversation = None): - """Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)""" - max_tokens = None +def get_llm_client(conversation: Conversation = None) -> Tuple[LLMClient, Optional[int]]: + """Get LLM client from conversation's provider.""" if conversation and conversation.provider_id: db = SessionLocal() try: provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first() if provider: - max_tokens = provider.max_tokens - client = LLMClient( + return LLMClient( api_key=provider.api_key, api_url=provider.base_url, model=provider.default_model - ) - return client, max_tokens + ), provider.max_tokens finally: db.close() + return LLMClient(), None + + +class StreamState: + """Holds streaming state across iterations.""" - # Fallback to global config - client = LLMClient() - return client, max_tokens + def __init__(self): + self.messages: List[Dict] = [] + self.all_steps: List[Dict] = [] + self.all_tool_calls: List[Dict] = [] + self.all_tool_results: List[Dict] = [] + self.step_index: int = 0 + self.total_usage: Dict[str, int] = { + "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 + } + # Current iteration state + self.full_content: str = "" + self.full_thinking: str = "" + self.tool_calls_list: List[Dict] = [] + self.thinking_step_id: Optional[str] = None + self.thinking_step_idx: Optional[int] = None + self.text_step_id: Optional[str] = None + self.text_step_idx: Optional[int] = None class ChatService: - """Chat service with tool support""" + """Chat service with step-by-step flow architecture.""" def __init__(self): self.tool_executor = ToolExecutor() - def build_messages( - self, - conversation: Conversation, - include_system: bool = True - ) -> List[Dict[str, str]]: - """Build message list""" - + # ==================== Step 1: Initialize ==================== + + def build_messages(self, conversation: Conversation, user_message: str) -> List[Dict[str, str]]: + """Build message list including user message.""" messages = [] - - if include_system and conversation.system_prompt: - messages.append({ - "role": "system", - "content": conversation.system_prompt - }) + if conversation.system_prompt: + messages.append({"role": "system", "content": conversation.system_prompt}) db = SessionLocal() try: - db_messages = db.query(Message).filter( + for msg in db.query(Message).filter( Message.conversation_id == conversation.id - ).order_by(Message.created_at).all() - - for msg in db_messages: - # Parse JSON content if possible + ).order_by(Message.created_at).all(): try: - content_obj = json.loads(msg.content) if msg.content else {} - if isinstance(content_obj, dict): - content = content_obj.get("text", msg.content) - else: - content = msg.content + obj = json.loads(msg.content) if msg.content else {} + content = obj.get("text", msg.content) if isinstance(obj, dict) else msg.content except (json.JSONDecodeError, TypeError): content = msg.content - - messages.append({ - "role": msg.role, - "content": content - }) + messages.append({"role": msg.role, "content": content}) finally: db.close() + messages.append({"role": "user", "content": json.dumps({"text": user_message, "attachments": []})}) return messages + def init_stream_state(self, conversation: Conversation, user_message: str, enabled_tools: list) -> Tuple[StreamState, LLMClient, Dict, str, Optional[int]]: + """Initialize streaming state. Returns: (state, llm, tools, model, max_tokens)""" + state = StreamState() + state.messages = self.build_messages(conversation, user_message) + + tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else [] + + llm, max_tokens = get_llm_client(conversation) + model = conversation.model or llm.default_model or "gpt-4" + + tool_context = {"workspace": None, "user_id": None, "username": None, "user_permission_level": 1} + + return state, llm, tools, model, max_tokens, tool_context + + # ==================== Step 2: Stream LLM ==================== + + def parse_sse_line(self, sse_line: str) -> Tuple[Optional[str], Optional[str]]: + """Parse SSE line into (event_type, data_str).""" + event_type = data_str = None + for line in sse_line.strip().split('\n'): + if line.startswith('event: '): + event_type = line[7:].strip() + elif line.startswith('data: '): + data_str = line[6:].strip() + return event_type, data_str + + def stream_llm_response(self, llm, model: str, messages: List[Dict], tools: list, + temperature: float, max_tokens: int, thinking_enabled: bool): + """ + Stream LLM response and yield (sse_line, parsed_chunk) pairs. + """ + for sse_line in llm.stream_call( + model=model, messages=messages, tools=tools, + temperature=temperature, max_tokens=max_tokens or 8192, + thinking_enabled=thinking_enabled + ): + _, data_str = self.parse_sse_line(sse_line) + chunk = None + if data_str: + try: + chunk = json.loads(data_str) + except json.JSONDecodeError: + pass + yield sse_line, chunk + + def process_delta(self, state: StreamState, delta: dict) -> List[str]: + """ + Process a single delta, return list of SSE event strings. + """ + events = [] + + # Handle thinking/reasoning + reasoning = delta.get("reasoning_content", "") + if reasoning: + if not state.full_thinking: + state.thinking_step_idx = state.step_index + state.thinking_step_id = f"step-{state.step_index}" + state.step_index += 1 + state.full_thinking += reasoning + events.append(_sse_event("process_step", { + "step": {"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking} + })) + + # Handle content + content = delta.get("content", "") + if content: + if not state.full_content: + state.text_step_idx = state.step_index + state.text_step_id = f"step-{state.step_index}" + state.step_index += 1 + state.full_content += content + events.append(_sse_event("process_step", { + "step": {"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content} + })) + + # Handle tool calls + for tc in delta.get("tool_calls", []): + idx = tc.get("index", 0) + if idx >= len(state.tool_calls_list): + state.tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}}) + func = tc.get("function", {}) + if func.get("name"): + state.tool_calls_list[idx]["function"]["name"] += func["name"] + if func.get("arguments"): + state.tool_calls_list[idx]["function"]["arguments"] += func["arguments"] + + return events + + def save_steps(self, state: StreamState): + """Save current iteration steps to all_steps.""" + if state.thinking_step_id: + state.all_steps.append({"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking}) + if state.text_step_id: + state.all_steps.append({"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content}) + + # ==================== Step 3: Execute Tools ==================== + + def execute_tools(self, state: StreamState, tool_context: Dict) -> Tuple[List[Dict], List[str]]: + """ + Execute tools and return (results, events). + """ + if not state.tool_calls_list: + return [], [] + + state.all_tool_calls.extend(state.tool_calls_list) + tool_call_ids = [] + events = [] + + # Yield tool_call steps + for tc in state.tool_calls_list: + step_id = f"step-{state.step_index}" + tool_call_ids.append(step_id) + state.step_index += 1 + step = { + "id": step_id, "index": len(state.all_steps), "type": "tool_call", + "id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] + } + state.all_steps.append(step) + events.append(_sse_event("process_step", {"step": step})) + + # Execute tools + results = self.tool_executor.process_tool_calls_parallel(state.tool_calls_list, tool_context) + + # Yield tool_result steps + for i, tr in enumerate(results): + ref_id = tool_call_ids[i] if i < len(tool_call_ids) else f"tool-{i}" + step_id = f"step-{state.step_index}" + state.step_index += 1 + + content = tr.get("content", "") + success = True + try: + obj = json.loads(content) + if isinstance(obj, dict): + success = obj.get("success", True) + except: + pass + + step = { + "id": step_id, "index": len(state.all_steps), "type": "tool_result", + "id_ref": ref_id, "name": tr.get("name", ""), "content": content, "success": success + } + state.all_steps.append(step) + events.append(_sse_event("process_step", {"step": step})) + + state.all_tool_results.append({"role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": content}) + + return results, events + + def update_messages_for_next_iteration(self, state: StreamState, results: List[Dict]): + """Update messages list with assistant response and tool results for next iteration.""" + state.messages.append({"role": "assistant", "content": state.full_content or "", "tool_calls": state.tool_calls_list}) + if results: + state.messages.extend(state.all_tool_results[-len(results):]) + state.all_tool_results = [] + + def reset_iteration_state(self, state: StreamState): + """Reset state for next iteration.""" + state.full_content = state.full_thinking = "" + state.tool_calls_list = [] + state.thinking_step_id = state.thinking_step_idx = state.text_step_id = state.text_step_idx = None + + # ==================== Step 4: Finalize ==================== + + def save_message(self, conversation_id: str, state: StreamState, token_count: int): + """Save assistant message to database.""" + content_json = {"text": state.full_content, "steps": state.all_steps} + if state.all_tool_calls: + content_json["tool_calls"] = state.all_tool_calls + + db = SessionLocal() + try: + db.add(Message( + id=str(uuid.uuid4()), + conversation_id=conversation_id, + role="assistant", + content=json.dumps(content_json, ensure_ascii=False), + token_count=token_count, + usage=json.dumps(state.total_usage) if state.total_usage else None + )) + db.commit() + except Exception: + db.rollback() + raise + finally: + db.close() + + # ==================== Main Orchestrator ==================== + async def stream_response( self, conversation: Conversation, @@ -100,360 +295,100 @@ class ChatService: username: str = None, workspace: str = None, user_permission_level: int = 1 - ) -> AsyncGenerator[Dict[str, str], None]: - """ - Streaming response generator - - Yields raw SSE event strings for direct forwarding. - """ + ) -> AsyncGenerator[str, None]: + """Main streaming orchestrator - step-by-step flow.""" try: - messages = self.build_messages(conversation) - - messages.append({ - "role": "user", - "content": json.dumps({"text": user_message, "attachments": []}) + # Step 1: Initialize + state, llm, tools, model, max_tokens, tool_context = self.init_stream_state( + conversation, user_message, enabled_tools or [] + ) + tool_context.update + ({ + "user_id": user_id, + "username": username, + "workspace": workspace, + "user_permission_level": user_permission_level }) - # Get tools based on enabled_tools filter - if enabled_tools: - tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] - else: - tools = [] - - llm, provider_max_tokens = get_llm_client(conversation) - model = conversation.model or llm.default_model or "gpt-4" - # 直接使用 provider 的 max_tokens - max_tokens = provider_max_tokens - - # State tracking - all_steps = [] - all_tool_calls = [] - all_tool_results = [] - step_index = 0 - - # Token usage tracking - total_usage = { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - } - - # Global step IDs for thinking and text (persist across iterations) - thinking_step_id = None - thinking_step_idx = None - text_step_id = None - text_step_idx = None - + # ReAct loop for _ in range(MAX_ITERATIONS): - # Stream from LLM - full_content = "" - full_thinking = "" - tool_calls_list = [] + self.reset_iteration_state(state) - # Step tracking - use unified step-{index} format - thinking_step_id = None - thinking_step_idx = None - text_step_id = None - text_step_idx = None - - async for sse_line in llm.stream_call( - model=model, - messages=messages, - tools=tools, - temperature=conversation.temperature, - max_tokens=max_tokens or 8192, - thinking_enabled=thinking_enabled or conversation.thinking_enabled + # Step 2: Stream LLM + async for sse_line, chunk in self.stream_llm_response( + llm, model, state.messages, tools, + conversation.temperature, max_tokens, + thinking_enabled or conversation.thinking_enabled ): - # Parse SSE line - # Format: "event: xxx\ndata: {...}\n\n" - event_type = None - data_str = None + # Handle error events + event_type, data_str = self.parse_sse_line(sse_line) + if event_type == 'error': + error_data = json.loads(data_str) if data_str else {} + yield _sse_event("error", {"content": error_data.get("content", "Unknown error")}) + return - for line in sse_line.strip().split('\n'): - if line.startswith('event: '): - event_type = line[7:].strip() - elif line.startswith('data: '): - data_str = line[6:].strip() - - if data_str is None: + if not chunk: continue - # Handle error events from LLM - if event_type == 'error': - try: - error_data = json.loads(data_str) - yield _sse_event("error", {"content": error_data.get("content", "Unknown error")}) - except json.JSONDecodeError: - yield _sse_event("error", {"content": data_str}) - return - - # Parse the data - try: - chunk = json.loads(data_str) - except json.JSONDecodeError: - yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"}) - return - - # 提取 API 返回的 usage 信息 + # Extract usage if "usage" in chunk: - usage = chunk["usage"] - total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0) - total_usage["completion_tokens"] = usage.get("completion_tokens", 0) - total_usage["total_tokens"] = usage.get("total_tokens", 0) + u = chunk["usage"] + state.total_usage = { + "prompt_tokens": u.get("prompt_tokens", 0), + "completion_tokens": u.get("completion_tokens", 0), + "total_tokens": u.get("total_tokens", 0) + } - # Check for error in response + # Check for API errors if "error" in chunk: - error_msg = chunk["error"].get("message", str(chunk["error"])) - yield _sse_event("error", {"content": f"API Error: {error_msg}"}) + yield _sse_event("error", {"content": f"API Error: {chunk['error'].get('message', str(chunk['error']))}"}) return # Get delta - choices = chunk.get("choices", []) delta = None - + choices = chunk.get("choices", []) if choices: delta = choices[0].get("delta", {}) - # If no delta but has message (non-streaming response) if not delta: - message = choices[0].get("message", {}) - if message.get("content"): - delta = {"content": message["content"]} + content = choices[0].get("message", {}).get("content") + if content: + delta = {"content": content} if not delta: - # Check if there's any content in the response (for non-standard LLM responses) content = chunk.get("content") or chunk.get("message", {}).get("content", "") if content: delta = {"content": content} if delta: - # Handle reasoning (thinking) - reasoning = delta.get("reasoning_content", "") - if reasoning: - prev_thinking_len = len(full_thinking) - full_thinking += reasoning - if prev_thinking_len == 0: # New thinking stream started - thinking_step_idx = step_index - thinking_step_id = f"step-{step_index}" - step_index += 1 - yield _sse_event("process_step", { - "step": { - "id": thinking_step_id, - "index": thinking_step_idx, - "type": "thinking", - "content": full_thinking - } - }) - - # Handle content - content = delta.get("content", "") - if content: - prev_content_len = len(full_content) - full_content += content - if prev_content_len == 0: # New text stream started - text_step_idx = step_index - text_step_id = f"step-{step_index}" - step_index += 1 - yield _sse_event("process_step", { - "step": { - "id": text_step_id, - "index": text_step_idx, - "type": "text", - "content": full_content - } - }) - - # Accumulate tool calls - tool_calls_delta = delta.get("tool_calls", []) - for tc in tool_calls_delta: - idx = tc.get("index", 0) - if idx >= len(tool_calls_list): - tool_calls_list.append({ - "id": tc.get("id", ""), - "type": "function", - "function": {"name": "", "arguments": ""} - }) - func = tc.get("function", {}) - if func.get("name"): - tool_calls_list[idx]["function"]["name"] += func["name"] - if func.get("arguments"): - tool_calls_list[idx]["function"]["arguments"] += func["arguments"] + # Step 2b: Process delta + for event in self.process_delta(state, delta): + yield event - # Save thinking step - if thinking_step_id is not None: - all_steps.append({ - "id": thinking_step_id, - "index": thinking_step_idx, - "type": "thinking", - "content": full_thinking - }) + # Save steps after streaming + self.save_steps(state) - # Save text step - if text_step_id is not None: - all_steps.append({ - "id": text_step_id, - "index": text_step_idx, - "type": "text", - "content": full_content - }) - - # Handle tool calls - if tool_calls_list: - all_tool_calls.extend(tool_calls_list) - - # Yield tool_call steps - use unified step-{index} format - tool_call_step_ids = [] # Track step IDs for tool calls - for tc in tool_calls_list: - call_step_idx = step_index - call_step_id = f"step-{step_index}" - tool_call_step_ids.append(call_step_id) - step_index += 1 - call_step = { - "id": call_step_id, - "index": call_step_idx, - "type": "tool_call", - "id_ref": tc.get("id", ""), - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"] - } - all_steps.append(call_step) - yield _sse_event("process_step", {"step": call_step}) - - # Execute tools - tool_context = { - "workspace": workspace, - "user_id": user_id, - "username": username, - "user_permission_level": user_permission_level - } - tool_results = self.tool_executor.process_tool_calls_parallel( - tool_calls_list, tool_context - ) - - # Yield tool_result steps - use unified step-{index} format - for i, tr in enumerate(tool_results): - tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" - result_step_idx = step_index - result_step_id = f"step-{step_index}" - step_index += 1 - - content = tr.get("content", "") - success = True - try: - content_obj = json.loads(content) - if isinstance(content_obj, dict): - success = content_obj.get("success", True) - except: - pass - - result_step = { - "id": result_step_id, - "index": result_step_idx, - "type": "tool_result", - "id_ref": tool_call_step_id, # Reference to the tool_call step - "name": tr.get("name", ""), - "content": content, - "success": success - } - all_steps.append(result_step) - yield _sse_event("process_step", {"step": result_step}) - - all_tool_results.append({ - "role": "tool", - "tool_call_id": tr.get("tool_call_id", ""), - "content": tr.get("content", "") - }) - - # Add assistant message with tool calls for next iteration - messages.append({ - "role": "assistant", - "content": full_content or "", - "tool_calls": tool_calls_list - }) - messages.extend(all_tool_results[-len(tool_results):]) - all_tool_results = [] + # Step 3: Execute tools if present + if state.tool_calls_list: + results, events = self.execute_tools(state, tool_context) + for event in events: + yield event + self.update_messages_for_next_iteration(state, results) continue - # No tool calls - final iteration, save message - msg_id = str(uuid.uuid4()) - - actual_token_count = total_usage.get("completion_tokens", 0) - logger.info(f"total_usage: {total_usage}") - - self._save_message( - conversation.id, - msg_id, - full_content, - all_tool_calls, - all_tool_results, - all_steps, - actual_token_count, - total_usage - ) - - yield _sse_event("done", { - "message_id": msg_id, - "token_count": actual_token_count, - "usage": total_usage - }) + # Step 4: Finalize (no tool calls) + token_count = state.total_usage.get("completion_tokens", 0) + self.save_message(conversation.id, state, token_count) + yield _sse_event("done", {"message_id": str(uuid.uuid4()), "token_count": token_count, "usage": state.total_usage}) return - # Max iterations exceeded - save message before error - if full_content or all_tool_calls: - msg_id = str(uuid.uuid4()) - self._save_message( - conversation.id, - msg_id, - full_content, - all_tool_calls, - all_tool_results, - all_steps, - actual_token_count, - total_usage - ) + # Max iterations exceeded + if state.full_content or state.all_tool_calls: + self.save_message(conversation.id, state, state.total_usage.get("completion_tokens", 0)) yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) except Exception as e: yield _sse_event("error", {"content": str(e)}) - - def _save_message( - self, - conversation_id: str, - msg_id: str, - full_content: str, - all_tool_calls: list, - all_tool_results: list, - all_steps: list, - token_count: int = 0, - usage: dict = None - ): - """Save the assistant message to database.""" - - - content_json = { - "text": full_content, - "steps": all_steps - } - if all_tool_calls: - content_json["tool_calls"] = all_tool_calls - - db = SessionLocal() - try: - msg = Message( - id=msg_id, - conversation_id=conversation_id, - role="assistant", - content=json.dumps(content_json, ensure_ascii=False), - token_count=token_count, - usage=json.dumps(usage) if usage else None - ) - db.add(msg) - db.commit() - except Exception as e: - db.rollback() - raise - finally: - db.close() -# Global chat service +# Global service chat_service = ChatService()