refactor: 重构multi agent 参数设置
This commit is contained in:
parent
3970c0b9a0
commit
cc639a979a
|
|
@ -39,8 +39,6 @@ max_iterations: 15
|
||||||
# Sub-agent settings (multi_agent tool)
|
# Sub-agent settings (multi_agent tool)
|
||||||
sub_agent:
|
sub_agent:
|
||||||
max_iterations: 3 # Max tool-call rounds per sub-agent
|
max_iterations: 3 # Max tool-call rounds per sub-agent
|
||||||
max_tokens: 4096 # Max tokens per LLM call inside a sub-agent
|
|
||||||
max_agents: 5 # Max number of concurrent sub-agents per request
|
|
||||||
max_concurrency: 3 # ThreadPoolExecutor max workers
|
max_concurrency: 3 # ThreadPoolExecutor max workers
|
||||||
|
|
||||||
# Available models
|
# Available models
|
||||||
|
|
|
||||||
|
|
@ -37,15 +37,11 @@ MAX_ITERATIONS = _cfg.get("max_iterations", 5)
|
||||||
# Max parallel workers for tool execution (ThreadPoolExecutor)
|
# Max parallel workers for tool execution (ThreadPoolExecutor)
|
||||||
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
|
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
|
||||||
|
|
||||||
# Max character length for a single tool result content (truncated if exceeded)
|
|
||||||
TOOL_RESULT_MAX_LENGTH = _cfg.get("tool_result_max_length", 4096)
|
|
||||||
|
|
||||||
# Sub-agent settings (multi_agent tool)
|
# Sub-agent settings (multi_agent tool)
|
||||||
_sa = _cfg.get("sub_agent", {})
|
_sa = _cfg.get("sub_agent", {})
|
||||||
SUB_AGENT_MAX_ITERATIONS = _sa.get("max_iterations", 3)
|
SUB_AGENT_MAX_ITERATIONS = _sa.get("max_iterations", 3)
|
||||||
SUB_AGENT_MAX_TOKENS = _sa.get("max_tokens", 4096)
|
|
||||||
SUB_AGENT_MAX_AGENTS = _sa.get("max_agents", 5)
|
|
||||||
SUB_AGENT_MAX_CONCURRENCY = _sa.get("max_concurrency", 3)
|
SUB_AGENT_MAX_CONCURRENCY = _sa.get("max_concurrency", 3)
|
||||||
|
SUB_AGENT_TIMEOUT = _sa.get("timeout", 60)
|
||||||
|
|
||||||
# Code execution settings
|
# Code execution settings
|
||||||
_ce = _cfg.get("code_execution", {})
|
_ce = _cfg.get("code_execution", {})
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from backend.utils.helpers import (
|
||||||
build_messages,
|
build_messages,
|
||||||
)
|
)
|
||||||
from backend.services.llm_client import LLMClient
|
from backend.services.llm_client import LLMClient
|
||||||
from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS, TOOL_RESULT_MAX_LENGTH
|
from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -70,7 +70,11 @@ class ChatService:
|
||||||
|
|
||||||
executor = ToolExecutor(registry=registry)
|
executor = ToolExecutor(registry=registry)
|
||||||
|
|
||||||
context = {"model": conv_model}
|
context = {
|
||||||
|
"model": conv_model,
|
||||||
|
"max_tokens": conv_max_tokens,
|
||||||
|
"temperature": conv_temperature,
|
||||||
|
}
|
||||||
if project_id:
|
if project_id:
|
||||||
context["project_id"] = project_id
|
context["project_id"] = project_id
|
||||||
elif conv.project_id:
|
elif conv.project_id:
|
||||||
|
|
@ -332,30 +336,6 @@ class ChatService:
|
||||||
sse_chunks,
|
sse_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _truncate_tool_results(self, tool_results):
|
|
||||||
"""Truncate oversized tool result content in-place and return the list."""
|
|
||||||
for tr in tool_results:
|
|
||||||
if len(tr["content"]) > TOOL_RESULT_MAX_LENGTH:
|
|
||||||
try:
|
|
||||||
result_data = json.loads(tr["content"])
|
|
||||||
original = result_data
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
original = None
|
|
||||||
|
|
||||||
tr["content"] = json.dumps(
|
|
||||||
{"success": False, "error": "Tool result too large, truncated"},
|
|
||||||
ensure_ascii=False,
|
|
||||||
) if not original else json.dumps(
|
|
||||||
{
|
|
||||||
**original,
|
|
||||||
"truncated": True,
|
|
||||||
"_note": f"Content truncated, original length {len(tr['content'])} chars",
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
default=str,
|
|
||||||
)[:TOOL_RESULT_MAX_LENGTH]
|
|
||||||
return tool_results
|
|
||||||
|
|
||||||
def _execute_tools_safe(self, app, executor, tool_calls_list, context):
|
def _execute_tools_safe(self, app, executor, tool_calls_list, context):
|
||||||
"""Execute tool calls with top-level error wrapping.
|
"""Execute tool calls with top-level error wrapping.
|
||||||
|
|
||||||
|
|
@ -365,21 +345,17 @@ class ChatService:
|
||||||
try:
|
try:
|
||||||
if len(tool_calls_list) > 1:
|
if len(tool_calls_list) > 1:
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
return self._truncate_tool_results(
|
return executor.process_tool_calls_parallel(
|
||||||
executor.process_tool_calls_parallel(
|
tool_calls_list, context, max_workers=TOOL_MAX_WORKERS
|
||||||
tool_calls_list, context, max_workers=TOOL_MAX_WORKERS
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
return self._truncate_tool_results(
|
return executor.process_tool_calls(
|
||||||
executor.process_tool_calls(
|
tool_calls_list, context
|
||||||
tool_calls_list, context
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error during tool execution")
|
logger.exception("Error during tool execution")
|
||||||
tool_results = [
|
return [
|
||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tc["id"],
|
"tool_call_id": tc["id"],
|
||||||
|
|
@ -391,7 +367,6 @@ class ChatService:
|
||||||
}
|
}
|
||||||
for tc in tool_calls_list
|
for tc in tool_calls_list
|
||||||
]
|
]
|
||||||
return self._truncate_tool_results(tool_results)
|
|
||||||
|
|
||||||
def _save_message(
|
def _save_message(
|
||||||
self, app, conv_id, conv_model, msg_id,
|
self, app, conv_id, conv_model, msg_id,
|
||||||
|
|
|
||||||
|
|
@ -7,16 +7,15 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
from backend.tools import get_service
|
||||||
from backend.tools.factory import tool
|
from backend.tools.factory import tool
|
||||||
from backend.tools.core import registry
|
from backend.tools.core import registry
|
||||||
from backend.tools.executor import ToolExecutor
|
from backend.tools.executor import ToolExecutor
|
||||||
from backend.config import (
|
from backend.config import (
|
||||||
DEFAULT_MODEL,
|
DEFAULT_MODEL,
|
||||||
SUB_AGENT_MAX_ITERATIONS,
|
SUB_AGENT_MAX_ITERATIONS,
|
||||||
SUB_AGENT_MAX_TOKENS,
|
|
||||||
SUB_AGENT_MAX_AGENTS,
|
|
||||||
SUB_AGENT_MAX_CONCURRENCY,
|
SUB_AGENT_MAX_CONCURRENCY,
|
||||||
|
SUB_AGENT_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -62,6 +61,7 @@ def _run_sub_agent(
|
||||||
tool_names: Optional[List[str]],
|
tool_names: Optional[List[str]],
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
project_id: Optional[str],
|
project_id: Optional[str],
|
||||||
app: Any,
|
app: Any,
|
||||||
max_iterations: int = 3,
|
max_iterations: int = 3,
|
||||||
|
|
@ -71,7 +71,6 @@ def _run_sub_agent(
|
||||||
Each sub-agent gets its own ToolExecutor instance and runs a simplified
|
Each sub-agent gets its own ToolExecutor instance and runs a simplified
|
||||||
version of the main agent loop, limited to prevent runaway cost.
|
version of the main agent loop, limited to prevent runaway cost.
|
||||||
"""
|
"""
|
||||||
from backend.tools import get_service
|
|
||||||
|
|
||||||
llm_client = get_service("llm_client")
|
llm_client = get_service("llm_client")
|
||||||
if not llm_client:
|
if not llm_client:
|
||||||
|
|
@ -117,9 +116,9 @@ def _run_sub_agent(
|
||||||
# more tools.
|
# more tools.
|
||||||
tools=None if is_final else (tools if tools else None),
|
tools=None if is_final else (tools if tools else None),
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=min(max_tokens, 4096),
|
max_tokens=max_tokens,
|
||||||
temperature=0.7,
|
temperature=temperature,
|
||||||
timeout=60,
|
timeout=SUB_AGENT_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
|
|
@ -247,8 +246,8 @@ def multi_agent(arguments: dict) -> dict:
|
||||||
|
|
||||||
tasks = arguments["tasks"]
|
tasks = arguments["tasks"]
|
||||||
|
|
||||||
if len(tasks) > SUB_AGENT_MAX_AGENTS:
|
if len(tasks) > 5:
|
||||||
return {"success": False, "error": f"Maximum {SUB_AGENT_MAX_AGENTS} concurrent agents allowed"}
|
return {"success": False, "error": "Maximum 5 concurrent agents allowed"}
|
||||||
|
|
||||||
# Get current conversation context for model/project info
|
# Get current conversation context for model/project info
|
||||||
app = current_app._get_current_object()
|
app = current_app._get_current_object()
|
||||||
|
|
@ -256,6 +255,8 @@ def multi_agent(arguments: dict) -> dict:
|
||||||
# Use injected model/project_id from executor context, fall back to defaults
|
# Use injected model/project_id from executor context, fall back to defaults
|
||||||
model = arguments.get("_model") or DEFAULT_MODEL
|
model = arguments.get("_model") or DEFAULT_MODEL
|
||||||
project_id = arguments.get("_project_id")
|
project_id = arguments.get("_project_id")
|
||||||
|
max_tokens = arguments.get("_max_tokens", 65536)
|
||||||
|
temperature = arguments.get("_temperature", 0.7)
|
||||||
|
|
||||||
# Execute agents concurrently
|
# Execute agents concurrently
|
||||||
concurrency = min(len(tasks), SUB_AGENT_MAX_CONCURRENCY)
|
concurrency = min(len(tasks), SUB_AGENT_MAX_CONCURRENCY)
|
||||||
|
|
@ -269,7 +270,8 @@ def multi_agent(arguments: dict) -> dict:
|
||||||
task["instruction"],
|
task["instruction"],
|
||||||
task.get("tools"),
|
task.get("tools"),
|
||||||
model,
|
model,
|
||||||
SUB_AGENT_MAX_TOKENS,
|
max_tokens,
|
||||||
|
temperature,
|
||||||
project_id,
|
project_id,
|
||||||
app,
|
app,
|
||||||
SUB_AGENT_MAX_ITERATIONS,
|
SUB_AGENT_MAX_ITERATIONS,
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,10 @@ class ToolExecutor:
|
||||||
args["_model"] = context["model"]
|
args["_model"] = context["model"]
|
||||||
if "project_id" in context:
|
if "project_id" in context:
|
||||||
args["_project_id"] = context["project_id"]
|
args["_project_id"] = context["project_id"]
|
||||||
|
if "max_tokens" in context:
|
||||||
|
args["_max_tokens"] = context["max_tokens"]
|
||||||
|
if "temperature" in context:
|
||||||
|
args["_temperature"] = context["temperature"]
|
||||||
|
|
||||||
def _prepare_call(
|
def _prepare_call(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1022,10 +1022,9 @@ frontend_port: 4000
|
||||||
max_iterations: 15
|
max_iterations: 15
|
||||||
|
|
||||||
# 子代理资源配置(multi_agent 工具)
|
# 子代理资源配置(multi_agent 工具)
|
||||||
|
# max_tokens 和 temperature 与主 Agent 共用,无需单独配置
|
||||||
sub_agent:
|
sub_agent:
|
||||||
max_iterations: 3 # 每个子代理的最大工具调用轮数
|
max_iterations: 3 # 每个子代理的最大工具调用轮数
|
||||||
max_tokens: 4096 # 每次调用的最大 token 数
|
|
||||||
max_agents: 5 # 每次请求最多派生的子代理数
|
|
||||||
max_concurrency: 3 # 并发线程数
|
max_concurrency: 3 # 并发线程数
|
||||||
|
|
||||||
# 可用模型列表(每个模型必须指定 api_url 和 api_key)
|
# 可用模型列表(每个模型必须指定 api_url 和 api_key)
|
||||||
|
|
|
||||||
|
|
@ -319,10 +319,11 @@ execute_python({
|
||||||
| 配置项 | 默认值 | 说明 |
|
| 配置项 | 默认值 | 说明 |
|
||||||
|--------|--------|------|
|
|--------|--------|------|
|
||||||
| `max_iterations` | 3 | 每个子代理的最大工具调用轮数 |
|
| `max_iterations` | 3 | 每个子代理的最大工具调用轮数 |
|
||||||
| `max_tokens` | 4096 | 每次调用的最大 token 数 |
|
|
||||||
| `max_agents` | 5 | 每次请求最多派生的子代理数 |
|
|
||||||
| `max_concurrency` | 3 | ThreadPoolExecutor 并发线程数 |
|
| `max_concurrency` | 3 | ThreadPoolExecutor 并发线程数 |
|
||||||
|
|
||||||
|
> - `max_tokens` 和 `temperature` 与主 Agent 共用,从对话配置中获取,无需单独配置。
|
||||||
|
> - 子代理禁止调用 `multi_agent` 工具,防止无限递归。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 六、核心特性
|
## 六、核心特性
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue