From 7bd19a7529d46219715a0d2d6e77b5493648aede Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 27 Mar 2026 16:35:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=20=E5=88=9D=E6=AD=A5=E6=90=AD=E5=BB=BA?= =?UTF-8?q?=20multi-agent=20=E6=A1=86=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/routes/__init__.py | 4 + backend/services/chat.py | 29 +-- backend/tools/__init__.py | 21 +- backend/tools/builtin/__init__.py | 1 + backend/tools/builtin/agent.py | 348 ++++++++++++++++++++++++++++++ backend/tools/executor.py | 139 +++++++++++- backend/tools/services.py | 28 ++- 7 files changed, 548 insertions(+), 22 deletions(-) create mode 100644 backend/tools/builtin/agent.py diff --git a/backend/routes/__init__.py b/backend/routes/__init__.py index ad0700e..5996c11 100644 --- a/backend/routes/__init__.py +++ b/backend/routes/__init__.py @@ -17,6 +17,10 @@ def register_routes(app: Flask): client = LLMClient(MODEL_CONFIG) init_chat_service(client) + # Register LLM client in service locator so tools (e.g. agent_task) can access it + from backend.tools import register_service + register_service("llm_client", client) + # Initialize authentication system (reads auth_mode from config.yml) init_auth(app) diff --git a/backend/services/chat.py b/backend/services/chat.py index 78d1ff1..c3cd311 100644 --- a/backend/services/chat.py +++ b/backend/services/chat.py @@ -56,11 +56,11 @@ class ChatService: executor = ToolExecutor(registry=registry) # Build context for tool execution - context = None + context = {"model": conv_model} if project_id: - context = {"project_id": project_id} + context["project_id"] = project_id elif conv.project_id: - context = {"project_id": conv.project_id} + context["project_id"] = conv.project_id def generate(): messages = list(initial_messages) @@ -178,10 +178,8 @@ class ChatService: if tool_calls_list: all_tool_calls.extend(tool_calls_list) - # Execute each tool call, emit tool_call + tool_result as paired steps - tool_results = [] + # Phase 1: emit all tool_call steps (before execution) for tc in tool_calls_list: - # Emit tool_call step (before execution) call_step = { 'id': f'step-{step_index}', 'index': step_index, @@ -194,17 +192,24 @@ class ChatService: yield f"event: process_step\ndata: {json.dumps(call_step, ensure_ascii=False)}\n\n" step_index += 1 - # Execute the tool + # Phase 2: execute tools — parallel when multiple, sequential when single + if len(tool_calls_list) > 1: with app.app_context(): - single_result = executor.process_tool_calls([tc], context) - tool_results.extend(single_result) + tool_results = executor.process_tool_calls_parallel( + tool_calls_list, context, max_workers=4 + ) + else: + with app.app_context(): + tool_results = executor.process_tool_calls( + tool_calls_list, context + ) - # Emit tool_result step (after execution) - tr = single_result[0] + # Phase 3: emit all tool_result steps (after execution, same order) + for tr in tool_results: try: result_content = json.loads(tr["content"]) skipped = result_content.get("skipped", False) - except: + except Exception: skipped = False result_step = { 'id': f'step-{step_index}', diff --git a/backend/tools/__init__.py b/backend/tools/__init__.py index 15ac229..1c3d449 100644 --- a/backend/tools/__init__.py +++ b/backend/tools/__init__.py @@ -20,13 +20,29 @@ from backend.tools.factory import tool, register_tool from backend.tools.executor import ToolExecutor +# --------------------------------------------------------------------------- +# Service locator – allows tools (e.g. agent_task) to access LLM client +# --------------------------------------------------------------------------- +_services: dict = {} + + +def register_service(name: str, service) -> None: + """Register a shared service (e.g. LLM client) for tool access.""" + _services[name] = service + + +def get_service(name: str): + """Retrieve a previously registered service, or None.""" + return _services.get(name) + + def init_tools() -> None: """ Initialize all built-in tools Importing builtin module automatically registers all decorator-defined tools """ - from backend.tools.builtin import code, crawler, data, weather, file_ops # noqa: F401 + from backend.tools.builtin import code, crawler, data, weather, file_ops, agent # noqa: F401 # Public API exports @@ -43,4 +59,7 @@ __all__ = [ "register_tool", # Initialization "init_tools", + # Service locator + "register_service", + "get_service", ] diff --git a/backend/tools/builtin/__init__.py b/backend/tools/builtin/__init__.py index f97685d..20da6af 100644 --- a/backend/tools/builtin/__init__.py +++ b/backend/tools/builtin/__init__.py @@ -4,3 +4,4 @@ from backend.tools.builtin.crawler import * from backend.tools.builtin.data import * from backend.tools.builtin.file_ops import * from backend.tools.builtin.weather import * +from backend.tools.builtin.agent import * diff --git a/backend/tools/builtin/agent.py b/backend/tools/builtin/agent.py new file mode 100644 index 0000000..fef021f --- /dev/null +++ b/backend/tools/builtin/agent.py @@ -0,0 +1,348 @@ +"""Multi-agent tools for concurrent and batch task execution. + +Provides: +- parallel_execute: Run multiple tool calls concurrently +- agent_task: Spawn sub-agents with their own LLM conversation loops +""" +import json +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Dict, Any, Optional + +from backend.tools.factory import tool +from backend.tools.core import registry +from backend.tools.executor import ToolExecutor + + +# --------------------------------------------------------------------------- +# parallel_execute – run multiple tool calls concurrently +# --------------------------------------------------------------------------- + +@tool( + name="parallel_execute", + description=( + "Execute multiple tool calls concurrently for better performance. " + "Use when you have several independent operations that don't depend on each other " + "(e.g. reading multiple files, running multiple searches, fetching several pages). " + "Results are returned in the same order as the input." + ), + parameters={ + "type": "object", + "properties": { + "tool_calls": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Tool name to execute", + }, + "arguments": { + "type": "object", + "description": "Arguments for the tool", + }, + }, + "required": ["name", "arguments"], + }, + "description": "List of tool calls to execute in parallel (max 10)", + }, + "concurrency": { + "type": "integer", + "description": "Max concurrent executions (1-5, default 3)", + "default": 3, + }, + }, + "required": ["tool_calls"], + }, + category="agent", +) +def parallel_execute(arguments: dict) -> dict: + """Execute multiple tool calls concurrently. + + Args: + arguments: { + "tool_calls": [ + {"name": "file_read", "arguments": {"path": "a.py"}}, + {"name": "web_search", "arguments": {"query": "python"}} + ], + "concurrency": 3, + "_project_id": "..." // injected by executor + } + + Returns: + {"results": [{index, tool_name, success, data/error}]} + """ + tool_calls = arguments["tool_calls"] + concurrency = min(max(arguments.get("concurrency", 3), 1), 5) + + if len(tool_calls) > 10: + return {"success": False, "error": "Maximum 10 tool calls allowed per parallel execution"} + + # Build executor context from injected fields + context = {} + project_id = arguments.get("_project_id") + if project_id: + context["project_id"] = project_id + + # Format tool_calls into executor-compatible format + executor_calls = [] + for i, tc in enumerate(tool_calls): + executor_calls.append({ + "id": f"pe-{i}", + "type": "function", + "function": { + "name": tc["name"], + "arguments": json.dumps(tc["arguments"], ensure_ascii=False), + }, + }) + + # Use ToolExecutor for proper context injection, caching and dedup + executor = ToolExecutor(registry=registry, enable_cache=False) + executor_results = executor.process_tool_calls_parallel( + executor_calls, context, max_workers=concurrency + ) + + # Format output + results = [] + for er in executor_results: + try: + content = json.loads(er["content"]) if isinstance(er["content"], str) else er["content"] + except (json.JSONDecodeError, TypeError): + content = {"success": False, "error": "Failed to parse result"} + results.append({ + "index": len(results), + "tool_name": er["name"], + **content, + }) + + return { + "success": True, + "results": results, + "total": len(results), + } + + +# --------------------------------------------------------------------------- +# agent_task – spawn sub-agents with independent LLM conversation loops +# --------------------------------------------------------------------------- + +def _run_sub_agent( + task_name: str, + instruction: str, + tool_names: Optional[List[str]], + model: str, + max_tokens: int, + project_id: Optional[str], + app: Any, + max_iterations: int = 3, +) -> dict: + """Run a single sub-agent with its own agentic loop. + + Each sub-agent gets its own ToolExecutor instance and runs a simplified + version of the main agent loop, limited to prevent runaway cost. + """ + from backend.tools import get_service + + llm_client = get_service("llm_client") + if not llm_client: + return { + "task_name": task_name, + "success": False, + "error": "LLM client not available", + } + + # Build tool list – filter to requested tools or use all + all_tools = registry.list_all() + if tool_names: + allowed = set(tool_names) + tools = [t for t in all_tools if t["function"]["name"] in allowed] + else: + tools = all_tools + + executor = ToolExecutor(registry=registry) + context = {"project_id": project_id} if project_id else None + + # System prompt: instruction + reminder to give a final text answer + system_msg = ( + f"{instruction}\n\n" + "IMPORTANT: After gathering information via tools, you MUST provide a final " + "text response with your analysis/answer. Do NOT end with only tool calls." + ) + messages = [{"role": "system", "content": system_msg}] + + for _ in range(max_iterations): + try: + with app.app_context(): + resp = llm_client.call( + model=model, + messages=messages, + tools=tools if tools else None, + stream=False, + max_tokens=min(max_tokens, 4096), + temperature=0.7, + timeout=60, + ) + + if resp.status_code != 200: + error_detail = resp.text[:500] if resp.text else f"HTTP {resp.status_code}" + return { + "task_name": task_name, + "success": False, + "error": f"LLM API error: {error_detail}", + } + + data = resp.json() + choice = data["choices"][0] + message = choice["message"] + + if message.get("tool_calls"): + messages.append(message) + tc_list = message["tool_calls"] + # Convert OpenAI tool_calls to executor format + executor_calls = [] + for tc in tc_list: + executor_calls.append({ + "id": tc.get("id", ""), + "type": tc.get("type", "function"), + "function": { + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + }, + }) + tool_results = executor.process_tool_calls(executor_calls, context) + messages.extend(tool_results) + else: + # Final text response + return { + "task_name": task_name, + "success": True, + "response": message.get("content", ""), + } + + except Exception as e: + return { + "task_name": task_name, + "success": False, + "error": str(e), + } + + # Exhausted iterations without final response — return last LLM output if any + return { + "task_name": task_name, + "success": True, + "response": "Agent task completed but did not produce a final text response within the iteration limit.", + } + + +# @tool( +# name="agent_task", +# description=( +# "Spawn one or more sub-agents to work on tasks concurrently. " +# "Each agent runs its own independent conversation with the LLM and can use tools. " +# "Useful for parallel research, multi-file analysis, or dividing complex tasks into sub-tasks. " +# "Each agent is limited to 3 iterations and 4096 tokens to control cost." +# ), +# parameters={ +# "type": "object", +# "properties": { +# "tasks": { +# "type": "array", +# "items": { +# "type": "object", +# "properties": { +# "name": { +# "type": "string", +# "description": "Short name/identifier for this task", +# }, +# "instruction": { +# "type": "string", +# "description": "Detailed instruction for the sub-agent", +# }, +# "tools": { +# "type": "array", +# "items": {"type": "string"}, +# "description": ( +# "Tool names this agent can use (empty = all tools). " +# "e.g. ['file_read', 'file_list', 'web_search']" +# ), +# }, +# }, +# "required": ["name", "instruction"], +# }, +# "description": "Tasks for parallel sub-agents (max 5)", +# }, +# }, +# "required": ["tasks"], +# }, +# category="agent", +# ) +def agent_task(arguments: dict) -> dict: + """Spawn sub-agents to work on tasks concurrently. + + Args: + arguments: { + "tasks": [ + { + "name": "research", + "instruction": "Research Python async patterns...", + "tools": ["web_search", "fetch_page"] + }, + { + "name": "code_review", + "instruction": "Review code quality...", + "tools": ["file_read", "file_list"] + } + ] + } + + Returns: + {"success": true, "results": [{task_name, success, response/error}]} + """ + from flask import current_app + + tasks = arguments["tasks"] + + if len(tasks) > 5: + return {"success": False, "error": "Maximum 5 concurrent agents allowed"} + + # Get current conversation context for model/project info + app = current_app._get_current_object() + + # Use injected model/project_id from executor context, fall back to defaults + model = arguments.get("_model", "glm-5") + project_id = arguments.get("_project_id") + + # Execute agents concurrently (max 3 at a time) + concurrency = min(len(tasks), 3) + results = [None] * len(tasks) + + with ThreadPoolExecutor(max_workers=concurrency) as pool: + futures = { + pool.submit( + _run_sub_agent, + task["name"], + task["instruction"], + task.get("tools"), + model, + 4096, + project_id, + app, + ): i + for i, task in enumerate(tasks) + } + for future in as_completed(futures): + idx = futures[future] + try: + results[idx] = future.result() + except Exception as e: + results[idx] = { + "task_name": tasks[idx]["name"], + "success": False, + "error": str(e), + } + + return { + "success": True, + "results": results, + "total": len(results), + } diff --git a/backend/tools/executor.py b/backend/tools/executor.py index 9112ed1..c6b4a3d 100644 --- a/backend/tools/executor.py +++ b/backend/tools/executor.py @@ -2,6 +2,7 @@ import json import time import hashlib +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Dict, Optional, Any from backend.tools.core import ToolRegistry, registry @@ -54,6 +55,140 @@ class ToolExecutor: """Clear call history (call this at start of new conversation turn)""" self._call_history.clear() + @staticmethod + def _inject_context(name: str, args: dict, context: Optional[dict]) -> None: + """Inject context fields into tool arguments in-place. + + - file_* tools: inject project_id + - agent_task: inject model and project_id (prefixed with _ to avoid collisions) + - parallel_execute: inject project_id (prefixed with _ to avoid collisions) + """ + if not context: + return + if name.startswith("file_") and "project_id" in context: + args["project_id"] = context["project_id"] + if name == "agent_task": + if "model" in context: + args["_model"] = context["model"] + if "project_id" in context: + args["_project_id"] = context["project_id"] + if name == "parallel_execute": + if "project_id" in context: + args["_project_id"] = context["project_id"] + + def process_tool_calls_parallel( + self, + tool_calls: List[dict], + context: Optional[dict] = None, + max_workers: int = 4, + ) -> List[dict]: + """ + Process tool calls concurrently and return message list (ordered by input). + + Identical logic to process_tool_calls but uses ThreadPoolExecutor so that + independent tool calls (e.g. reading 3 files, running 2 searches) execute + in parallel instead of sequentially. + + Args: + tool_calls: Tool call list returned by LLM + context: Optional context info (user_id, project_id, etc.) + max_workers: Maximum concurrent threads (1-6, default 4) + + Returns: + Tool response message list in the same order as input tool_calls. + """ + if len(tool_calls) <= 1: + return self.process_tool_calls(tool_calls, context) + + max_workers = min(max(max_workers, 1), 6) + + # Phase 1: prepare each call (parse args, inject context, check dedup/cache) + # This phase is fast and sequential – it must be done before parallelism + # to avoid race conditions on seen_calls / _call_history / _cache. + prepared: List[Optional[tuple]] = [None] * len(tool_calls) + seen_calls: set = set() + + for i, call in enumerate(tool_calls): + name = call["function"]["name"] + args_str = call["function"]["arguments"] + call_id = call["id"] + + # Parse JSON arguments + try: + args = json.loads(args_str) if isinstance(args_str, str) else args_str + except json.JSONDecodeError: + prepared[i] = self._create_error_result(call_id, name, "Invalid JSON arguments") + continue + + # Inject context into tool arguments + self._inject_context(name, args, context) + + # Dedup within same batch + call_key = f"{name}:{json.dumps(args, sort_keys=True)}" + if call_key in seen_calls: + prepared[i] = self._create_tool_result( + call_id, name, + {"success": True, "data": None, "cached": True, "duplicate": True} + ) + continue + seen_calls.add(call_key) + + # History dedup + history_result = self._check_duplicate_in_history(name, args) + if history_result is not None: + prepared[i] = self._create_tool_result(call_id, name, {**history_result, "cached": True}) + continue + + # Cache check + cache_key = self._make_cache_key(name, args) + cached_result = self._get_cached(cache_key) + if cached_result is not None: + prepared[i] = self._create_tool_result(call_id, name, {**cached_result, "cached": True}) + continue + + # Mark as needing actual execution + prepared[i] = ("execute", call_id, name, args, cache_key) + + # Separate pre-resolved results from tasks needing execution + results: List[dict] = [None] * len(tool_calls) + exec_tasks: Dict[int, tuple] = {} # index -> (call_id, name, args, cache_key) + + for i, item in enumerate(prepared): + if isinstance(item, dict): + results[i] = item + elif isinstance(item, tuple) and item[0] == "execute": + _, call_id, name, args, cache_key = item + exec_tasks[i] = (call_id, name, args, cache_key) + + # Phase 2: execute remaining calls in parallel + if exec_tasks: + def _run(idx: int, call_id: str, name: str, args: dict, cache_key: str) -> tuple: + t0 = time.time() + result = self._execute_tool(name, args) + elapsed = time.time() - t0 + + if result.get("success"): + self._set_cache(cache_key, result) + + self._call_history.append({ + "name": name, + "args_str": json.dumps(args, sort_keys=True, ensure_ascii=False), + "result": result, + }) + + return idx, self._create_tool_result(call_id, name, result, execution_time=elapsed) + + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = { + pool.submit(_run, idx, cid, n, a, ck): idx + for idx, (cid, n, a, ck) in exec_tasks.items() + } + for future in as_completed(futures): + idx, result_msg = future.result() + results[idx] = result_msg + + return results + def process_tool_calls( self, tool_calls: List[dict], @@ -86,9 +221,7 @@ class ToolExecutor: continue # Inject context into tool arguments - if context: - if name.startswith("file_") and "project_id" in context: - args["project_id"] = context["project_id"] + self._inject_context(name, args, context) # Check for duplicate within same batch call_key = f"{name}:{json.dumps(args, sort_keys=True)}" diff --git a/backend/tools/services.py b/backend/tools/services.py index 7a4cecd..c069a12 100644 --- a/backend/tools/services.py +++ b/backend/tools/services.py @@ -1,5 +1,6 @@ """Tool helper services""" from typing import List +from concurrent.futures import ThreadPoolExecutor, as_completed from ddgs import DDGS import re @@ -119,19 +120,34 @@ class FetchService: max_concurrent: int = 5 ) -> List[dict]: """ - Batch fetch pages + Batch fetch pages concurrently. Args: urls: URL list extract_type: Extract type - max_concurrent: Max concurrent requests + max_concurrent: Max concurrent requests (1-5, default 5) Returns: - Result list + Result list (same order as input URLs) """ - results = [] - for url in urls: - results.append(self.fetch(url, extract_type)) + if len(urls) <= 1: + return [self.fetch(url, extract_type) for url in urls] + + max_concurrent = min(max(max_concurrent, 1), 5) + results = [None] * len(urls) + + with ThreadPoolExecutor(max_workers=max_concurrent) as pool: + futures = { + pool.submit(self.fetch, url, extract_type): i + for i, url in enumerate(urls) + } + for future in as_completed(futures): + idx = futures[future] + try: + results[idx] = future.result() + except Exception as e: + results[idx] = {"error": str(e)} + return results