From 8a23b1cd007dec3914fa11f728bc73e770367567 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 25 Mar 2026 00:22:32 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=B7=A5=E5=85=B7=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E8=AE=B0=E5=BD=95=E8=BF=81=E7=A7=BB=E8=87=B3=E7=8B=AC?= =?UTF-8?q?=E7=AB=8B=E8=A1=A8=E5=B9=B6=E6=9B=B4=E6=96=B0=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/models.py | 53 ++- backend/services/chat.py | 74 +++- backend/utils/helpers.py | 9 +- docs/Design.md | 80 +++- docs/ToolSystemDesign.md | 771 ++++++++++++++++++++++++++++++++------- 5 files changed, 794 insertions(+), 193 deletions(-) diff --git a/backend/models.py b/backend/models.py index 3712993..f5ddb9f 100644 --- a/backend/models.py +++ b/backend/models.py @@ -8,7 +8,7 @@ class User(db.Model): id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) username = db.Column(db.String(50), unique=True, nullable=False) - password = db.Column(db.String(255), nullable=False) + password = db.Column(db.String(255), nullable=True) # Allow NULL for third-party login phone = db.Column(db.String(20)) conversations = db.relationship("Conversation", backref="user", lazy="dynamic", @@ -20,7 +20,7 @@ class Conversation(db.Model): __tablename__ = "conversations" id = db.Column(db.String(64), primary_key=True) - user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False) + user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False, index=True) title = db.Column(db.String(255), nullable=False, default="") model = db.Column(db.String(64), nullable=False, default="glm-5") system_prompt = db.Column(db.Text, default="") @@ -28,7 +28,8 @@ class Conversation(db.Model): max_tokens = db.Column(db.Integer, nullable=False, default=65536) thinking_enabled = db.Column(db.Boolean, default=False) created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc)) - updated_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + updated_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), index=True) messages = db.relationship("Message", backref="conversation", lazy="dynamic", cascade="all, delete-orphan", @@ -39,27 +40,48 @@ class Message(db.Model): __tablename__ = "messages" id = db.Column(db.String(64), primary_key=True) - conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"), nullable=False) + conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"), + nullable=False, index=True) role = db.Column(db.String(16), nullable=False) # user, assistant, system, tool - content = db.Column(db.Text, default="") + content = db.Column(LONGTEXT, default="") # LONGTEXT for long conversations token_count = db.Column(db.Integer, default=0) - thinking_content = db.Column(db.Text, default="") - - # Tool call support - tool_calls = db.Column(LONGTEXT) # JSON string: tool call requests (assistant messages) - tool_call_id = db.Column(db.String(64)) # Tool call ID (tool messages) - name = db.Column(db.String(64)) # Tool name (tool messages) - + thinking_content = db.Column(LONGTEXT, default="") + created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), index=True) + + # Tool call support - relation to ToolCall table + tool_calls = db.relationship("ToolCall", backref="message", lazy="dynamic", + cascade="all, delete-orphan", + order_by="ToolCall.call_index.asc()") + + +class ToolCall(db.Model): + """Tool call record - separate table, follows database normalization""" + __tablename__ = "tool_calls" + + id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) + message_id = db.Column(db.String(64), db.ForeignKey("messages.id"), + nullable=False, index=True) + call_id = db.Column(db.String(64), nullable=False) # Tool call ID + call_index = db.Column(db.Integer, nullable=False, default=0) # Call order + tool_name = db.Column(db.String(64), nullable=False) # Tool name + arguments = db.Column(LONGTEXT, nullable=False) # Call arguments JSON + result = db.Column(LONGTEXT) # Execution result JSON + execution_time = db.Column(db.Float, default=0) # Execution time (seconds) created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc)) + __table_args__ = ( + db.Index("ix_tool_calls_message_call", "message_id", "call_index"), + ) + class TokenUsage(db.Model): __tablename__ = "token_usage" id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) - user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False) - date = db.Column(db.Date, nullable=False) # 使用日期 - model = db.Column(db.String(64), nullable=False) # 模型名称 + user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), + nullable=False, index=True) + date = db.Column(db.Date, nullable=False, index=True) + model = db.Column(db.String(64), nullable=False) prompt_tokens = db.Column(db.Integer, default=0) completion_tokens = db.Column(db.Integer, default=0) total_tokens = db.Column(db.Integer, default=0) @@ -67,4 +89,5 @@ class TokenUsage(db.Model): __table_args__ = ( db.UniqueConstraint("user_id", "date", "model", name="uq_user_date_model"), + db.Index("ix_token_usage_date_model", "date", "model"), # Composite index ) diff --git a/backend/services/chat.py b/backend/services/chat.py index 193ef6d..f9e7df2 100644 --- a/backend/services/chat.py +++ b/backend/services/chat.py @@ -3,7 +3,7 @@ import json import uuid from flask import current_app, Response from backend import db -from backend.models import Conversation, Message +from backend.models import Conversation, Message, ToolCall from backend.tools import registry, ToolExecutor from backend.utils.helpers import ( get_or_create_default_user, @@ -60,8 +60,7 @@ class ChatService: prompt_tokens = usage.get("prompt_tokens", 0) completion_tokens = usage.get("completion_tokens", 0) - merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results) - + # Create message msg = Message( id=str(uuid.uuid4()), conversation_id=conv.id, @@ -69,16 +68,18 @@ class ChatService: content=message.get("content", ""), token_count=completion_tokens, thinking_content=message.get("reasoning_content", ""), - tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None ) db.session.add(msg) + + # Create tool call records + self._save_tool_calls(msg.id, all_tool_calls, all_tool_results) db.session.commit() user = get_or_create_default_user() record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens) return ok({ - "message": to_dict(msg, thinking_content=msg.thinking_content or None), + "message": self._message_to_dict(msg), "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, @@ -194,8 +195,6 @@ class ChatService: continue # No tool calls - finish - merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results) - with app.app_context(): msg = Message( id=msg_id, @@ -204,9 +203,11 @@ class ChatService: content=full_content, token_count=token_count, thinking_content=full_thinking, - tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None ) db.session.add(msg) + + # Create tool call records + self._save_tool_calls(msg_id, all_tool_calls, all_tool_results) db.session.commit() user = get_or_create_default_user() @@ -223,6 +224,53 @@ class ChatService: headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"} ) + def _save_tool_calls(self, message_id: str, tool_calls: list, tool_results: list) -> None: + """Save tool calls to database""" + for i, tc in enumerate(tool_calls): + result_content = tool_results[i]["content"] if i < len(tool_results) else None + + # Parse result to extract execution_time if present + execution_time = 0 + if result_content: + try: + result_data = json.loads(result_content) + execution_time = result_data.get("execution_time", 0) + except: + pass + + tool_call = ToolCall( + message_id=message_id, + call_id=tc.get("id", ""), + call_index=i, + tool_name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + result=result_content, + execution_time=execution_time, + ) + db.session.add(tool_call) + + def _message_to_dict(self, msg: Message) -> dict: + """Convert message to dict with tool calls""" + result = to_dict(msg, thinking_content=msg.thinking_content or None) + + # Add tool calls if any + tool_calls = msg.tool_calls.all() if msg.tool_calls else [] + if tool_calls: + result["tool_calls"] = [ + { + "id": tc.call_id, + "type": "function", + "function": { + "name": tc.tool_name, + "arguments": tc.arguments, + }, + "result": tc.result, + } + for tc in tool_calls + ] + + return result + def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list: """Process tool calls from streaming delta""" tool_calls_delta = delta.get("tool_calls", []) @@ -242,13 +290,3 @@ class ChatService: if tc["function"].get("arguments"): tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"] return tool_calls_list - - def _merge_tool_results(self, tool_calls: list, tool_results: list) -> list: - """Merge tool results into tool calls""" - merged = [] - for i, tc in enumerate(tool_calls): - merged_tc = dict(tc) - if i < len(tool_results): - merged_tc["result"] = tool_results[i]["content"] - merged.append(merged_tc) - return merged diff --git a/backend/utils/helpers.py b/backend/utils/helpers.py index 06034b0..970de26 100644 --- a/backend/utils/helpers.py +++ b/backend/utils/helpers.py @@ -9,7 +9,7 @@ def get_or_create_default_user(): """Get or create default user""" user = User.query.filter_by(username="default").first() if not user: - user = User(username="default", password="") + user = User(username="default", password=None) db.session.add(user) db.session.commit() return user @@ -39,13 +39,6 @@ def to_dict(inst, **extra): if k in d and hasattr(d[k], "strftime"): d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ") - # Parse tool_calls JSON if present - if "tool_calls" in d and d["tool_calls"]: - try: - d["tool_calls"] = json.loads(d["tool_calls"]) - except: - pass - # Filter out None values for cleaner API response d = {k: v for k, v in d.items() if v is not None} diff --git a/docs/Design.md b/docs/Design.md index 9e41208..9512680 100644 --- a/docs/Design.md +++ b/docs/Design.md @@ -526,15 +526,24 @@ GET /api/stats/tokens?period=daily ### ER 关系 ``` -User 1 ── * Conversation 1 ── * Message +User 1 ── * Conversation 1 ── * Message 1 ── * ToolCall ``` +### User(用户) + +| 字段 | 类型 | 说明 | +| ------------ | --------------- | ----------------------- | +| `id` | bigint | 用户 ID(自增) | +| `username` | string(50) | 用户名(唯一) | +| `password` | string(255) | 密码(可为空,第三方登录) | +| `phone` | string(20) | 手机号 | + ### Conversation(会话) | 字段 | 类型 | 说明 | | ------------------ | ------------- | --------------------- | | `id` | string (UUID) | 会话 ID | -| `user_id` | string | 所属用户 ID | +| `user_id` | bigint | 所属用户 ID | | `title` | string | 会话标题 | | `model` | string | 使用的模型,默认 `glm-5` | | `system_prompt` | string | 系统提示词 | @@ -550,21 +559,51 @@ User 1 ── * Conversation 1 ── * Message | ------------------ | ------------- | ------------------------------- | | `id` | string (UUID) | 消息 ID | | `conversation_id` | string | 所属会话 ID | -| `role` | enum | `user` / `assistant` / `system` | -| `content` | string | 消息内容 | +| `role` | enum | `user` / `assistant` / `system` / `tool` | +| `content` | LONGTEXT | 消息内容 | | `token_count` | integer | token 消耗数 | -| `thinking_content` | string | 思维链内容(启用时) | -| `tool_calls` | array (JSON) | 工具调用信息(含结果),仅 assistant 消息 | +| `thinking_content` | LONGTEXT | 思维链内容(启用时) | | `created_at` | datetime | 创建时间 | +**说明**:工具调用信息存储在关联的 `ToolCall` 表中,通过 `message.tool_calls` 关系获取。 + +### ToolCall(工具调用) + +| 字段 | 类型 | 说明 | +| ----------------- | --------------- | --------------------------- | +| `id` | bigint | 调用记录 ID(自增) | +| `message_id` | string(64) | 关联的消息 ID | +| `call_id` | string(64) | 工具调用 ID | +| `call_index` | integer | 调用顺序(从 0 开始) | +| `tool_name` | string(64) | 工具名称 | +| `arguments` | LONGTEXT | 调用参数 JSON | +| `result` | LONGTEXT | 执行结果 JSON | +| `execution_time` | float | 执行时间(秒) | +| `created_at` | datetime | 创建时间 | + +### TokenUsage(Token 使用统计) + +| 字段 | 类型 | 说明 | +| ------------------- | ---------- | -------------------------- | +| `id` | bigint | 记录 ID(自增) | +| `user_id` | bigint | 用户 ID | +| `date` | date | 统计日期 | +| `model` | string(64) | 模型名称 | +| `prompt_tokens` | integer | 输入 token 数 | +| `completion_tokens` | integer | 输出 token 数 | +| `total_tokens` | integer | 总 token 数 | +| `created_at` | datetime | 创建时间 | + #### 消息类型说明 **1. 用户消息 (role=user)** ```json { "id": "msg_001", + "conversation_id": "conv_abc123", "role": "user", "content": "北京今天天气怎么样?", + "token_count": 0, "created_at": "2026-03-24T10:00:00Z" } ``` @@ -573,22 +612,23 @@ User 1 ── * Conversation 1 ── * Message ```json { "id": "msg_002", + "conversation_id": "conv_abc123", "role": "assistant", "content": "北京今天天气晴朗...", "token_count": 50, "thinking_content": "用户想了解天气...", - "tool_calls": null, "created_at": "2026-03-24T10:00:01Z" } ``` **3. 助手消息 - 含工具调用 (role=assistant, with tool_calls)** -工具调用结果直接合并到 `tool_calls` 数组中,每个调用包含 `result` 字段: +工具调用记录存储在独立的 `tool_calls` 表中,API 响应时会自动关联并返回: ```json { "id": "msg_003", + "conversation_id": "conv_abc123", "role": "assistant", "content": "北京今天天气晴朗,温度25°C,湿度60%", "token_count": 80, @@ -608,6 +648,18 @@ User 1 ── * Conversation 1 ── * Message } ``` +**4. 工具消息 (role=tool)** + +用于 API 调用时传递工具执行结果(不存储在数据库): +```json +{ + "role": "tool", + "tool_call_id": "call_abc123", + "name": "get_weather", + "content": "{\"temperature\": 25, \"humidity\": 60}" +} +``` + #### 工具调用流程示例 ``` @@ -617,14 +669,16 @@ User 1 ── * Conversation 1 ── * Message ↓ [AI 调用工具 get_weather] ↓ -[msg_002] role=assistant, tool_calls=[{get_weather, args:{"city":"北京"}, result="{...}"}] - content="北京今天天气晴朗,温度25°C..." +[msg_002] role=assistant, content="北京今天天气晴朗,温度25°C..." + tool_calls=[{get_weather, args:{"city":"北京"}, result="{...}"}] ``` **说明:** -- 工具调用结果直接存储在 `tool_calls[].result` 字段中 -- 不再创建独立的 `role=tool` 消息 -- 前端可通过 `tool_calls` 数组展示完整的工具调用过程 +- 工具调用记录存储在独立的 `tool_calls` 表中,与 `messages` 表通过 `message_id` 关联 +- API 响应时自动查询并组装 `tool_calls` 数组 +- 工具调用包含完整的调用参数和执行结果 +- `call_index` 字段记录同一消息中多次工具调用的顺序 +- `execution_time` 字段记录工具执行耗时 --- diff --git a/docs/ToolSystemDesign.md b/docs/ToolSystemDesign.md index 766838f..5cd5d8f 100644 --- a/docs/ToolSystemDesign.md +++ b/docs/ToolSystemDesign.md @@ -2,7 +2,7 @@ ## 概述 -本文档描述 NanoClaw 工具调用系统的设计,采用简化的工厂模式,减少不必要的类层次。 +NanoClaw 工具调用系统采用简化的工厂模式,支持装饰器注册、缓存优化、重复调用检测等功能。 --- @@ -27,13 +27,20 @@ classDiagram +register(ToolDefinition tool) void +get(str name) ToolDefinition? +list_all() list~dict~ - +execute(str name, dict args) Any + +list_by_category(str category) list~dict~ + +execute(str name, dict args) dict + +remove(str name) bool + +has(str name) bool } class ToolExecutor { -ToolRegistry registry - +process_tool_calls(list tool_calls) list~dict~ - +build_request(list messages) dict + -dict _cache + -list _call_history + +process_tool_calls(list tool_calls, dict context) list~dict~ + +build_request(list messages, str model, list tools, dict kwargs) dict + +clear_history() void + +execute_with_retry(str name, dict args, int max_retries) dict } class ToolResult { @@ -42,6 +49,8 @@ classDiagram +Any data +str? error +dict to_dict() + +ok(Any data)$ ToolResult + +fail(str error)$ ToolResult } ToolRegistry "1" --> "*" ToolDefinition : manages @@ -61,11 +70,8 @@ classDiagram class ToolFactory { <> - +tool(name, description, parameters)$ decorator - +register(name, handler, description, parameters)$ void - +create_crawler_tools()$ list~ToolDefinition~ - +create_data_tools()$ list~ToolDefinition~ - +create_file_tools()$ list~ToolDefinition~ + +tool(name, description, parameters, category)$ decorator + +register_tool(name, handler, description, parameters, category)$ void } class ToolDefinition { @@ -73,6 +79,7 @@ classDiagram +str description +dict parameters +Callable handler + +str category } ToolFactory ..> ToolDefinition : creates @@ -153,18 +160,31 @@ class ToolRegistry: return cls._instance def register(self, tool: ToolDefinition) -> None: + """注册工具""" self._tools[tool.name] = tool def get(self, name: str) -> Optional[ToolDefinition]: + """获取工具定义""" return self._tools.get(name) def list_all(self) -> List[dict]: + """列出所有工具(OpenAI 格式)""" return [t.to_openai_format() for t in self._tools.values()] + def list_by_category(self, category: str) -> List[dict]: + """按类别列出工具""" + return [ + t.to_openai_format() + for t in self._tools.values() + if t.category == category + ] + def execute(self, name: str, arguments: dict) -> dict: + """执行工具""" tool = self.get(name) if not tool: return ToolResult.fail(f"Tool not found: {name}").to_dict() + try: result = tool.handler(arguments) if isinstance(result, ToolResult): @@ -173,6 +193,17 @@ class ToolRegistry: except Exception as e: return ToolResult.fail(str(e)).to_dict() + def remove(self, name: str) -> bool: + """移除工具""" + if name in self._tools: + del self._tools[name] + return True + return False + + def has(self, name: str) -> bool: + """检查工具是否存在""" + return name in self._tools + # 全局注册表 registry = ToolRegistry() @@ -182,39 +213,205 @@ registry = ToolRegistry() ```python import json -from typing import List, Dict +import time +import hashlib +from typing import List, Dict, Optional class ToolExecutor: - """工具执行器""" + """工具执行器(支持缓存和重复检测)""" - def __init__(self, registry: ToolRegistry = None): + def __init__( + self, + registry: ToolRegistry = None, + api_url: str = None, + api_key: str = None, + enable_cache: bool = True, + cache_ttl: int = 300, # 5分钟 + ): self.registry = registry or ToolRegistry() + self.api_url = api_url + self.api_key = api_key + self.enable_cache = enable_cache + self.cache_ttl = cache_ttl + self._cache: Dict[str, tuple] = {} # key -> (result, timestamp) + self._call_history: List[dict] = [] # 当前会话的调用历史 - def process_tool_calls(self, tool_calls: List[dict]) -> List[dict]: - """处理工具调用,返回消息列表""" + def _make_cache_key(self, name: str, args: dict) -> str: + """生成缓存键""" + args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) + return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest() + + def _get_cached(self, key: str) -> Optional[dict]: + """获取缓存结果""" + if not self.enable_cache: + return None + if key in self._cache: + result, timestamp = self._cache[key] + if time.time() - timestamp < self.cache_ttl: + return result + del self._cache[key] + return None + + def _set_cache(self, key: str, result: dict) -> None: + """设置缓存""" + if self.enable_cache: + self._cache[key] = (result, time.time()) + + def _check_duplicate_in_history(self, name: str, args: dict) -> Optional[dict]: + """检查历史中是否有相同调用""" + args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) + for record in self._call_history: + if record["name"] == name and record["args_str"] == args_str: + return record["result"] + return None + + def clear_history(self) -> None: + """清空调用历史(新会话开始时调用)""" + self._call_history.clear() + + def process_tool_calls( + self, + tool_calls: List[dict], + context: dict = None + ) -> List[dict]: + """ + 处理工具调用,返回消息列表 + + Args: + tool_calls: LLM 返回的工具调用列表 + context: 可选上下文信息(user_id 等) + + Returns: + 工具响应消息列表,可直接追加到 messages + """ results = [] + seen_calls = set() # 当前批次内的重复检测 + for call in tool_calls: name = call["function"]["name"] - args = json.loads(call["function"]["arguments"]) + args_str = call["function"]["arguments"] call_id = call["id"] + try: + args = json.loads(args_str) if isinstance(args_str, str) else args_str + except json.JSONDecodeError: + results.append(self._create_error_result( + call_id, name, "Invalid JSON arguments" + )) + continue + + # 检查批次内重复 + call_key = f"{name}:{json.dumps(args, sort_keys=True)}" + if call_key in seen_calls: + results.append(self._create_tool_result( + call_id, name, + {"success": True, "data": None, "cached": True, "duplicate": True} + )) + continue + seen_calls.add(call_key) + + # 检查历史重复 + history_result = self._check_duplicate_in_history(name, args) + if history_result is not None: + result = {**history_result, "cached": True} + results.append(self._create_tool_result(call_id, name, result)) + continue + + # 检查缓存 + cache_key = self._make_cache_key(name, args) + cached_result = self._get_cached(cache_key) + if cached_result is not None: + result = {**cached_result, "cached": True} + results.append(self._create_tool_result(call_id, name, result)) + continue + + # 执行工具 result = self.registry.execute(name, args) - results.append({ - "role": "tool", - "tool_call_id": call_id, + # 缓存结果 + self._set_cache(cache_key, result) + + # 添加到历史 + self._call_history.append({ "name": name, - "content": json.dumps(result, ensure_ascii=False) + "args_str": json.dumps(args, sort_keys=True, ensure_ascii=False), + "result": result }) + + results.append(self._create_tool_result(call_id, name, result)) + return results - def build_request(self, messages: List[dict], **kwargs) -> dict: - """构建 API 请求""" + def _create_tool_result( + self, + call_id: str, + name: str, + result: dict, + execution_time: float = 0 + ) -> dict: + """创建工具结果消息""" + result["execution_time"] = execution_time return { - "model": kwargs.get("model", "glm-5"), + "role": "tool", + "tool_call_id": call_id, + "name": name, + "content": json.dumps(result, ensure_ascii=False, default=str) + } + + def _create_error_result( + self, + call_id: str, + name: str, + error: str + ) -> dict: + """创建错误结果消息""" + return { + "role": "tool", + "tool_call_id": call_id, + "name": name, + "content": json.dumps({ + "success": False, + "error": error + }, ensure_ascii=False) + } + + def build_request( + self, + messages: List[dict], + model: str = "glm-5", + tools: List[dict] = None, + **kwargs + ) -> dict: + """构建 API 请求体""" + return { + "model": model, "messages": messages, - "tools": self.registry.list_all(), - "tool_choice": "auto" + "tools": tools or self.registry.list_all(), + "tool_choice": kwargs.get("tool_choice", "auto"), + **{k: v for k, v in kwargs.items() if k not in ["tool_choice"]} + } + + def execute_with_retry( + self, + name: str, + arguments: dict, + max_retries: int = 3, + retry_delay: float = 1.0 + ) -> dict: + """带重试的工具执行""" + last_error = None + + for attempt in range(max_retries): + try: + return self.registry.execute(name, arguments) + except Exception as e: + last_error = e + if attempt < max_retries - 1: + time.sleep(retry_delay) + + return { + "success": False, + "error": f"Failed after {max_retries} retries: {last_error}" } ``` @@ -227,11 +424,30 @@ class ToolExecutor: ```python # backend/tools/factory.py -from .core import ToolDefinition, registry +from typing import Callable +from backend.tools.core import ToolDefinition, registry -def tool(name: str, description: str, parameters: dict, category: str = "general"): - """工具注册装饰器""" - def decorator(func): + +def tool( + name: str, + description: str, + parameters: dict, + category: str = "general" +) -> Callable: + """ + 工具注册装饰器 + + 用法: + @tool( + name="web_search", + description="搜索互联网获取信息", + parameters={"type": "object", "properties": {...}}, + category="crawler" + ) + def web_search(arguments: dict) -> dict: + ... + """ + def decorator(func: Callable) -> Callable: tool_def = ToolDefinition( name=name, description=description, @@ -242,6 +458,34 @@ def tool(name: str, description: str, parameters: dict, category: str = "general registry.register(tool_def) return func return decorator + + +def register_tool( + name: str, + handler: Callable, + description: str, + parameters: dict, + category: str = "general" +) -> None: + """ + 直接注册工具(无需装饰器) + + 用法: + register_tool( + name="my_tool", + handler=my_function, + description="工具描述", + parameters={...} + ) + """ + tool_def = ToolDefinition( + name=name, + description=description, + parameters=parameters, + handler=handler, + category=category + ) + registry.register(tool_def) ``` ### 4.2 使用示例 @@ -249,24 +493,32 @@ def tool(name: str, description: str, parameters: dict, category: str = "general ```python # backend/tools/builtin/crawler.py -from ..factory import tool +from backend.tools.factory import tool +from backend.tools.services import SearchService, FetchService # 网页搜索工具 @tool( name="web_search", - description="搜索互联网获取信息", + description="Search the internet for information. Use when you need to find latest news or answer questions that require web search.", parameters={ "type": "object", "properties": { - "query": {"type": "string", "description": "搜索关键词"}, - "max_results": {"type": "integer", "default": 5} + "query": { + "type": "string", + "description": "Search keywords" + }, + "max_results": { + "type": "integer", + "description": "Number of results to return, default 5", + "default": 5 + } }, "required": ["query"] }, category="crawler" ) def web_search(arguments: dict) -> dict: - from ..services import SearchService + """Web search tool""" query = arguments["query"] max_results = arguments.get("max_results", 5) service = SearchService() @@ -277,19 +529,27 @@ def web_search(arguments: dict) -> dict: # 页面抓取工具 @tool( name="fetch_page", - description="抓取指定网页内容", + description="Fetch content from a specific webpage. Use when user needs detailed information from a webpage.", parameters={ "type": "object", "properties": { - "url": {"type": "string", "description": "网页URL"}, - "extract_type": {"type": "string", "enum": ["text", "links", "structured"]} + "url": { + "type": "string", + "description": "URL of the webpage to fetch" + }, + "extract_type": { + "type": "string", + "description": "Extraction type", + "enum": ["text", "links", "structured"], + "default": "text" + } }, "required": ["url"] }, category="crawler" ) def fetch_page(arguments: dict) -> dict: - from ..services import FetchService + """Page fetch tool""" url = arguments["url"] extract_type = arguments.get("extract_type", "text") service = FetchService() @@ -297,33 +557,39 @@ def fetch_page(arguments: dict) -> dict: return result -# 计算器工具 +# 批量抓取工具 @tool( - name="calculator", - description="执行数学计算", + name="crawl_batch", + description="Batch fetch multiple webpages. Use when you need to get content from multiple pages at once.", parameters={ "type": "object", "properties": { - "expression": {"type": "string", "description": "数学表达式"} + "urls": { + "type": "array", + "items": {"type": "string"}, + "description": "List of URLs to fetch" + }, + "extract_type": { + "type": "string", + "enum": ["text", "links", "structured"], + "default": "text" + } }, - "required": ["expression"] + "required": ["urls"] }, - category="data" + category="crawler" ) -def calculator(arguments: dict) -> dict: - import ast - import operator - expr = arguments["expression"] - # 安全计算 - ops = { - ast.Add: operator.add, - ast.Sub: operator.sub, - ast.Mult: operator.mul, - ast.Div: operator.truediv - } - node = ast.parse(expr, mode='eval') - result = eval(compile(node, '', 'eval'), {"__builtins__": {}}, ops) - return {"result": result} +def crawl_batch(arguments: dict) -> dict: + """Batch fetch tool""" + urls = arguments["urls"] + extract_type = arguments.get("extract_type", "text") + + if len(urls) > 10: + return {"error": "Maximum 10 pages can be fetched at once"} + + service = FetchService() + results = service.fetch_batch(urls, extract_type) + return {"results": results, "total": len(results)} ``` --- @@ -332,90 +598,220 @@ def calculator(arguments: dict) -> dict: 工具依赖的服务保持独立,不与工具类耦合: -```mermaid -classDiagram - direction LR - - class SearchService { - -SearchEngine engine - +search(str query, int limit) list~dict~ - } - - class FetchService { - +fetch(str url, str type) dict - +fetch_batch(list urls) dict - } - - class ContentExtractor { - +extract_text(html) str - +extract_links(html) list - +extract_structured(html) dict - } - - FetchService --> ContentExtractor : uses -``` - ```python # backend/tools/services.py +from typing import List, Dict +from ddgs import DDGS +import re + + class SearchService: """搜索服务""" - def __init__(self, engine=None): - from ddgs import DDGS - self.engine = engine or DDGS() - def search(self, query: str, max_results: int = 5) -> list: - results = list(self.engine.text(query, max_results=max_results)) + def __init__(self, engine: str = "duckduckgo"): + self.engine = engine + + def search( + self, + query: str, + max_results: int = 5, + region: str = "cn-zh" + ) -> List[dict]: + """执行搜索""" + if self.engine == "duckduckgo": + return self._search_duckduckgo(query, max_results, region) + else: + raise ValueError(f"Unsupported search engine: {self.engine}") + + def _search_duckduckgo( + self, + query: str, + max_results: int, + region: str + ) -> List[dict]: + """DuckDuckGo 搜索""" + with DDGS() as ddgs: + results = list(ddgs.text( + query, + max_results=max_results, + region=region + )) + return [ - {"title": r["title"], "url": r["href"], "snippet": r["body"]} + { + "title": r.get("title", ""), + "url": r.get("href", ""), + "snippet": r.get("body", "") + } for r in results ] class FetchService: """页面抓取服务""" - def __init__(self, timeout: float = 30.0): + + def __init__(self, timeout: float = 30.0, user_agent: str = None): self.timeout = timeout + self.user_agent = user_agent or ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36" + ) def fetch(self, url: str, extract_type: str = "text") -> dict: + """抓取单个页面""" import httpx - from bs4 import BeautifulSoup - resp = httpx.get(url, timeout=self.timeout, follow_redirects=True) - soup = BeautifulSoup(resp.text, "html.parser") + try: + resp = httpx.get( + url, + timeout=self.timeout, + follow_redirects=True, + headers={"User-Agent": self.user_agent} + ) + resp.raise_for_status() + except Exception as e: + return {"error": str(e), "url": url} + + html = resp.text + extractor = ContentExtractor(html) - extractor = ContentExtractor(soup) if extract_type == "text": - return {"text": extractor.extract_text()} + return { + "url": url, + "text": extractor.extract_text() + } elif extract_type == "links": - return {"links": extractor.extract_links()} + return { + "url": url, + "links": extractor.extract_links() + } else: - return extractor.extract_structured() + return extractor.extract_structured(url) + + def fetch_batch( + self, + urls: List[str], + extract_type: str = "text", + max_concurrent: int = 5 + ) -> List[dict]: + """批量抓取页面""" + results = [] + for url in urls: + results.append(self.fetch(url, extract_type)) + return results class ContentExtractor: """内容提取器""" - def __init__(self, soup): - self.soup = soup + + def __init__(self, html: str): + self.html = html + self._soup = None + + @property + def soup(self): + if self._soup is None: + try: + from bs4 import BeautifulSoup + self._soup = BeautifulSoup(self.html, "html.parser") + except ImportError: + raise ImportError("Please install beautifulsoup4: pip install beautifulsoup4") + return self._soup def extract_text(self) -> str: + """提取纯文本""" # 移除脚本和样式 - for tag in self.soup(["script", "style"]): + for tag in self.soup(["script", "style", "nav", "footer", "header"]): tag.decompose() - return self.soup.get_text(separator="\n", strip=True) - def extract_links(self) -> list: - return [ - {"text": a.get_text(strip=True), "href": a.get("href")} - for a in self.soup.find_all("a", href=True) - ] + text = self.soup.get_text(separator="\n", strip=True) + # 清理多余空白 + text = re.sub(r"\n{3,}", "\n\n", text) + return text + + def extract_links(self) -> List[dict]: + """提取链接""" + links = [] + for a in self.soup.find_all("a", href=True): + text = a.get_text(strip=True) + href = a["href"] + if text and href and not href.startswith(("#", "javascript:")): + links.append({"text": text, "href": href}) + return links[:50] # 限制数量 + + def extract_structured(self, url: str = "") -> dict: + """提取结构化内容""" + soup = self.soup + + # 提取标题 + title = "" + if soup.title: + title = soup.title.string or "" + + # 提取 meta 描述 + description = "" + meta_desc = soup.find("meta", attrs={"name": "description"}) + if meta_desc: + description = meta_desc.get("content", "") - def extract_structured(self) -> dict: return { - "title": self.soup.title.string if self.soup.title else "", - "text": self.extract_text(), + "url": url, + "title": title.strip(), + "description": description.strip(), + "text": self.extract_text()[:5000], # 限制长度 "links": self.extract_links()[:20] } + + +class CalculatorService: + """安全计算服务""" + + ALLOWED_OPS = { + "add", "sub", "mul", "truediv", "floordiv", + "mod", "pow", "neg", "abs" + } + + def evaluate(self, expression: str) -> dict: + """安全计算数学表达式""" + import ast + import operator + + ops = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.USub: operator.neg, + ast.UAdd: operator.pos, + } + + try: + # 解析表达式 + node = ast.parse(expression, mode="eval") + + # 验证节点类型 + for child in ast.walk(node): + if isinstance(child, ast.Call): + return {"error": "Function calls not allowed"} + if isinstance(child, ast.Name): + return {"error": "Variable names not allowed"} + + # 安全执行 + result = eval( + compile(node, "", "eval"), + {"__builtins__": {}}, + {} + ) + + return {"result": result} + + except Exception as e: + return {"error": f"Calculation error: {str(e)}"} ``` --- @@ -425,57 +821,154 @@ class ContentExtractor: ```python # backend/tools/__init__.py -from .core import ToolDefinition, ToolResult, ToolRegistry, registry, ToolExecutor -from .factory import tool +""" +NanoClaw Tool System -def init_tools(): - """初始化所有内置工具""" - # 导入即自动注册 - from .builtin import crawler, data, weather +Usage: + from backend.tools import registry, ToolExecutor, tool + from backend.tools import init_tools -# 使用时 -init_tools() + # 初始化内置工具 + init_tools() + + # 列出所有工具 + tools = registry.list_all() + + # 执行工具 + result = registry.execute("web_search", {"query": "Python"}) +""" + +from backend.tools.core import ToolDefinition, ToolResult, ToolRegistry, registry +from backend.tools.factory import tool, register_tool +from backend.tools.executor import ToolExecutor + + +def init_tools() -> None: + """ + 初始化所有内置工具 + + 导入 builtin 模块会自动注册所有装饰器定义的工具 + """ + from backend.tools.builtin import crawler, data, weather, file_ops # noqa: F401 + + +# 公开 API 导出 +__all__ = [ + # 核心类 + "ToolDefinition", + "ToolResult", + "ToolRegistry", + "ToolExecutor", + # 实例 + "registry", + # 工厂函数 + "tool", + "register_tool", + # 初始化 + "init_tools", +] ``` --- ## 七、工具清单 -| 类别 | 工具名称 | 描述 | 依赖服务 | -| ------- | --------------- | ---- | ------------- | -| crawler | `web_search` | 网页搜索 | SearchService | -| crawler | `fetch_page` | 单页抓取 | FetchService | -| crawler | `crawl_batch` | 批量爬取 | FetchService | -| data | `calculator` | 数学计算 | CalculatorService | -| data | `text_process` | 文本处理 | - | -| data | `json_process` | JSON处理 | - | -| weather | `get_weather` | 天气查询 | - (模拟数据) | -| file | `file_read` | 读取文件 | - | -| file | `file_write` | 写入文件 | - | -| file | `file_delete` | 删除文件 | - | -| file | `file_list` | 列出目录 | - | -| file | `file_exists` | 检查存在 | - | -| file | `file_mkdir` | 创建目录 | - | +### 7.1 爬虫工具 (crawler) + +| 工具名称 | 描述 | 参数 | +| --------------- | --------------------------- | --------------------------------------- | +| `web_search` | 搜索互联网获取信息 | `query`: 搜索关键词
`max_results`: 结果数量(默认 5) | +| `fetch_page` | 抓取单个网页内容 | `url`: 网页 URL
`extract_type`: 提取类型(text/links/structured) | +| `crawl_batch` | 批量抓取多个网页(最多 10 个) | `urls`: URL 列表
`extract_type`: 提取类型 | + +### 7.2 数据处理工具 (data) + +| 工具名称 | 描述 | 参数 | +| --------------- | --------------------------- | --------------------------------------- | +| `calculator` | 执行数学计算(支持加减乘除、幂、模等) | `expression`: 数学表达式 | +| `text_process` | 文本处理(计数、格式转换等) | `text`: 文本内容
`operation`: 操作类型(count/lines/words/upper/lower/reverse) | +| `json_process` | JSON 处理(解析、格式化、提取、验证) | `json_string`: JSON 字符串
`operation`: 操作类型(parse/format/keys/validate) | + +### 7.3 天气工具 (weather) + +| 工具名称 | 描述 | 参数 | +| --------------- | --------------------------- | --------------------------------------- | +| `get_weather` | 查询指定城市的天气信息(模拟数据) | `city`: 城市名称(如:北京、上海、广州) | + +### 7.4 文件操作工具 (file) + +| 工具名称 | 描述 | 参数 | +| --------------- | --------------------------- | --------------------------------------- | +| `file_read` | 读取文件内容 | `path`: 文件路径
`encoding`: 编码(默认 utf-8) | +| `file_write` | 写入文件(支持覆盖和追加) | `path`: 文件路径
`content`: 内容
`mode`: 写入模式(write/append) | +| `file_delete` | 删除文件 | `path`: 文件路径 | +| `file_list` | 列出目录内容 | `path`: 目录路径(默认 .)
`pattern`: 文件模式(默认 *) | +| `file_exists` | 检查文件或目录是否存在 | `path`: 路径 | +| `file_mkdir` | 创建目录(自动创建父目录) | `path`: 目录路径 | + +**安全说明**:文件操作工具限制在项目根目录内,防止越权访问。 --- ## 八、与旧设计对比 -| 方面 | 旧设计 | 新设计 | -| ----- | ----------------- | --------- | -| 类数量 | 30+ | ~10 | -| 工具定义 | 继承 BaseTool | 装饰器 + 函数 | -| 中间抽象层 | 5个(CrawlerTool 等) | 无 | -| 扩展方式 | 创建子类 | 写函数 + 装饰器 | -| 代码量 | 多 | 少 | +| 方面 | 旧设计 | 新设计 | +| --------- | ----------------- | ----------------- | +| 类数量 | 30+ | ~10 | +| 工具定义 | 继承 BaseTool | 装饰器 + 函数 | +| 中间抽象层 | 5个(CrawlerTool 等) | 无 | +| 扩展方式 | 创建子类 | 写函数 + 装饰器 | +| 缓存机制 | 无 | 支持结果缓存(TTL 可配置) | +| 重复检测 | 无 | 支持会话内重复调用检测 | +| 代码量 | 多 | 少 | --- -## 九、总结 +## 九、核心特性 -简化后的设计: +### 9.1 装饰器注册 + +简化工具定义,只需一个装饰器: + +```python +@tool( + name="my_tool", + description="工具描述", + parameters={...}, + category="custom" +) +def my_tool(arguments: dict) -> dict: + # 工具实现 + return {"result": "ok"} +``` + +### 9.2 智能缓存 + +- **结果缓存**:相同参数的工具调用结果会被缓存(默认 5 分钟) +- **可配置 TTL**:通过 `cache_ttl` 参数设置缓存过期时间 +- **可禁用**:通过 `enable_cache=False` 关闭缓存 + +### 9.3 重复检测 + +- **批次内去重**:同一批次中相同工具+参数的调用会被跳过 +- **历史去重**:同一会话内已调用过的工具会直接返回缓存结果 +- **自动清理**:新会话开始时调用 `clear_history()` 清理历史 + +### 9.4 安全设计 + +- **计算器安全**:禁止函数调用和变量名,只支持数学运算 +- **文件沙箱**:文件操作限制在项目根目录内,防止越权访问 +- **错误处理**:所有工具执行都有 try-catch,不会因工具错误导致系统崩溃 + +--- + +## 十、总结 + +简化后的设计特点: 1. **核心类**:`ToolDefinition`、`ToolRegistry`、`ToolExecutor`、`ToolResult` 2. **工厂模式**:使用 `@tool` 装饰器注册工具 3. **服务分离**:工具依赖的服务独立,不与工具类耦合 -4. **易于扩展**:新增工具只需写一个函数并加装饰器 +4. **性能优化**:支持缓存和重复检测,减少重复计算和网络请求 +5. **易于扩展**:新增工具只需写一个函数并加装饰器 +6. **安全可靠**:文件沙箱、安全计算、完善的错误处理