refactor: 工具调用记录迁移至独立表并更新文档

This commit is contained in:
ViperEkura 2026-03-25 00:22:32 +08:00
parent 362ab15338
commit 8a23b1cd00
5 changed files with 794 additions and 193 deletions

View File

@ -8,7 +8,7 @@ class User(db.Model):
id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) id = db.Column(db.BigInteger, primary_key=True, autoincrement=True)
username = db.Column(db.String(50), unique=True, nullable=False) username = db.Column(db.String(50), unique=True, nullable=False)
password = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) # Allow NULL for third-party login
phone = db.Column(db.String(20)) phone = db.Column(db.String(20))
conversations = db.relationship("Conversation", backref="user", lazy="dynamic", conversations = db.relationship("Conversation", backref="user", lazy="dynamic",
@ -20,7 +20,7 @@ class Conversation(db.Model):
__tablename__ = "conversations" __tablename__ = "conversations"
id = db.Column(db.String(64), primary_key=True) id = db.Column(db.String(64), primary_key=True)
user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False) user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False, index=True)
title = db.Column(db.String(255), nullable=False, default="") title = db.Column(db.String(255), nullable=False, default="")
model = db.Column(db.String(64), nullable=False, default="glm-5") model = db.Column(db.String(64), nullable=False, default="glm-5")
system_prompt = db.Column(db.Text, default="") system_prompt = db.Column(db.Text, default="")
@ -28,7 +28,8 @@ class Conversation(db.Model):
max_tokens = db.Column(db.Integer, nullable=False, default=65536) max_tokens = db.Column(db.Integer, nullable=False, default=65536)
thinking_enabled = db.Column(db.Boolean, default=False) thinking_enabled = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc)) created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) updated_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc), index=True)
messages = db.relationship("Message", backref="conversation", lazy="dynamic", messages = db.relationship("Message", backref="conversation", lazy="dynamic",
cascade="all, delete-orphan", cascade="all, delete-orphan",
@ -39,27 +40,48 @@ class Message(db.Model):
__tablename__ = "messages" __tablename__ = "messages"
id = db.Column(db.String(64), primary_key=True) id = db.Column(db.String(64), primary_key=True)
conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"), nullable=False) conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"),
nullable=False, index=True)
role = db.Column(db.String(16), nullable=False) # user, assistant, system, tool role = db.Column(db.String(16), nullable=False) # user, assistant, system, tool
content = db.Column(db.Text, default="") content = db.Column(LONGTEXT, default="") # LONGTEXT for long conversations
token_count = db.Column(db.Integer, default=0) token_count = db.Column(db.Integer, default=0)
thinking_content = db.Column(db.Text, default="") thinking_content = db.Column(LONGTEXT, default="")
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), index=True)
# Tool call support
tool_calls = db.Column(LONGTEXT) # JSON string: tool call requests (assistant messages) # Tool call support - relation to ToolCall table
tool_call_id = db.Column(db.String(64)) # Tool call ID (tool messages) tool_calls = db.relationship("ToolCall", backref="message", lazy="dynamic",
name = db.Column(db.String(64)) # Tool name (tool messages) cascade="all, delete-orphan",
order_by="ToolCall.call_index.asc()")
class ToolCall(db.Model):
"""Tool call record - separate table, follows database normalization"""
__tablename__ = "tool_calls"
id = db.Column(db.BigInteger, primary_key=True, autoincrement=True)
message_id = db.Column(db.String(64), db.ForeignKey("messages.id"),
nullable=False, index=True)
call_id = db.Column(db.String(64), nullable=False) # Tool call ID
call_index = db.Column(db.Integer, nullable=False, default=0) # Call order
tool_name = db.Column(db.String(64), nullable=False) # Tool name
arguments = db.Column(LONGTEXT, nullable=False) # Call arguments JSON
result = db.Column(LONGTEXT) # Execution result JSON
execution_time = db.Column(db.Float, default=0) # Execution time (seconds)
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc)) created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
__table_args__ = (
db.Index("ix_tool_calls_message_call", "message_id", "call_index"),
)
class TokenUsage(db.Model): class TokenUsage(db.Model):
__tablename__ = "token_usage" __tablename__ = "token_usage"
id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) id = db.Column(db.BigInteger, primary_key=True, autoincrement=True)
user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False) user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"),
date = db.Column(db.Date, nullable=False) # 使用日期 nullable=False, index=True)
model = db.Column(db.String(64), nullable=False) # 模型名称 date = db.Column(db.Date, nullable=False, index=True)
model = db.Column(db.String(64), nullable=False)
prompt_tokens = db.Column(db.Integer, default=0) prompt_tokens = db.Column(db.Integer, default=0)
completion_tokens = db.Column(db.Integer, default=0) completion_tokens = db.Column(db.Integer, default=0)
total_tokens = db.Column(db.Integer, default=0) total_tokens = db.Column(db.Integer, default=0)
@ -67,4 +89,5 @@ class TokenUsage(db.Model):
__table_args__ = ( __table_args__ = (
db.UniqueConstraint("user_id", "date", "model", name="uq_user_date_model"), db.UniqueConstraint("user_id", "date", "model", name="uq_user_date_model"),
db.Index("ix_token_usage_date_model", "date", "model"), # Composite index
) )

View File

