diff --git a/backend/routes/__init__.py b/backend/routes/__init__.py index 5996c11..4919ba9 100644 --- a/backend/routes/__init__.py +++ b/backend/routes/__init__.py @@ -17,7 +17,7 @@ def register_routes(app: Flask): client = LLMClient(MODEL_CONFIG) init_chat_service(client) - # Register LLM client in service locator so tools (e.g. agent_task) can access it + # Register LLM client in service locator so tools (e.g. multi_agent) can access it from backend.tools import register_service register_service("llm_client", client) diff --git a/backend/services/chat.py b/backend/services/chat.py index 3d85987..d7b0d2e 100644 --- a/backend/services/chat.py +++ b/backend/services/chat.py @@ -61,6 +61,9 @@ class ChatService: """ conv_id = conv.id conv_model = conv.model + conv_max_tokens = conv.max_tokens + conv_temperature = conv.temperature + conv_thinking_enabled = conv.thinking_enabled app = current_app._get_current_object() tools = registry.list_all() if tools_enabled else None initial_messages = build_messages(conv, project_id) @@ -85,7 +88,9 @@ class ChatService: for iteration in range(MAX_ITERATIONS): try: stream_result = self._stream_llm_response( - app, conv_id, messages, tools, tool_choice, step_index + app, messages, tools, tool_choice, step_index, + conv_model, conv_max_tokens, conv_temperature, + conv_thinking_enabled, ) except requests.exceptions.HTTPError as e: resp = e.response @@ -185,7 +190,7 @@ class ChatService: # Append assistant message + tool results for the next iteration messages.append({ "role": "assistant", - "content": full_content or None, + "content": full_content or "", "tool_calls": tool_calls_list, }) messages.extend(tool_results) @@ -232,7 +237,8 @@ class ChatService: # ------------------------------------------------------------------ def _stream_llm_response( - self, app, conv_id, messages, tools, tool_choice, step_index + self, app, messages, tools, tool_choice, step_index, + model, max_tokens, temperature, thinking_enabled, ): """Call LLM streaming API and parse the response. @@ -253,13 +259,12 @@ class ChatService: sse_chunks = [] # Collect SSE events to yield later with app.app_context(): - active_conv = db.session.get(Conversation, conv_id) resp = self.llm.call( - model=active_conv.model, + model=model, messages=messages, - max_tokens=active_conv.max_tokens, - temperature=active_conv.temperature, - thinking_enabled=active_conv.thinking_enabled, + max_tokens=max_tokens, + temperature=temperature, + thinking_enabled=thinking_enabled, tools=tools, tool_choice=tool_choice, stream=True, @@ -327,39 +332,8 @@ class ChatService: sse_chunks, ) - def _execute_tools_safe(self, app, executor, tool_calls_list, context): - """Execute tool calls with top-level error wrapping. - - If an unexpected exception occurs during tool execution, it is - converted into error tool results instead of crashing the stream. - """ - try: - if len(tool_calls_list) > 1: - with app.app_context(): - tool_results = executor.process_tool_calls_parallel( - tool_calls_list, context, max_workers=TOOL_MAX_WORKERS - ) - else: - with app.app_context(): - tool_results = executor.process_tool_calls( - tool_calls_list, context - ) - except Exception as e: - logger.exception("Error during tool execution") - tool_results = [ - { - "role": "tool", - "tool_call_id": tc["id"], - "name": tc["function"]["name"], - "content": json.dumps({ - "success": False, - "error": f"Tool execution failed: {e}", - }, ensure_ascii=False), - } - for tc in tool_calls_list - ] - - # Truncate oversized tool result content + def _truncate_tool_results(self, tool_results): + """Truncate oversized tool result content in-place and return the list.""" for tr in tool_results: if len(tr["content"]) > TOOL_RESULT_MAX_LENGTH: try: @@ -380,9 +354,45 @@ class ChatService: ensure_ascii=False, default=str, )[:TOOL_RESULT_MAX_LENGTH] - return tool_results + def _execute_tools_safe(self, app, executor, tool_calls_list, context): + """Execute tool calls with top-level error wrapping. + + If an unexpected exception occurs during tool execution, it is + converted into error tool results instead of crashing the stream. + """ + try: + if len(tool_calls_list) > 1: + with app.app_context(): + return self._truncate_tool_results( + executor.process_tool_calls_parallel( + tool_calls_list, context, max_workers=TOOL_MAX_WORKERS + ) + ) + else: + with app.app_context(): + return self._truncate_tool_results( + executor.process_tool_calls( + tool_calls_list, context + ) + ) + except Exception as e: + logger.exception("Error during tool execution") + tool_results = [ + { + "role": "tool", + "tool_call_id": tc["id"], + "name": tc["function"]["name"], + "content": json.dumps({ + "success": False, + "error": f"Tool execution failed: {e}", + }, ensure_ascii=False), + } + for tc in tool_calls_list + ] + return self._truncate_tool_results(tool_results) + def _save_message( self, app, conv_id, conv_model, msg_id, full_content, all_tool_calls, all_tool_results, diff --git a/backend/tools/__init__.py b/backend/tools/__init__.py index 6b436d6..49433b6 100644 --- a/backend/tools/__init__.py +++ b/backend/tools/__init__.py @@ -21,7 +21,7 @@ from backend.tools.executor import ToolExecutor # --------------------------------------------------------------------------- -# Service locator – allows tools (e.g. agent_task) to access LLM client +# Service locator – allows tools (e.g. multi_agent) to access LLM client # --------------------------------------------------------------------------- _services: dict = {} diff --git a/backend/tools/builtin/agent.py b/backend/tools/builtin/agent.py index fef021f..0b3f7a1 100644 --- a/backend/tools/builtin/agent.py +++ b/backend/tools/builtin/agent.py @@ -1,8 +1,7 @@ -"""Multi-agent tools for concurrent and batch task execution. +"""Multi-agent tool for spawning concurrent sub-agents. Provides: -- parallel_execute: Run multiple tool calls concurrently -- agent_task: Spawn sub-agents with their own LLM conversation loops +- multi_agent: Spawn sub-agents with independent LLM conversation loops """ import json from concurrent.futures import ThreadPoolExecutor, as_completed @@ -13,118 +12,36 @@ from backend.tools.core import registry from backend.tools.executor import ToolExecutor -# --------------------------------------------------------------------------- -# parallel_execute – run multiple tool calls concurrently -# --------------------------------------------------------------------------- +def _to_executor_calls(tool_calls: list, id_prefix: str = "tc") -> list: + """Normalize tool calls into executor-compatible format. -@tool( - name="parallel_execute", - description=( - "Execute multiple tool calls concurrently for better performance. " - "Use when you have several independent operations that don't depend on each other " - "(e.g. reading multiple files, running multiple searches, fetching several pages). " - "Results are returned in the same order as the input." - ), - parameters={ - "type": "object", - "properties": { - "tool_calls": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "Tool name to execute", - }, - "arguments": { - "type": "object", - "description": "Arguments for the tool", - }, - }, - "required": ["name", "arguments"], - }, - "description": "List of tool calls to execute in parallel (max 10)", - }, - "concurrency": { - "type": "integer", - "description": "Max concurrent executions (1-5, default 3)", - "default": 3, - }, - }, - "required": ["tool_calls"], - }, - category="agent", -) -def parallel_execute(arguments: dict) -> dict: - """Execute multiple tool calls concurrently. - - Args: - arguments: { - "tool_calls": [ - {"name": "file_read", "arguments": {"path": "a.py"}}, - {"name": "web_search", "arguments": {"query": "python"}} - ], - "concurrency": 3, - "_project_id": "..." // injected by executor - } - - Returns: - {"results": [{index, tool_name, success, data/error}]} + Accepts two input shapes: + - LLM format: {"function": {"name": ..., "arguments": ...}} + - Simple format: {"name": ..., "arguments": ...} """ - tool_calls = arguments["tool_calls"] - concurrency = min(max(arguments.get("concurrency", 3), 1), 5) - - if len(tool_calls) > 10: - return {"success": False, "error": "Maximum 10 tool calls allowed per parallel execution"} - - # Build executor context from injected fields - context = {} - project_id = arguments.get("_project_id") - if project_id: - context["project_id"] = project_id - - # Format tool_calls into executor-compatible format executor_calls = [] for i, tc in enumerate(tool_calls): - executor_calls.append({ - "id": f"pe-{i}", - "type": "function", - "function": { - "name": tc["name"], - "arguments": json.dumps(tc["arguments"], ensure_ascii=False), - }, - }) + if "function" in tc: + func = tc["function"] + executor_calls.append({ + "id": tc.get("id", f"{id_prefix}-{i}"), + "type": tc.get("type", "function"), + "function": { + "name": func["name"], + "arguments": func["arguments"], + }, + }) + else: + executor_calls.append({ + "id": f"{id_prefix}-{i}", + "type": "function", + "function": { + "name": tc["name"], + "arguments": json.dumps(tc["arguments"], ensure_ascii=False), + }, + }) + return executor_calls - # Use ToolExecutor for proper context injection, caching and dedup - executor = ToolExecutor(registry=registry, enable_cache=False) - executor_results = executor.process_tool_calls_parallel( - executor_calls, context, max_workers=concurrency - ) - - # Format output - results = [] - for er in executor_results: - try: - content = json.loads(er["content"]) if isinstance(er["content"], str) else er["content"] - except (json.JSONDecodeError, TypeError): - content = {"success": False, "error": "Failed to parse result"} - results.append({ - "index": len(results), - "tool_name": er["name"], - **content, - }) - - return { - "success": True, - "results": results, - "total": len(results), - } - - -# --------------------------------------------------------------------------- -# agent_task – spawn sub-agents with independent LLM conversation loops -# --------------------------------------------------------------------------- def _run_sub_agent( task_name: str, @@ -160,7 +77,9 @@ def _run_sub_agent( tools = all_tools executor = ToolExecutor(registry=registry) - context = {"project_id": project_id} if project_id else None + context = {"model": model} + if project_id: + context["project_id"] = project_id # System prompt: instruction + reminder to give a final text answer system_msg = ( @@ -170,13 +89,17 @@ def _run_sub_agent( ) messages = [{"role": "system", "content": system_msg}] - for _ in range(max_iterations): + for i in range(max_iterations): + is_final = (i == max_iterations - 1) try: with app.app_context(): resp = llm_client.call( model=model, messages=messages, - tools=tools if tools else None, + # On the last iteration, don't pass tools so the LLM is + # forced to produce a final text response instead of calling + # more tools. + tools=None if is_final else (tools if tools else None), stream=False, max_tokens=min(max_tokens, 4096), temperature=0.7, @@ -196,19 +119,15 @@ def _run_sub_agent( message = choice["message"] if message.get("tool_calls"): - messages.append(message) + # Only extract needed fields — LLM response may contain extra + # fields (e.g. reasoning_content) that the API rejects on re-send + messages.append({ + "role": "assistant", + "content": message.get("content") or "", + "tool_calls": message["tool_calls"], + }) tc_list = message["tool_calls"] - # Convert OpenAI tool_calls to executor format - executor_calls = [] - for tc in tc_list: - executor_calls.append({ - "id": tc.get("id", ""), - "type": tc.get("type", "function"), - "function": { - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], - }, - }) + executor_calls = _to_executor_calls(tc_list) tool_results = executor.process_tool_calls(executor_calls, context) messages.extend(tool_results) else: @@ -226,7 +145,7 @@ def _run_sub_agent( "error": str(e), } - # Exhausted iterations without final response — return last LLM output if any + # Exhausted iterations without final response return { "task_name": task_name, "success": True, @@ -234,49 +153,49 @@ def _run_sub_agent( } -# @tool( -# name="agent_task", -# description=( -# "Spawn one or more sub-agents to work on tasks concurrently. " -# "Each agent runs its own independent conversation with the LLM and can use tools. " -# "Useful for parallel research, multi-file analysis, or dividing complex tasks into sub-tasks. " -# "Each agent is limited to 3 iterations and 4096 tokens to control cost." -# ), -# parameters={ -# "type": "object", -# "properties": { -# "tasks": { -# "type": "array", -# "items": { -# "type": "object", -# "properties": { -# "name": { -# "type": "string", -# "description": "Short name/identifier for this task", -# }, -# "instruction": { -# "type": "string", -# "description": "Detailed instruction for the sub-agent", -# }, -# "tools": { -# "type": "array", -# "items": {"type": "string"}, -# "description": ( -# "Tool names this agent can use (empty = all tools). " -# "e.g. ['file_read', 'file_list', 'web_search']" -# ), -# }, -# }, -# "required": ["name", "instruction"], -# }, -# "description": "Tasks for parallel sub-agents (max 5)", -# }, -# }, -# "required": ["tasks"], -# }, -# category="agent", -# ) -def agent_task(arguments: dict) -> dict: +@tool( + name="multi_agent", + description=( + "Spawn multiple sub-agents to work on tasks concurrently. " + "Each agent runs its own independent conversation with the LLM and can use tools. " + "Useful for parallel research, multi-file analysis, or dividing complex tasks into sub-tasks. " + "Each agent is limited to 3 iterations and 4096 tokens to control cost." + ), + parameters={ + "type": "object", + "properties": { + "tasks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Short name/identifier for this task", + }, + "instruction": { + "type": "string", + "description": "Detailed instruction for the sub-agent", + }, + "tools": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "Tool names this agent can use (empty = all tools). " + "e.g. ['file_read', 'file_list', 'web_search']" + ), + }, + }, + "required": ["name", "instruction"], + }, + "description": "Tasks for parallel sub-agents (max 5)", + }, + }, + "required": ["tasks"], + }, + category="agent", +) +def multi_agent(arguments: dict) -> dict: """Spawn sub-agents to work on tasks concurrently. Args: @@ -296,7 +215,7 @@ def agent_task(arguments: dict) -> dict: } Returns: - {"success": true, "results": [{task_name, success, response/error}]} + {"success": true, "results": [{task_name, success, response/error}], "total": int} """ from flask import current_app @@ -309,7 +228,8 @@ def agent_task(arguments: dict) -> dict: app = current_app._get_current_object() # Use injected model/project_id from executor context, fall back to defaults - model = arguments.get("_model", "glm-5") + from backend.config import DEFAULT_MODEL + model = arguments.get("_model") or DEFAULT_MODEL project_id = arguments.get("_project_id") # Execute agents concurrently (max 3 at a time) diff --git a/backend/tools/executor.py b/backend/tools/executor.py index 6cdb9b8..926e156 100644 --- a/backend/tools/executor.py +++ b/backend/tools/executor.py @@ -56,21 +56,80 @@ class ToolExecutor: """Inject context fields into tool arguments in-place. - file_* tools: inject project_id - - agent_task: inject model and project_id (prefixed with _ to avoid collisions) - - parallel_execute: inject project_id (prefixed with _ to avoid collisions) + - agent tools (multi_agent): inject _model and _project_id """ if not context: return if name.startswith("file_") and "project_id" in context: args["project_id"] = context["project_id"] - if name == "agent_task": + if name == "multi_agent": if "model" in context: args["_model"] = context["model"] if "project_id" in context: args["_project_id"] = context["project_id"] - if name == "parallel_execute": - if "project_id" in context: - args["_project_id"] = context["project_id"] + + def _prepare_call( + self, + call: dict, + context: Optional[dict], + seen_calls: set, + ) -> tuple: + """Parse, inject context, check dedup/cache for a single tool call. + + Returns a tagged tuple: + ("error", call_id, name, error_msg) + ("cached", call_id, name, result_dict) -- dedup or cache hit + ("execute", call_id, name, args, cache_key) + """ + name = call["function"]["name"] + args_str = call["function"]["arguments"] + call_id = call["id"] + + # Parse JSON arguments + try: + args = json.loads(args_str) if isinstance(args_str, str) else args_str + except json.JSONDecodeError: + return ("error", call_id, name, "Invalid JSON arguments") + + # Inject context + self._inject_context(name, args, context) + + # Dedup within same batch + call_key = f"{name}:{json.dumps(args, sort_keys=True)}" + if call_key in seen_calls: + return ("cached", call_id, name, + {"success": True, "data": None, "cached": True, "duplicate": True}) + seen_calls.add(call_key) + + # History dedup + history_result = self._check_duplicate_in_history(name, args) + if history_result is not None: + return ("cached", call_id, name, {**history_result, "cached": True}) + + # Cache check + cache_key = self._make_cache_key(name, args) + cached_result = self._get_cached(cache_key) + if cached_result is not None: + return ("cached", call_id, name, {**cached_result, "cached": True}) + + return ("execute", call_id, name, args, cache_key) + + def _execute_and_record( + self, + name: str, + args: dict, + cache_key: str, + ) -> dict: + """Execute a tool, cache result, record history, and return raw result dict.""" + result = self._execute_tool(name, args) + if result.get("success"): + self._set_cache(cache_key, result) + self._call_history.append({ + "name": name, + "args_str": json.dumps(args, sort_keys=True, ensure_ascii=False), + "result": result, + }) + return result def process_tool_calls_parallel( self, @@ -81,10 +140,6 @@ class ToolExecutor: """ Process tool calls concurrently and return message list (ordered by input). - Identical logic to process_tool_calls but uses ThreadPoolExecutor so that - independent tool calls (e.g. reading 3 files, running 2 searches) execute - in parallel instead of sequentially. - Args: tool_calls: Tool call list returned by LLM context: Optional context info (user_id, project_id, etc.) @@ -98,80 +153,31 @@ class ToolExecutor: max_workers = min(max(max_workers, 1), 6) - # Phase 1: prepare each call (parse args, inject context, check dedup/cache) - # This phase is fast and sequential – it must be done before parallelism - # to avoid race conditions on seen_calls / _call_history / _cache. - prepared: List[Optional[tuple]] = [None] * len(tool_calls) - seen_calls: set = set() + # Phase 1: prepare (sequential – avoids race conditions on shared state) + prepared = [self._prepare_call(call, context, set()) for call in tool_calls] - for i, call in enumerate(tool_calls): - name = call["function"]["name"] - args_str = call["function"]["arguments"] - call_id = call["id"] - - # Parse JSON arguments - try: - args = json.loads(args_str) if isinstance(args_str, str) else args_str - except json.JSONDecodeError: - prepared[i] = self._create_error_result(call_id, name, "Invalid JSON arguments") - continue - - # Inject context into tool arguments - self._inject_context(name, args, context) - - # Dedup within same batch - call_key = f"{name}:{json.dumps(args, sort_keys=True)}" - if call_key in seen_calls: - prepared[i] = self._create_tool_result( - call_id, name, - {"success": True, "data": None, "cached": True, "duplicate": True} - ) - continue - seen_calls.add(call_key) - - # History dedup - history_result = self._check_duplicate_in_history(name, args) - if history_result is not None: - prepared[i] = self._create_tool_result(call_id, name, {**history_result, "cached": True}) - continue - - # Cache check - cache_key = self._make_cache_key(name, args) - cached_result = self._get_cached(cache_key) - if cached_result is not None: - prepared[i] = self._create_tool_result(call_id, name, {**cached_result, "cached": True}) - continue - - # Mark as needing actual execution - prepared[i] = ("execute", call_id, name, args, cache_key) - - # Separate pre-resolved results from tasks needing execution + # Phase 2: separate pre-resolved from tasks needing execution results: List[dict] = [None] * len(tool_calls) - exec_tasks: Dict[int, tuple] = {} # index -> (call_id, name, args, cache_key) + exec_tasks: Dict[int, tuple] = {} for i, item in enumerate(prepared): - if isinstance(item, dict): - results[i] = item - elif isinstance(item, tuple) and item[0] == "execute": + tag = item[0] + if tag == "error": + _, call_id, name, error_msg = item + results[i] = self._create_error_result(call_id, name, error_msg) + elif tag == "cached": + _, call_id, name, result_dict = item + results[i] = self._create_tool_result(call_id, name, result_dict) + else: # "execute" _, call_id, name, args, cache_key = item exec_tasks[i] = (call_id, name, args, cache_key) - # Phase 2: execute remaining calls in parallel + # Phase 3: execute remaining calls in parallel if exec_tasks: def _run(idx: int, call_id: str, name: str, args: dict, cache_key: str) -> tuple: t0 = time.time() - result = self._execute_tool(name, args) + result = self._execute_and_record(name, args, cache_key) elapsed = time.time() - t0 - - if result.get("success"): - self._set_cache(cache_key, result) - - self._call_history.append({ - "name": name, - "args_str": json.dumps(args, sort_keys=True, ensure_ascii=False), - "result": result, - }) - return idx, self._create_tool_result(call_id, name, result, execution_time=elapsed) with ThreadPoolExecutor(max_workers=max_workers) as pool: @@ -201,65 +207,22 @@ class ToolExecutor: Tool response message list, can be appended to messages """ results = [] - seen_calls = set() # Track calls within this batch + seen_calls: set = set() for call in tool_calls: - name = call["function"]["name"] - args_str = call["function"]["arguments"] - call_id = call["id"] + prepared = self._prepare_call(call, context, seen_calls) + tag = prepared[0] - try: - args = json.loads(args_str) if isinstance(args_str, str) else args_str - except json.JSONDecodeError: - results.append(self._create_error_result( - call_id, name, "Invalid JSON arguments" - )) - continue - - # Inject context into tool arguments - self._inject_context(name, args, context) - - # Check for duplicate within same batch - call_key = f"{name}:{json.dumps(args, sort_keys=True)}" - if call_key in seen_calls: - # Skip duplicate, but still return a result - results.append(self._create_tool_result( - call_id, name, - {"success": True, "data": None, "cached": True, "duplicate": True} - )) - continue - seen_calls.add(call_key) - - # Check history for previous call in this session - history_result = self._check_duplicate_in_history(name, args) - if history_result is not None: - result = {**history_result, "cached": True} + if tag == "error": + _, call_id, name, error_msg = prepared + results.append(self._create_error_result(call_id, name, error_msg)) + elif tag == "cached": + _, call_id, name, result_dict = prepared + results.append(self._create_tool_result(call_id, name, result_dict)) + else: # "execute" + _, call_id, name, args, cache_key = prepared + result = self._execute_and_record(name, args, cache_key) results.append(self._create_tool_result(call_id, name, result)) - continue - - # Check cache - cache_key = self._make_cache_key(name, args) - cached_result = self._get_cached(cache_key) - if cached_result is not None: - result = {**cached_result, "cached": True} - results.append(self._create_tool_result(call_id, name, result)) - continue - - # Execute tool with retry - result = self._execute_tool(name, args) - - # Cache the result (only cache successful results) - if result.get("success"): - self._set_cache(cache_key, result) - - # Add to history - self._call_history.append({ - "name": name, - "args_str": json.dumps(args, sort_keys=True, ensure_ascii=False), - "result": result - }) - - results.append(self._create_tool_result(call_id, name, result)) return results diff --git a/backend/utils/helpers.py b/backend/utils/helpers.py index a9761b4..6691002 100644 --- a/backend/utils/helpers.py +++ b/backend/utils/helpers.py @@ -133,6 +133,13 @@ def build_messages(conv, project_id=None): # Query messages directly to avoid detached instance warning messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all() for m in messages: + # Skip tool messages — they are ephemeral intermediate results, not + # meant to be replayed as conversation history (would violate the API + # protocol that requires tool messages to follow an assistant message + # with matching tool_calls). + if m.role == "tool": + continue + # Build full content from JSON structure full_content = m.content try: diff --git a/docs/Design.md b/docs/Design.md index f2e8ca8..0647108 100644 --- a/docs/Design.md +++ b/docs/Design.md @@ -266,8 +266,8 @@ classDiagram -ToolRegistry registry -dict _cache -list _call_history - +process_tool_calls(calls, context) list - +clear_history() void + +process_tool_calls(list, dict) list + +process_tool_calls_parallel(list, dict, int) list } ChatService --> LLMClient : 使用 @@ -295,18 +295,17 @@ classDiagram +register(ToolDefinition) void +get(str name) ToolDefinition? +list_all() list~dict~ - +list_by_category(str) list~dict~ +execute(str name, dict args) dict - +remove(str name) bool - +has(str name) bool } class ToolExecutor { -ToolRegistry registry + -bool enable_cache + -int cache_ttl -dict _cache -list _call_history +process_tool_calls(list, dict) list - +clear_history() void + +process_tool_calls_parallel(list, dict, int) list } class ToolResult { @@ -394,18 +393,19 @@ def validate_path_in_project(path: str, project_dir: Path) -> Path: 工具执行器自动为文件工具注入 `project_id`: ```python -# backend/tools/executor.py +# backend/tools/executor.py — _inject_context() -def process_tool_calls(self, tool_calls, context=None): - for call in tool_calls: - name = call["function"]["name"] - args = json.loads(call["function"]["arguments"]) - - # 自动注入 project_id - if context and name.startswith("file_") and "project_id" in context: - args["project_id"] = context["project_id"] - - result = self.registry.execute(name, args) +@staticmethod +def _inject_context(name: str, args: dict, context: Optional[dict]) -> None: + # file_* 工具: 注入 project_id + if name.startswith("file_") and "project_id" in context: + args["project_id"] = context["project_id"] + # agent 工具: 注入 _model 和 _project_id + if name == "multi_agent": + if "model" in context: + args["_model"] = context["model"] + if "project_id" in context: + args["_project_id"] = context["project_id"] ``` --- diff --git a/docs/ToolSystemDesign.md b/docs/ToolSystemDesign.md index b197dc7..1c76288 100644 --- a/docs/ToolSystemDesign.md +++ b/docs/ToolSystemDesign.md @@ -27,19 +27,20 @@ classDiagram +register(ToolDefinition tool) void +get(str name) ToolDefinition? +list_all() list~dict~ - +list_by_category(str category) list~dict~ +execute(str name, dict args) dict - +remove(str name) bool - +has(str name) bool } class ToolExecutor { -ToolRegistry registry + -bool enable_cache + -int cache_ttl -dict _cache -list _call_history +process_tool_calls(list tool_calls, dict context) list~dict~ - +build_request(list messages, str model, list tools, dict kwargs) dict - +clear_history() void + +process_tool_calls_parallel(list tool_calls, dict context, int max_workers) list~dict~ + -_prepare_call(dict call, dict context, set seen_calls) tuple + -_execute_and_record(str name, dict args, str cache_key) dict + -_inject_context(str name, dict args, dict context) void } class ToolResult { @@ -88,32 +89,26 @@ classDiagram ### context 参数 -`process_tool_calls()` 接受 `context` 参数,用于自动注入工具参数: +`process_tool_calls()` / `process_tool_calls_parallel()` 接受 `context` 参数,用于自动注入工具参数: ```python -# backend/tools/executor.py +# backend/tools/executor.py — _inject_context() -def process_tool_calls( - self, - tool_calls: List[dict], - context: Optional[dict] = None -) -> List[dict]: +@staticmethod +def _inject_context(name: str, args: dict, context: Optional[dict]) -> None: """ - Args: - tool_calls: LLM 返回的工具调用列表 - context: 上下文信息,支持: - - project_id: 自动注入到文件工具 + - file_* 工具: 注入 project_id + - agent 工具 (multi_agent): 注入 _model 和 _project_id """ - for call in tool_calls: - name = call["function"]["name"] - args = json.loads(call["function"]["arguments"]) - - # 自动注入 project_id 到文件工具 - if context: - if name.startswith("file_") and "project_id" in context: - args["project_id"] = context["project_id"] - - result = self.registry.execute(name, args) + if not context: + return + if name.startswith("file_") and "project_id" in context: + args["project_id"] = context["project_id"] + if name == "multi_agent": + if "model" in context: + args["_model"] = context["model"] + if "project_id" in context: + args["_project_id"] = context["project_id"] ``` ### 使用示例 @@ -122,12 +117,12 @@ def process_tool_calls( # backend/services/chat.py def stream_response(self, conv, tools_enabled=True, project_id=None): - # 构建上下文(优先使用请求传递的 project_id,否则回退到对话绑定的) - context = None + # 构建上下文(包含 model 和 project_id) + context = {"model": conv.model} if project_id: - context = {"project_id": project_id} + context["project_id"] = project_id elif conv.project_id: - context = {"project_id": conv.project_id} + context["project_id"] = conv.project_id # 处理工具调用时自动注入 tool_results = self.executor.process_tool_calls(tool_calls, context) @@ -250,6 +245,19 @@ file_read({"path": "src/main.py", "project_id": "xxx"}) |---------|------|------| | `get_weather` | 查询天气信息(模拟) | `city`: 城市名称 | +### 5.6 多智能体工具 (agent) + +| 工具名称 | 描述 | 参数 | +|---------|------|------| +| `multi_agent` | 派生子 Agent 并发执行任务(最多 5 个) | `tasks`: 任务数组(name, instruction, tools)
`_model`: 模型名称(自动注入)
`_project_id`: 项目 ID(自动注入) | + +**`multi_agent` 工作原理:** +1. 接收任务数组,每个任务指定 name、instruction 和可选的 tools 列表 +2. 为每个子 Agent 创建独立线程,各自拥有 LLM 对话循环(最多 3 轮迭代,4096 tokens) +3. 通过 Service Locator 获取 `llm_client` 实例 +4. 子 Agent 在 `app.app_context()` 中运行,可独立调用所有注册工具 +5. 返回 `{success, results: [{task_name, success, response/error}], total}` + --- ## 六、核心特性 @@ -285,7 +293,6 @@ def my_tool(arguments: dict) -> dict: - **批次内去重**:同一批次中相同工具+参数的调用会被跳过 - **历史去重**:同一会话内已调用过的工具会直接返回缓存结果 -- **自动清理**:新会话开始时调用 `clear_history()` 清理历史 ### 6.4 无自动重试 @@ -308,13 +315,45 @@ def my_tool(arguments: dict) -> dict: def init_tools() -> None: """初始化所有内置工具""" from backend.tools.builtin import ( - code, crawler, data, weather, file_ops + code, crawler, data, weather, file_ops, agent ) ``` --- -## 八、扩展新工具 +## 八、Service Locator + +工具系统提供 Service Locator 模式,允许工具访问共享服务(如 LLM 客户端): + +```python +# backend/tools/__init__.py + +_services: dict = {} + +def register_service(name: str, service) -> None: + """注册共享服务""" + _services[name] = service + +def get_service(name: str): + """获取已注册的服务,不存在则返回 None""" + return _services.get(name) +``` + +### 使用方式 + +```python +# 在应用初始化时注册(routes/__init__.py) +from backend.tools import register_service +register_service("llm_client", llm_client) + +# 在工具中使用(agent.py) +from backend.tools import get_service +llm_client = get_service("llm_client") +``` + +--- + +## 九、扩展新工具 ### 添加新工具