@ -3,7 +3,7 @@ import json
import uuid import uuid
from flask import current_app, Response from flask import current_app, Response
from backend import db from backend import db
from backend.models import Conversation, Message from backend.models import Conversation, Message, ToolCall
from backend.tools import registry, ToolExecutor from backend.tools import registry, ToolExecutor
from backend.utils.helpers import ( from backend.utils.helpers import (
get_or_create_default_user, get_or_create_default_user,
@ -60,8 +60,7 @@ class ChatService:
prompt_tokens = usage.get("prompt_tokens", 0) prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0) completion_tokens = usage.get("completion_tokens", 0)
merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results) # Create message
msg = Message( msg = Message(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
conversation_id=conv.id, conversation_id=conv.id,
@ -69,16 +68,18 @@ class ChatService:
content=message.get("content", ""), content=message.get("content", ""),
token_count=completion_tokens, token_count=completion_tokens,
thinking_content=message.get("reasoning_content", ""), thinking_content=message.get("reasoning_content", ""),
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
) )
db.session.add(msg) db.session.add(msg)
# Create tool call records
self._save_tool_calls(msg.id, all_tool_calls, all_tool_results)
db.session.commit() db.session.commit()
user = get_or_create_default_user() user = get_or_create_default_user()
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens) record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
return ok({ return ok({
"message": to_dict(msg, thinking_content=msg.thinking_content or None), "message": self._message_to_dict(msg),
"usage": { "usage": {
"prompt_tokens": prompt_tokens, "prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens, "completion_tokens": completion_tokens,
@ -194,8 +195,6 @@ class ChatService:
continue continue
# No tool calls - finish # No tool calls - finish
merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results)
with app.app_context(): with app.app_context():
msg = Message( msg = Message(
id=msg_id, id=msg_id,
@ -204,9 +203,11 @@ class ChatService:
content=full_content, content=full_content,
token_count=token_count, token_count=token_count,
thinking_content=full_thinking, thinking_content=full_thinking,
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
) )
db.session.add(msg) db.session.add(msg)
# Create tool call records
self._save_tool_calls(msg_id, all_tool_calls, all_tool_results)
db.session.commit() db.session.commit()
user = get_or_create_default_user() user = get_or_create_default_user()
@ -223,6 +224,53 @@ class ChatService:
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"} headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}
) )
def _save_tool_calls(self, message_id: str, tool_calls: list, tool_results: list) -> None:
"""Save tool calls to database"""
for i, tc in enumerate(tool_calls):
result_content = tool_results[i]["content"] if i < len(tool_results) else None
# Parse result to extract execution_time if present
execution_time = 0
if result_content:
try:
result_data = json.loads(result_content)
execution_time = result_data.get("execution_time", 0)
except:
pass
tool_call = ToolCall(
message_id=message_id,
call_id=tc.get("id", ""),
call_index=i,
tool_name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
result=result_content,
execution_time=execution_time,
)
db.session.add(tool_call)
def _message_to_dict(self, msg: Message) -> dict:
"""Convert message to dict with tool calls"""
result = to_dict(msg, thinking_content=msg.thinking_content or None)
# Add tool calls if any
tool_calls = msg.tool_calls.all() if msg.tool_calls else []
if tool_calls:
result["tool_calls"] = [
{
"id": tc.call_id,
"type": "function",
"function": {
"name": tc.tool_name,
"arguments": tc.arguments,
},
"result": tc.result,
}
for tc in tool_calls
]
return result
def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list: def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list:
"""Process tool calls from streaming delta""" """Process tool calls from streaming delta"""
tool_calls_delta = delta.get("tool_calls", []) tool_calls_delta = delta.get("tool_calls", [])
@ -242,13 +290,3 @@ class ChatService:
if tc["function"].get("arguments"): if tc["function"].get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"] tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"]
return tool_calls_list return tool_calls_list
def _merge_tool_results(self, tool_calls: list, tool_results: list) -> list:
"""Merge tool results into tool calls"""
merged = []
for i, tc in enumerate(tool_calls):
merged_tc = dict(tc)
if i < len(tool_results):
merged_tc["result"] = tool_results[i]["content"]
merged.append(merged_tc)
return merged

View File

@ -9,7 +9,7 @@ def get_or_create_default_user():
"""Get or create default user""" """Get or create default user"""
user = User.query.filter_by(username="default").first() user = User.query.filter_by(username="default").first()
if not user: if not user:
user = User(username="default", password="") user = User(username="default", password=None)
db.session.add(user) db.session.add(user)
db.session.commit() db.session.commit()
return user return user
@ -39,13 +39,6 @@ def to_dict(inst, **extra):
if k in d and hasattr(d[k], "strftime"): if k in d and hasattr(d[k], "strftime"):
d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ") d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ")
# Parse tool_calls JSON if present
if "tool_calls" in d and d["tool_calls"]:
try:
d["tool_calls"] = json.loads(d["tool_calls"])
except:
pass
# Filter out None values for cleaner API response # Filter out None values for cleaner API response
d = {k: v for k, v in d.items() if v is not None} d = {k: v for k, v in d.items() if v is not None}

View File

@ -526,15 +526,24 @@ GET /api/stats/tokens?period=daily
### ER 关系 ### ER 关系
``` ```
User 1 ── * Conversation 1 ── * Message User 1 ── * Conversation 1 ── * Message 1 ── * ToolCall
``` ```
### User用户
| 字段 | 类型 | 说明 |
| ------------ | --------------- | ----------------------- |
| `id` | bigint | 用户 ID自增 |
| `username` | string(50) | 用户名(唯一) |
| `password` | string(255) | 密码(可为空,第三方登录) |
| `phone` | string(20) | 手机号 |
### Conversation会话 ### Conversation会话
| 字段 | 类型 | 说明 | | 字段 | 类型 | 说明 |
| ------------------ | ------------- | --------------------- | | ------------------ | ------------- | --------------------- |
| `id` | string (UUID) | 会话 ID | | `id` | string (UUID) | 会话 ID |
| `user_id` | string | 所属用户 ID | | `user_id` | bigint | 所属用户 ID |
| `title` | string | 会话标题 | | `title` | string | 会话标题 |
| `model` | string | 使用的模型,默认 `glm-5` | | `model` | string | 使用的模型,默认 `glm-5` |
| `system_prompt` | string | 系统提示词 | | `system_prompt` | string | 系统提示词 |
@ -550,21 +559,51 @@ User 1 ── * Conversation 1 ── * Message
| ------------------ | ------------- | ------------------------------- | | ------------------ | ------------- | ------------------------------- |
| `id` | string (UUID) | 消息 ID | | `id` | string (UUID) | 消息 ID |
| `conversation_id` | string | 所属会话 ID | | `conversation_id` | string | 所属会话 ID |
| `role` | enum | `user` / `assistant` / `system` | | `role` | enum | `user` / `assistant` / `system` / `tool` |
| `content` | string | 消息内容 | | `content` | LONGTEXT | 消息内容 |
| `token_count` | integer | token 消耗数 | | `token_count` | integer | token 消耗数 |
| `thinking_content` | string | 思维链内容(启用时) | | `thinking_content` | LONGTEXT | 思维链内容(启用时) |
| `tool_calls` | array (JSON) | 工具调用信息(含结果),仅 assistant 消息 |
| `created_at` | datetime | 创建时间 | | `created_at` | datetime | 创建时间 |
**说明**:工具调用信息存储在关联的 `ToolCall` 表中,通过 `message.tool_calls` 关系获取。
### ToolCall工具调用
| 字段 | 类型 | 说明 |
| ----------------- | --------------- | --------------------------- |
| `id` | bigint | 调用记录 ID自增 |
| `message_id` | string(64) | 关联的消息 ID |
| `call_id` | string(64) | 工具调用 ID |
| `call_index` | integer | 调用顺序(从 0 开始) |
| `tool_name` | string(64) | 工具名称 |
| `arguments` | LONGTEXT | 调用参数 JSON |
| `result` | LONGTEXT | 执行结果 JSON |
| `execution_time` | float | 执行时间(秒) |
| `created_at` | datetime | 创建时间 |
### TokenUsageToken 使用统计)
| 字段 | 类型 | 说明 |
| ------------------- | ---------- | -------------------------- |
| `id` | bigint | 记录 ID自增 |
| `user_id` | bigint | 用户 ID |
| `date` | date | 统计日期 |
| `model` | string(64) | 模型名称 |
| `prompt_tokens` | integer | 输入 token 数 |
| `completion_tokens` | integer | 输出 token 数 |
| `total_tokens` | integer | 总 token 数 |
| `created_at` | datetime | 创建时间 |
#### 消息类型说明 #### 消息类型说明
**1. 用户消息 (role=user)** **1. 用户消息 (role=user)**
```json ```json
{ {
"id": "msg_001", "id": "msg_001",
"conversation_id": "conv_abc123",
"role": "user", "role": "user",
"content": "北京今天天气怎么样?", "content": "北京今天天气怎么样?",
"token_count": 0,
"created_at": "2026-03-24T10:00:00Z" "created_at": "2026-03-24T10:00:00Z"
} }
``` ```
@ -573,22 +612,23 @@ User 1 ── * Conversation 1 ── * Message
```json ```json
{ {
"id": "msg_002", "id": "msg_002",
"conversation_id": "conv_abc123",
"role": "assistant", "role": "assistant",
"content": "北京今天天气晴朗...", "content": "北京今天天气晴朗...",
"token_count": 50, "token_count": 50,
"thinking_content": "用户想了解天气...", "thinking_content": "用户想了解天气...",
"tool_calls": null,
"created_at": "2026-03-24T10:00:01Z" "created_at": "2026-03-24T10:00:01Z"
} }
``` ```
**3. 助手消息 - 含工具调用 (role=assistant, with tool_calls)** **3. 助手消息 - 含工具调用 (role=assistant, with tool_calls)**
工具调用结果直接合并到 `tool_calls` 数组中,每个调用包含 `result` 字段 工具调用记录存储在独立的 `tool_calls` 表中API 响应时会自动关联并返回
```json ```json
{ {
"id": "msg_003", "id": "msg_003",
"conversation_id": "conv_abc123",
"role": "assistant", "role": "assistant",
"content": "北京今天天气晴朗温度25°C湿度60%", "content": "北京今天天气晴朗温度25°C湿度60%",
"token_count": 80, "token_count": 80,
@ -608,6 +648,18 @@ User 1 ── * Conversation 1 ── * Message
} }
``` ```
**4. 工具消息 (role=tool)**
用于 API 调用时传递工具执行结果(不存储在数据库):
```json
{
"role": "tool",
"tool_call_id": "call_abc123",
"name": "get_weather",
"content": "{\"temperature\": 25, \"humidity\": 60}"
}
```
#### 工具调用流程示例 #### 工具调用流程示例
``` ```
@ -617,14 +669,16 @@ User 1 ── * Conversation 1 ── * Message
[AI 调用工具 get_weather] [AI 调用工具 get_weather]
[msg_002] role=assistant, tool_calls=[{get_weather, args:{"city":"北京"}, result="{...}"}] [msg_002] role=assistant, content="北京今天天气晴朗温度25°C..."
content="北京今天天气晴朗温度25°C..." tool_calls=[{get_weather, args:{"city":"北京"}, result="{...}"}]
``` ```
**说明:** **说明:**
- 工具调用结果直接存储在 `tool_calls[].result` 字段中 - 工具调用记录存储在独立的 `tool_calls` 表中,与 `messages` 表通过 `message_id` 关联
- 不再创建独立的 `role=tool` 消息 - API 响应时自动查询并组装 `tool_calls` 数组
- 前端可通过 `tool_calls` 数组展示完整的工具调用过程 - 工具调用包含完整的调用参数和执行结果
- `call_index` 字段记录同一消息中多次工具调用的顺序
- `execution_time` 字段记录工具执行耗时
--- ---

View File

@ -2,7 +2,7 @@
## 概述 ## 概述
本文档描述 NanoClaw 工具调用系统的设计,采用简化的工厂模式,减少不必要的类层次 NanoClaw 工具调用系统采用简化的工厂模式,支持装饰器注册、缓存优化、重复调用检测等功能
--- ---
@ -27,13 +27,20 @@ classDiagram
+register(ToolDefinition tool) void +register(ToolDefinition tool) void
+get(str name) ToolDefinition? +get(str name) ToolDefinition?
+list_all() list~dict~ +list_all() list~dict~
+execute(str name, dict args) Any +list_by_category(str category) list~dict~
+execute(str name, dict args) dict
+remove(str name) bool
+has(str name) bool
} }
class ToolExecutor { class ToolExecutor {
-ToolRegistry registry -ToolRegistry registry
+process_tool_calls(list tool_calls) list~dict~ -dict _cache
+build_request(list messages) dict -list _call_history
+process_tool_calls(list tool_calls, dict context) list~dict~
+build_request(list messages, str model, list tools, dict kwargs) dict
+clear_history() void
+execute_with_retry(str name, dict args, int max_retries) dict
} }
class ToolResult { class ToolResult {
@ -42,6 +49,8 @@ classDiagram
+Any data +Any data
+str? error +str? error
+dict to_dict() +dict to_dict()
+ok(Any data)$ ToolResult
+fail(str error)$ ToolResult
} }
ToolRegistry "1" --> "*" ToolDefinition : manages ToolRegistry "1" --> "*" ToolDefinition : manages
@ -61,11 +70,8 @@ classDiagram
class ToolFactory { class ToolFactory {
<<module>> <<module>>
+tool(name, description, parameters)$ decorator +tool(name, description, parameters, category)$ decorator
+register(name, handler, description, parameters)$ void +register_tool(name, handler, description, parameters, category)$ void
+create_crawler_tools()$ list~ToolDefinition~
+create_data_tools()$ list~ToolDefinition~
+create_file_tools()$ list~ToolDefinition~
} }
class ToolDefinition { class ToolDefinition {
@ -73,6 +79,7 @@ classDiagram
+str description +str description
+dict parameters +dict parameters
+Callable handler +Callable handler
+str category
} }
ToolFactory ..> ToolDefinition : creates ToolFactory ..> ToolDefinition : creates
@ -153,18 +160,31 @@ class ToolRegistry:
return cls._instance return cls._instance
def register(self, tool: ToolDefinition) -> None: def register(self, tool: ToolDefinition) -> None:
"""注册工具"""
self._tools[tool.name] = tool self._tools[tool.name] = tool
def get(self, name: str) -> Optional[ToolDefinition]: def get(self, name: str) -> Optional[ToolDefinition]:
"""获取工具定义"""
return self._tools.get(name) return self._tools.get(name)
def list_all(self) -> List[dict]: def list_all(self) -> List[dict]:
"""列出所有工具OpenAI 格式)"""
return [t.to_openai_format() for t in self._tools.values()] return [t.to_openai_format() for t in self._tools.values()]
def list_by_category(self, category: str) -> List[dict]:
"""按类别列出工具"""
return [
t.to_openai_format()
for t in self._tools.values()
if t.category == category
]
def execute(self, name: str, arguments: dict) -> dict: def execute(self, name: str, arguments: dict) -> dict:
"""执行工具"""
tool = self.get(name) tool = self.get(name)
if not tool: if not tool:
return ToolResult.fail(f"Tool not found: {name}").to_dict() return ToolResult.fail(f"Tool not found: {name}").to_dict()
try: try:
result = tool.handler(arguments) result = tool.handler(arguments)
if isinstance(result, ToolResult): if isinstance(result, ToolResult):
@ -173,6 +193,17 @@ class ToolRegistry:
except Exception as e: except Exception as e:
return ToolResult.fail(str(e)).to_dict() return ToolResult.fail(str(e)).to_dict()
def remove(self, name: str) -> bool:
"""移除工具"""
if name in self._tools:
del self._tools[name]
return True
return False
def has(self, name: str) -> bool:
"""检查工具是否存在"""
return name in self._tools
# 全局注册表 # 全局注册表
registry = ToolRegistry() registry = ToolRegistry()
@ -182,39 +213,205 @@ registry = ToolRegistry()
```python ```python
import json import json
from typing import List, Dict import time
import hashlib
from typing import List, Dict, Optional
class ToolExecutor: class ToolExecutor:
"""工具执行器""" """工具执行器(支持缓存和重复检测)"""
def __init__(self, registry: ToolRegistry = None): def __init__(
self,
registry: ToolRegistry = None,
api_url: str = None,
api_key: str = None,
enable_cache: bool = True,
cache_ttl: int = 300, # 5分钟
):
self.registry = registry or ToolRegistry() self.registry = registry or ToolRegistry()
self.api_url = api_url
self.api_key = api_key
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
self._cache: Dict[str, tuple] = {} # key -> (result, timestamp)
self._call_history: List[dict] = [] # 当前会话的调用历史
def process_tool_calls(self, tool_calls: List[dict]) -> List[dict]: def _make_cache_key(self, name: str, args: dict) -> str:
"""处理工具调用,返回消息列表""" """生成缓存键"""
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest()
def _get_cached(self, key: str) -> Optional[dict]:
"""获取缓存结果"""
if not self.enable_cache:
return None
if key in self._cache:
result, timestamp = self._cache[key]
if time.time() - timestamp < self.cache_ttl:
return result
del self._cache[key]
return None
def _set_cache(self, key: str, result: dict) -> None:
"""设置缓存"""
if self.enable_cache:
self._cache[key] = (result, time.time())
def _check_duplicate_in_history(self, name: str, args: dict) -> Optional[dict]:
"""检查历史中是否有相同调用"""
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
for record in self._call_history:
if record["name"] == name and record["args_str"] == args_str:
return record["result"]
return None
def clear_history(self) -> None:
"""清空调用历史(新会话开始时调用)"""
self._call_history.clear()
def process_tool_calls(
self,
tool_calls: List[dict],
context: dict = None
) -> List[dict]:
"""
处理工具调用,返回消息列表
Args:
tool_calls: LLM 返回的工具调用列表
context: 可选上下文信息user_id 等)
Returns:
工具响应消息列表,可直接追加到 messages
"""
results = [] results = []
seen_calls = set() # 当前批次内的重复检测
for call in tool_calls: for call in tool_calls:
name = call["function"]["name"] name = call["function"]["name"]
args = json.loads(call["function"]["arguments"]) args_str = call["function"]["arguments"]
call_id = call["id"] call_id = call["id"]
try:
args = json.loads(args_str) if isinstance(args_str, str) else args_str
except json.JSONDecodeError:
results.append(self._create_error_result(
call_id, name, "Invalid JSON arguments"
))
continue
# 检查批次内重复
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
if call_key in seen_calls:
results.append(self._create_tool_result(
call_id, name,
{"success": True, "data": None, "cached": True, "duplicate": True}
))
continue
seen_calls.add(call_key)
# 检查历史重复
history_result = self._check_duplicate_in_history(name, args)
if history_result is not None:
result = {**history_result, "cached": True}
results.append(self._create_tool_result(call_id, name, result))
continue
# 检查缓存
cache_key = self._make_cache_key(name, args)
cached_result = self._get_cached(cache_key)
if cached_result is not None:
result = {**cached_result, "cached": True}
results.append(self._create_tool_result(call_id, name, result))
continue
# 执行工具
result = self.registry.execute(name, args) result = self.registry.execute(name, args)
results.append({ # 缓存结果
"role": "tool", self._set_cache(cache_key, result)
"tool_call_id": call_id,
# 添加到历史
self._call_history.append({
"name": name, "name": name,
"content": json.dumps(result, ensure_ascii=False) "args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
"result": result
}) })
results.append(self._create_tool_result(call_id, name, result))
return results return results
def build_request(self, messages: List[dict], **kwargs) -> dict: def _create_tool_result(
"""构建 API 请求""" self,
call_id: str,
name: str,
result: dict,
execution_time: float = 0
) -> dict:
"""创建工具结果消息"""
result["execution_time"] = execution_time
return { return {
"model": kwargs.get("model", "glm-5"), "role": "tool",
"tool_call_id": call_id,
"name": name,
"content": json.dumps(result, ensure_ascii=False, default=str)
}
def _create_error_result(
self,
call_id: str,
name: str,
error: str
) -> dict:
"""创建错误结果消息"""
return {
"role": "tool",
"tool_call_id": call_id,
"name": name,
"content": json.dumps({
"success": False,
"error": error
}, ensure_ascii=False)
}
def build_request(
self,
messages: List[dict],
model: str = "glm-5",
tools: List[dict] = None,
**kwargs
) -> dict:
"""构建 API 请求体"""
return {
"model": model,
"messages": messages, "messages": messages,
"tools": self.registry.list_all(), "tools": tools or self.registry.list_all(),
"tool_choice": "auto" "tool_choice": kwargs.get("tool_choice", "auto"),
**{k: v for k, v in kwargs.items() if k not in ["tool_choice"]}
}
def execute_with_retry(
self,
name: str,
arguments: dict,
max_retries: int = 3,
retry_delay: float = 1.0
) -> dict:
"""带重试的工具执行"""
last_error = None
for attempt in range(max_retries):
try:
return self.registry.execute(name, arguments)
except Exception as e:
last_error = e
if attempt < max_retries - 1:
time.sleep(retry_delay)
return {
"success": False,
"error": f"Failed after {max_retries} retries: {last_error}"
} }
``` ```
@ -227,11 +424,30 @@ class ToolExecutor:
```python ```python
# backend/tools/factory.py # backend/tools/factory.py
from .core import ToolDefinition, registry from typing import Callable
from backend.tools.core import ToolDefinition, registry
def tool(name: str, description: str, parameters: dict, category: str = "general"):
"""工具注册装饰器""" def tool(
def decorator(func): name: str,
description: str,
parameters: dict,
category: str = "general"
) -> Callable:
"""
工具注册装饰器
用法:
@tool(
name="web_search",
description="搜索互联网获取信息",
parameters={"type": "object", "properties": {...}},
category="crawler"
)
def web_search(arguments: dict) -> dict:
...
"""
def decorator(func: Callable) -> Callable:
tool_def = ToolDefinition( tool_def = ToolDefinition(
name=name, name=name,
description=description, description=description,
@ -242,6 +458,34 @@ def tool(name: str, description: str, parameters: dict, category: str = "general
registry.register(tool_def) registry.register(tool_def)
return func return func
return decorator return decorator
def register_tool(
name: str,
handler: Callable,
description: str,
parameters: dict,
category: str = "general"
) -> None:
"""
直接注册工具(无需装饰器)
用法:
register_tool(
name="my_tool",
handler=my_function,
description="工具描述",
parameters={...}
)
"""
tool_def = ToolDefinition(
name=name,
description=description,
parameters=parameters,
handler=handler,
category=category
)
registry.register(tool_def)
``` ```
### 4.2 使用示例 ### 4.2 使用示例
@ -249,24 +493,32 @@ def tool(name: str, description: str, parameters: dict, category: str = "general
```python ```python
# backend/tools/builtin/crawler.py # backend/tools/builtin/crawler.py
from ..factory import tool from backend.tools.factory import tool
from backend.tools.services import SearchService, FetchService
# 网页搜索工具 # 网页搜索工具
@tool( @tool(
name="web_search", name="web_search",
description="搜索互联网获取信息", description="Search the internet for information. Use when you need to find latest news or answer questions that require web search.",
parameters={ parameters={
"type": "object", "type": "object",
"properties": { "properties": {
"query": {"type": "string", "description": "搜索关键词"}, "query": {
"max_results": {"type": "integer", "default": 5} "type": "string",
"description": "Search keywords"
},
"max_results": {
"type": "integer",
"description": "Number of results to return, default 5",
"default": 5
}
}, },
"required": ["query"] "required": ["query"]
}, },
category="crawler" category="crawler"
) )
def web_search(arguments: dict) -> dict: def web_search(arguments: dict) -> dict:
from ..services import SearchService """Web search tool"""
query = arguments["query"] query = arguments["query"]
max_results = arguments.get("max_results", 5) max_results = arguments.get("max_results", 5)
service = SearchService() service = SearchService()
@ -277,19 +529,27 @@ def web_search(arguments: dict) -> dict:
# 页面抓取工具 # 页面抓取工具
@tool( @tool(
name="fetch_page", name="fetch_page",
description="抓取指定网页内容", description="Fetch content from a specific webpage. Use when user needs detailed information from a webpage.",
parameters={ parameters={
"type": "object", "type": "object",
"properties": { "properties": {
"url": {"type": "string", "description": "网页URL"}, "url": {
"extract_type": {"type": "string", "enum": ["text", "links", "structured"]} "type": "string",
"description": "URL of the webpage to fetch"
},
"extract_type": {
"type": "string",
"description": "Extraction type",
"enum": ["text", "links", "structured"],
"default": "text"
}
}, },
"required": ["url"] "required": ["url"]
}, },
category="crawler" category="crawler"
) )
def fetch_page(arguments: dict) -> dict: def fetch_page(arguments: dict) -> dict:
from ..services import FetchService """Page fetch tool"""
url = arguments["url"] url = arguments["url"]
extract_type = arguments.get("extract_type", "text") extract_type = arguments.get("extract_type", "text")
service = FetchService() service = FetchService()
@ -297,33 +557,39 @@ def fetch_page(arguments: dict) -> dict:
return result return result
# 计算器工具 # 批量抓取工具
@tool( @tool(
name="calculator", name="crawl_batch",
description="执行数学计算", description="Batch fetch multiple webpages. Use when you need to get content from multiple pages at once.",
parameters={ parameters={
"type": "object", "type": "object",
"properties": { "properties": {
"expression": {"type": "string", "description": "数学表达式"} "urls": {
"type": "array",
"items": {"type": "string"},
"description": "List of URLs to fetch"
},
"extract_type": {
"type": "string",
"enum": ["text", "links", "structured"],
"default": "text"
}
}, },
"required": ["expression"] "required": ["urls"]
}, },
category="data" category="crawler"
) )
def calculator(arguments: dict) -> dict: def crawl_batch(arguments: dict) -> dict:
import ast """Batch fetch tool"""
import operator urls = arguments["urls"]
expr = arguments["expression"] extract_type = arguments.get("extract_type", "text")
# 安全计算
ops = { if len(urls) > 10:
ast.Add: operator.add, return {"error": "Maximum 10 pages can be fetched at once"}
ast.Sub: operator.sub,
ast.Mult: operator.mul, service = FetchService()
ast.Div: operator.truediv results = service.fetch_batch(urls, extract_type)
} return {"results": results, "total": len(results)}
node = ast.parse(expr, mode='eval')
result = eval(compile(node, '<string>', 'eval'), {"__builtins__": {}}, ops)
return {"result": result}
``` ```
--- ---
@ -332,90 +598,220 @@ def calculator(arguments: dict) -> dict:
工具依赖的服务保持独立,不与工具类耦合: 工具依赖的服务保持独立,不与工具类耦合:
```mermaid
classDiagram
direction LR
class SearchService {
-SearchEngine engine
+search(str query, int limit) list~dict~
}
class FetchService {
+fetch(str url, str type) dict
+fetch_batch(list urls) dict
}
class ContentExtractor {
+extract_text(html) str
+extract_links(html) list
+extract_structured(html) dict
}
FetchService --> ContentExtractor : uses
```
```python ```python
# backend/tools/services.py # backend/tools/services.py
from typing import List, Dict
from ddgs import DDGS
import re
class SearchService: class SearchService:
"""搜索服务""" """搜索服务"""
def __init__(self, engine=None):
from ddgs import DDGS
self.engine = engine or DDGS()
def search(self, query: str, max_results: int = 5) -> list: def __init__(self, engine: str = "duckduckgo"):
results = list(self.engine.text(query, max_results=max_results)) self.engine = engine
def search(
self,
query: str,
max_results: int = 5,
region: str = "cn-zh"
) -> List[dict]:
"""执行搜索"""
if self.engine == "duckduckgo":
return self._search_duckduckgo(query, max_results, region)
else:
raise ValueError(f"Unsupported search engine: {self.engine}")
def _search_duckduckgo(
self,
query: str,
max_results: int,
region: str
) -> List[dict]:
"""DuckDuckGo 搜索"""
with DDGS() as ddgs:
results = list(ddgs.text(
query,
max_results=max_results,
region=region
))
return [ return [
{"title": r["title"], "url": r["href"], "snippet": r["body"]} {
"title": r.get("title", ""),
"url": r.get("href", ""),
"snippet": r.get("body", "")
}
for r in results for r in results
] ]
class FetchService: class FetchService:
"""页面抓取服务""" """页面抓取服务"""
def __init__(self, timeout: float = 30.0):
def __init__(self, timeout: float = 30.0, user_agent: str = None):
self.timeout = timeout self.timeout = timeout
self.user_agent = user_agent or (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
)
def fetch(self, url: str, extract_type: str = "text") -> dict: def fetch(self, url: str, extract_type: str = "text") -> dict:
"""抓取单个页面"""
import httpx import httpx
from bs4 import BeautifulSoup
resp = httpx.get(url, timeout=self.timeout, follow_redirects=True) try:
soup = BeautifulSoup(resp.text, "html.parser") resp = httpx.get(
url,
timeout=self.timeout,
follow_redirects=True,
headers={"User-Agent": self.user_agent}
)
resp.raise_for_status()
except Exception as e:
return {"error": str(e), "url": url}
html = resp.text
extractor = ContentExtractor(html)
extractor = ContentExtractor(soup)
if extract_type == "text": if extract_type == "text":
return {"text": extractor.extract_text()} return {
"url": url,
"text": extractor.extract_text()
}
elif extract_type == "links": elif extract_type == "links":
return {"links": extractor.extract_links()} return {
"url": url,
"links": extractor.extract_links()
}
else: else:
return extractor.extract_structured() return extractor.extract_structured(url)
def fetch_batch(
self,
urls: List[str],
extract_type: str = "text",
max_concurrent: int = 5
) -> List[dict]:
"""批量抓取页面"""
results = []
for url in urls:
results.append(self.fetch(url, extract_type))
return results
class ContentExtractor: class ContentExtractor:
"""内容提取器""" """内容提取器"""
def __init__(self, soup):
self.soup = soup def __init__(self, html: str):
self.html = html
self._soup = None
@property
def soup(self):
if self._soup is None:
try:
from bs4 import BeautifulSoup
self._soup = BeautifulSoup(self.html, "html.parser")
except ImportError:
raise ImportError("Please install beautifulsoup4: pip install beautifulsoup4")
return self._soup
def extract_text(self) -> str: def extract_text(self) -> str:
"""提取纯文本"""
# 移除脚本和样式 # 移除脚本和样式
for tag in self.soup(["script", "style"]): for tag in self.soup(["script", "style", "nav", "footer", "header"]):
tag.decompose() tag.decompose()
return self.soup.get_text(separator="\n", strip=True)
def extract_links(self) -> list: text = self.soup.get_text(separator="\n", strip=True)
return [ # 清理多余空白
{"text": a.get_text(strip=True), "href": a.get("href")} text = re.sub(r"\n{3,}", "\n\n", text)
for a in self.soup.find_all("a", href=True) return text
]
def extract_links(self) -> List[dict]:
"""提取链接"""
links = []
for a in self.soup.find_all("a", href=True):
text = a.get_text(strip=True)
href = a["href"]
if text and href and not href.startswith(("#", "javascript:")):
links.append({"text": text, "href": href})
return links[:50] # 限制数量
def extract_structured(self, url: str = "") -> dict:
"""提取结构化内容"""
soup = self.soup
# 提取标题
title = ""
if soup.title:
title = soup.title.string or ""
# 提取 meta 描述
description = ""
meta_desc = soup.find("meta", attrs={"name": "description"})
if meta_desc:
description = meta_desc.get("content", "")
def extract_structured(self) -> dict:
return { return {
"title": self.soup.title.string if self.soup.title else "", "url": url,
"text": self.extract_text(), "title": title.strip(),
"description": description.strip(),
"text": self.extract_text()[:5000], # 限制长度
"links": self.extract_links()[:20] "links": self.extract_links()[:20]
} }
class CalculatorService:
"""安全计算服务"""
ALLOWED_OPS = {
"add", "sub", "mul", "truediv", "floordiv",
"mod", "pow", "neg", "abs"
}
def evaluate(self, expression: str) -> dict:
"""安全计算数学表达式"""
import ast
import operator
ops = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.USub: operator.neg,
ast.UAdd: operator.pos,
}
try:
# 解析表达式
node = ast.parse(expression, mode="eval")
# 验证节点类型
for child in ast.walk(node):
if isinstance(child, ast.Call):
return {"error": "Function calls not allowed"}
if isinstance(child, ast.Name):
return {"error": "Variable names not allowed"}
# 安全执行
result = eval(
compile(node, "<string>", "eval"),
{"__builtins__": {}},
{}
)
return {"result": result}
except Exception as e:
return {"error": f"Calculation error: {str(e)}"}
``` ```
--- ---
@ -425,57 +821,154 @@ class ContentExtractor:
```python ```python
# backend/tools/__init__.py # backend/tools/__init__.py
from .core import ToolDefinition, ToolResult, ToolRegistry, registry, ToolExecutor """
from .factory import tool NanoClaw Tool System
def init_tools(): Usage:
"""初始化所有内置工具""" from backend.tools import registry, ToolExecutor, tool
# 导入即自动注册 from backend.tools import init_tools
from .builtin import crawler, data, weather
# 使用时 # 初始化内置工具
init_tools() init_tools()
# 列出所有工具
tools = registry.list_all()
# 执行工具
result = registry.execute("web_search", {"query": "Python"})
"""
from backend.tools.core import ToolDefinition, ToolResult, ToolRegistry, registry
from backend.tools.factory import tool, register_tool
from backend.tools.executor import ToolExecutor
def init_tools() -> None:
"""
初始化所有内置工具
导入 builtin 模块会自动注册所有装饰器定义的工具
"""
from backend.tools.builtin import crawler, data, weather, file_ops # noqa: F401
# 公开 API 导出
__all__ = [
# 核心类
"ToolDefinition",
"ToolResult",
"ToolRegistry",
"ToolExecutor",
# 实例
"registry",
# 工厂函数
"tool",
"register_tool",
# 初始化
"init_tools",
]
``` ```
--- ---
## 七、工具清单 ## 七、工具清单
| 类别 | 工具名称 | 描述 | 依赖服务 | ### 7.1 爬虫工具 (crawler)
| ------- | --------------- | ---- | ------------- |
| crawler | `web_search` | 网页搜索 | SearchService | | 工具名称 | 描述 | 参数 |
| crawler | `fetch_page` | 单页抓取 | FetchService | | --------------- | --------------------------- | --------------------------------------- |
| crawler | `crawl_batch` | 批量爬取 | FetchService | | `web_search` | 搜索互联网获取信息 | `query`: 搜索关键词<br>`max_results`: 结果数量(默认 5 |
| data | `calculator` | 数学计算 | CalculatorService | | `fetch_page` | 抓取单个网页内容 | `url`: 网页 URL<br>`extract_type`: 提取类型text/links/structured |
| data | `text_process` | 文本处理 | - | | `crawl_batch` | 批量抓取多个网页(最多 10 个) | `urls`: URL 列表<br>`extract_type`: 提取类型 |
| data | `json_process` | JSON处理 | - |
| weather | `get_weather` | 天气查询 | - (模拟数据) | ### 7.2 数据处理工具 (data)
| file | `file_read` | 读取文件 | - |
| file | `file_write` | 写入文件 | - | | 工具名称 | 描述 | 参数 |
| file | `file_delete` | 删除文件 | - | | --------------- | --------------------------- | --------------------------------------- |
| file | `file_list` | 列出目录 | - | | `calculator` | 执行数学计算(支持加减乘除、幂、模等) | `expression`: 数学表达式 |
| file | `file_exists` | 检查存在 | - | | `text_process` | 文本处理(计数、格式转换等) | `text`: 文本内容<br>`operation`: 操作类型count/lines/words/upper/lower/reverse |
| file | `file_mkdir` | 创建目录 | - | | `json_process` | JSON 处理(解析、格式化、提取、验证) | `json_string`: JSON 字符串<br>`operation`: 操作类型parse/format/keys/validate |
### 7.3 天气工具 (weather)
| 工具名称 | 描述 | 参数 |
| --------------- | --------------------------- | --------------------------------------- |
| `get_weather` | 查询指定城市的天气信息(模拟数据) | `city`: 城市名称(如:北京、上海、广州) |
### 7.4 文件操作工具 (file)
| 工具名称 | 描述 | 参数 |
| --------------- | --------------------------- | --------------------------------------- |
| `file_read` | 读取文件内容 | `path`: 文件路径<br>`encoding`: 编码(默认 utf-8 |
| `file_write` | 写入文件(支持覆盖和追加) | `path`: 文件路径<br>`content`: 内容<br>`mode`: 写入模式write/append |
| `file_delete` | 删除文件 | `path`: 文件路径 |
| `file_list` | 列出目录内容 | `path`: 目录路径(默认 .<br>`pattern`: 文件模式(默认 * |
| `file_exists` | 检查文件或目录是否存在 | `path`: 路径 |
| `file_mkdir` | 创建目录(自动创建父目录) | `path`: 目录路径 |
**安全说明**:文件操作工具限制在项目根目录内,防止越权访问。
--- ---
## 八、与旧设计对比 ## 八、与旧设计对比
| 方面 | 旧设计 | 新设计 | | 方面 | 旧设计 | 新设计 |
| ----- | ----------------- | --------- | | --------- | ----------------- | ----------------- |
| 类数量 | 30+ | ~10 | | 类数量 | 30+ | ~10 |
| 工具定义 | 继承 BaseTool | 装饰器 + 函数 | | 工具定义 | 继承 BaseTool | 装饰器 + 函数 |
| 中间抽象层 | 5个CrawlerTool 等) | 无 | | 中间抽象层 | 5个CrawlerTool 等) | 无 |
| 扩展方式 | 创建子类 | 写函数 + 装饰器 | | 扩展方式 | 创建子类 | 写函数 + 装饰器 |
| 代码量 | 多 | 少 | | 缓存机制 | 无 | 支持结果缓存TTL 可配置) |
| 重复检测 | 无 | 支持会话内重复调用检测 |
| 代码量 | 多 | 少 |
--- ---
## 九、总结 ## 九、核心特性
简化后的设计: ### 9.1 装饰器注册
简化工具定义,只需一个装饰器:
```python
@tool(
name="my_tool",
description="工具描述",
parameters={...},
category="custom"
)
def my_tool(arguments: dict) -> dict:
# 工具实现
return {"result": "ok"}
```
### 9.2 智能缓存
- **结果缓存**:相同参数的工具调用结果会被缓存(默认 5 分钟)
- **可配置 TTL**:通过 `cache_ttl` 参数设置缓存过期时间
- **可禁用**:通过 `enable_cache=False` 关闭缓存
### 9.3 重复检测
- **批次内去重**:同一批次中相同工具+参数的调用会被跳过
- **历史去重**:同一会话内已调用过的工具会直接返回缓存结果
- **自动清理**:新会话开始时调用 `clear_history()` 清理历史
### 9.4 安全设计
- **计算器安全**:禁止函数调用和变量名,只支持数学运算
- **文件沙箱**:文件操作限制在项目根目录内,防止越权访问
- **错误处理**:所有工具执行都有 try-catch不会因工具错误导致系统崩溃
---
## 十、总结
简化后的设计特点:
1. **核心类**`ToolDefinition`、`ToolRegistry`、`ToolExecutor`、`ToolResult` 1. **核心类**`ToolDefinition`、`ToolRegistry`、`ToolExecutor`、`ToolResult`
2. **工厂模式**:使用 `@tool` 装饰器注册工具 2. **工厂模式**:使用 `@tool` 装饰器注册工具
3. **服务分离**:工具依赖的服务独立,不与工具类耦合 3. **服务分离**:工具依赖的服务独立,不与工具类耦合
4. **易于扩展**:新增工具只需写一个函数并加装饰器 4. **性能优化**:支持缓存和重复检测,减少重复计算和网络请求
5. **易于扩展**:新增工具只需写一个函数并加装饰器
6. **安全可靠**:文件沙箱、安全计算、完善的错误处理