refactor: 工具调用记录迁移至独立表并更新文档
This commit is contained in:
parent
362ab15338
commit
8a23b1cd00
|
|
@ -8,7 +8,7 @@ class User(db.Model):
|
|||
|
||||
id = db.Column(db.BigInteger, primary_key=True, autoincrement=True)
|
||||
username = db.Column(db.String(50), unique=True, nullable=False)
|
||||
password = db.Column(db.String(255), nullable=False)
|
||||
password = db.Column(db.String(255), nullable=True) # Allow NULL for third-party login
|
||||
phone = db.Column(db.String(20))
|
||||
|
||||
conversations = db.relationship("Conversation", backref="user", lazy="dynamic",
|
||||
|
|
@ -20,7 +20,7 @@ class Conversation(db.Model):
|
|||
__tablename__ = "conversations"
|
||||
|
||||
id = db.Column(db.String(64), primary_key=True)
|
||||
user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False)
|
||||
user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = db.Column(db.String(255), nullable=False, default="")
|
||||
model = db.Column(db.String(64), nullable=False, default="glm-5")
|
||||
system_prompt = db.Column(db.Text, default="")
|
||||
|
|
@ -28,7 +28,8 @@ class Conversation(db.Model):
|
|||
max_tokens = db.Column(db.Integer, nullable=False, default=65536)
|
||||
thinking_enabled = db.Column(db.Boolean, default=False)
|
||||
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
updated_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc), index=True)
|
||||
|
||||
messages = db.relationship("Message", backref="conversation", lazy="dynamic",
|
||||
cascade="all, delete-orphan",
|
||||
|
|
@ -39,27 +40,48 @@ class Message(db.Model):
|
|||
__tablename__ = "messages"
|
||||
|
||||
id = db.Column(db.String(64), primary_key=True)
|
||||
conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"), nullable=False)
|
||||
conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"),
|
||||
nullable=False, index=True)
|
||||
role = db.Column(db.String(16), nullable=False) # user, assistant, system, tool
|
||||
content = db.Column(db.Text, default="")
|
||||
content = db.Column(LONGTEXT, default="") # LONGTEXT for long conversations
|
||||
token_count = db.Column(db.Integer, default=0)
|
||||
thinking_content = db.Column(db.Text, default="")
|
||||
thinking_content = db.Column(LONGTEXT, default="")
|
||||
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), index=True)
|
||||
|
||||
# Tool call support
|
||||
tool_calls = db.Column(LONGTEXT) # JSON string: tool call requests (assistant messages)
|
||||
tool_call_id = db.Column(db.String(64)) # Tool call ID (tool messages)
|
||||
name = db.Column(db.String(64)) # Tool name (tool messages)
|
||||
# Tool call support - relation to ToolCall table
|
||||
tool_calls = db.relationship("ToolCall", backref="message", lazy="dynamic",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="ToolCall.call_index.asc()")
|
||||
|
||||
|
||||
class ToolCall(db.Model):
|
||||
"""Tool call record - separate table, follows database normalization"""
|
||||
__tablename__ = "tool_calls"
|
||||
|
||||
id = db.Column(db.BigInteger, primary_key=True, autoincrement=True)
|
||||
message_id = db.Column(db.String(64), db.ForeignKey("messages.id"),
|
||||
nullable=False, index=True)
|
||||
call_id = db.Column(db.String(64), nullable=False) # Tool call ID
|
||||
call_index = db.Column(db.Integer, nullable=False, default=0) # Call order
|
||||
tool_name = db.Column(db.String(64), nullable=False) # Tool name
|
||||
arguments = db.Column(LONGTEXT, nullable=False) # Call arguments JSON
|
||||
result = db.Column(LONGTEXT) # Execution result JSON
|
||||
execution_time = db.Column(db.Float, default=0) # Execution time (seconds)
|
||||
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("ix_tool_calls_message_call", "message_id", "call_index"),
|
||||
)
|
||||
|
||||
|
||||
class TokenUsage(db.Model):
|
||||
__tablename__ = "token_usage"
|
||||
|
||||
id = db.Column(db.BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"), nullable=False)
|
||||
date = db.Column(db.Date, nullable=False) # 使用日期
|
||||
model = db.Column(db.String(64), nullable=False) # 模型名称
|
||||
user_id = db.Column(db.BigInteger, db.ForeignKey("users.id"),
|
||||
nullable=False, index=True)
|
||||
date = db.Column(db.Date, nullable=False, index=True)
|
||||
model = db.Column(db.String(64), nullable=False)
|
||||
prompt_tokens = db.Column(db.Integer, default=0)
|
||||
completion_tokens = db.Column(db.Integer, default=0)
|
||||
total_tokens = db.Column(db.Integer, default=0)
|
||||
|
|
@ -67,4 +89,5 @@ class TokenUsage(db.Model):
|
|||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("user_id", "date", "model", name="uq_user_date_model"),
|
||||
db.Index("ix_token_usage_date_model", "date", "model"), # Composite index
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
import uuid
|
||||
from flask import current_app, Response
|
||||
from backend import db
|
||||
from backend.models import Conversation, Message
|
||||
from backend.models import Conversation, Message, ToolCall
|
||||
from backend.tools import registry, ToolExecutor
|
||||
from backend.utils.helpers import (
|
||||
get_or_create_default_user,
|
||||
|
|
@ -60,8 +60,7 @@ class ChatService:
|
|||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results)
|
||||
|
||||
# Create message
|
||||
msg = Message(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conv.id,
|
||||
|
|
@ -69,16 +68,18 @@ class ChatService:
|
|||
content=message.get("content", ""),
|
||||
token_count=completion_tokens,
|
||||
thinking_content=message.get("reasoning_content", ""),
|
||||
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
||||
)
|
||||
db.session.add(msg)
|
||||
|
||||
# Create tool call records
|
||||
self._save_tool_calls(msg.id, all_tool_calls, all_tool_results)
|
||||
db.session.commit()
|
||||
|
||||
user = get_or_create_default_user()
|
||||
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
|
||||
|
||||
return ok({
|
||||
"message": to_dict(msg, thinking_content=msg.thinking_content or None),
|
||||
"message": self._message_to_dict(msg),
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
|
|
@ -194,8 +195,6 @@ class ChatService:
|
|||
continue
|
||||
|
||||
# No tool calls - finish
|
||||
merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results)
|
||||
|
||||
with app.app_context():
|
||||
msg = Message(
|
||||
id=msg_id,
|
||||
|
|
@ -204,9 +203,11 @@ class ChatService:
|
|||
content=full_content,
|
||||
token_count=token_count,
|
||||
thinking_content=full_thinking,
|
||||
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
||||
)
|
||||
db.session.add(msg)
|
||||
|
||||
# Create tool call records
|
||||
self._save_tool_calls(msg_id, all_tool_calls, all_tool_results)
|
||||
db.session.commit()
|
||||
|
||||
user = get_or_create_default_user()
|
||||
|
|
@ -223,6 +224,53 @@ class ChatService:
|
|||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}
|
||||
)
|
||||
|
||||
def _save_tool_calls(self, message_id: str, tool_calls: list, tool_results: list) -> None:
|
||||
"""Save tool calls to database"""
|
||||
for i, tc in enumerate(tool_calls):
|
||||
result_content = tool_results[i]["content"] if i < len(tool_results) else None
|
||||
|
||||
# Parse result to extract execution_time if present
|
||||
execution_time = 0
|
||||
if result_content:
|
||||
try:
|
||||
result_data = json.loads(result_content)
|
||||
execution_time = result_data.get("execution_time", 0)
|
||||
except:
|
||||
pass
|
||||
|
||||
tool_call = ToolCall(
|
||||
message_id=message_id,
|
||||
call_id=tc.get("id", ""),
|
||||
call_index=i,
|
||||
tool_name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
result=result_content,
|
||||
execution_time=execution_time,
|
||||
)
|
||||
db.session.add(tool_call)
|
||||
|
||||
def _message_to_dict(self, msg: Message) -> dict:
|
||||
"""Convert message to dict with tool calls"""
|
||||
result = to_dict(msg, thinking_content=msg.thinking_content or None)
|
||||
|
||||
# Add tool calls if any
|
||||
tool_calls = msg.tool_calls.all() if msg.tool_calls else []
|
||||
if tool_calls:
|
||||
result["tool_calls"] = [
|
||||
{
|
||||
"id": tc.call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.tool_name,
|
||||
"arguments": tc.arguments,
|
||||
},
|
||||
"result": tc.result,
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list:
|
||||
"""Process tool calls from streaming delta"""
|
||||
tool_calls_delta = delta.get("tool_calls", [])
|
||||
|
|
@ -242,13 +290,3 @@ class ChatService:
|
|||
if tc["function"].get("arguments"):
|
||||
tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"]
|
||||
return tool_calls_list
|
||||
|
||||
def _merge_tool_results(self, tool_calls: list, tool_results: list) -> list:
|
||||
"""Merge tool results into tool calls"""
|
||||
merged = []
|
||||
for i, tc in enumerate(tool_calls):
|
||||
merged_tc = dict(tc)
|
||||
if i < len(tool_results):
|
||||
merged_tc["result"] = tool_results[i]["content"]
|
||||
merged.append(merged_tc)
|
||||
return merged
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def get_or_create_default_user():
|
|||
"""Get or create default user"""
|
||||
user = User.query.filter_by(username="default").first()
|
||||
if not user:
|
||||
user = User(username="default", password="")
|
||||
user = User(username="default", password=None)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
return user
|
||||
|
|
@ -39,13 +39,6 @@ def to_dict(inst, **extra):
|
|||
if k in d and hasattr(d[k], "strftime"):
|
||||
d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# Parse tool_calls JSON if present
|
||||
if "tool_calls" in d and d["tool_calls"]:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(d["tool_calls"])
|
||||
except:
|
||||
pass
|
||||
|
||||
# Filter out None values for cleaner API response
|
||||
d = {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
|
|
|
|||
|
|
@ -526,15 +526,24 @@ GET /api/stats/tokens?period=daily
|
|||
### ER 关系
|
||||
|
||||
```
|
||||
User 1 ── * Conversation 1 ── * Message
|
||||
User 1 ── * Conversation 1 ── * Message 1 ── * ToolCall
|
||||
```
|
||||
|
||||
### User(用户)
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
| ------------ | --------------- | ----------------------- |
|
||||
| `id` | bigint | 用户 ID(自增) |
|
||||
| `username` | string(50) | 用户名(唯一) |
|
||||
| `password` | string(255) | 密码(可为空,第三方登录) |
|
||||
| `phone` | string(20) | 手机号 |
|
||||
|
||||
### Conversation(会话)
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
| ------------------ | ------------- | --------------------- |
|
||||
| `id` | string (UUID) | 会话 ID |
|
||||
| `user_id` | string | 所属用户 ID |
|
||||
| `user_id` | bigint | 所属用户 ID |
|
||||
| `title` | string | 会话标题 |
|
||||
| `model` | string | 使用的模型,默认 `glm-5` |
|
||||
| `system_prompt` | string | 系统提示词 |
|
||||
|
|
@ -550,11 +559,39 @@ User 1 ── * Conversation 1 ── * Message
|
|||
| ------------------ | ------------- | ------------------------------- |
|
||||
| `id` | string (UUID) | 消息 ID |
|
||||
| `conversation_id` | string | 所属会话 ID |
|
||||
| `role` | enum | `user` / `assistant` / `system` |
|
||||
| `content` | string | 消息内容 |
|
||||
| `role` | enum | `user` / `assistant` / `system` / `tool` |
|
||||
| `content` | LONGTEXT | 消息内容 |
|
||||
| `token_count` | integer | token 消耗数 |
|
||||
| `thinking_content` | string | 思维链内容(启用时) |
|
||||
| `tool_calls` | array (JSON) | 工具调用信息(含结果),仅 assistant 消息 |
|
||||
| `thinking_content` | LONGTEXT | 思维链内容(启用时) |
|
||||
| `created_at` | datetime | 创建时间 |
|
||||
|
||||
**说明**:工具调用信息存储在关联的 `ToolCall` 表中,通过 `message.tool_calls` 关系获取。
|
||||
|
||||
### ToolCall(工具调用)
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
| ----------------- | --------------- | --------------------------- |
|
||||
| `id` | bigint | 调用记录 ID(自增) |
|
||||
| `message_id` | string(64) | 关联的消息 ID |
|
||||
| `call_id` | string(64) | 工具调用 ID |
|
||||
| `call_index` | integer | 调用顺序(从 0 开始) |
|
||||
| `tool_name` | string(64) | 工具名称 |
|
||||
| `arguments` | LONGTEXT | 调用参数 JSON |
|
||||
| `result` | LONGTEXT | 执行结果 JSON |
|
||||
| `execution_time` | float | 执行时间(秒) |
|
||||
| `created_at` | datetime | 创建时间 |
|
||||
|
||||
### TokenUsage(Token 使用统计)
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
| ------------------- | ---------- | -------------------------- |
|
||||
| `id` | bigint | 记录 ID(自增) |
|
||||
| `user_id` | bigint | 用户 ID |
|
||||
| `date` | date | 统计日期 |
|
||||
| `model` | string(64) | 模型名称 |
|
||||
| `prompt_tokens` | integer | 输入 token 数 |
|
||||
| `completion_tokens` | integer | 输出 token 数 |
|
||||
| `total_tokens` | integer | 总 token 数 |
|
||||
| `created_at` | datetime | 创建时间 |
|
||||
|
||||
#### 消息类型说明
|
||||
|
|
@ -563,8 +600,10 @@ User 1 ── * Conversation 1 ── * Message
|
|||
```json
|
||||
{
|
||||
"id": "msg_001",
|
||||
"conversation_id": "conv_abc123",
|
||||
"role": "user",
|
||||
"content": "北京今天天气怎么样?",
|
||||
"token_count": 0,
|
||||
"created_at": "2026-03-24T10:00:00Z"
|
||||
}
|
||||
```
|
||||
|
|
@ -573,22 +612,23 @@ User 1 ── * Conversation 1 ── * Message
|
|||
```json
|
||||
{
|
||||
"id": "msg_002",
|
||||
"conversation_id": "conv_abc123",
|
||||
"role": "assistant",
|
||||
"content": "北京今天天气晴朗...",
|
||||
"token_count": 50,
|
||||
"thinking_content": "用户想了解天气...",
|
||||
"tool_calls": null,
|
||||
"created_at": "2026-03-24T10:00:01Z"
|
||||
}
|
||||
```
|
||||
|
||||
**3. 助手消息 - 含工具调用 (role=assistant, with tool_calls)**
|
||||
|
||||
工具调用结果直接合并到 `tool_calls` 数组中,每个调用包含 `result` 字段:
|
||||
工具调用记录存储在独立的 `tool_calls` 表中,API 响应时会自动关联并返回:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg_003",
|
||||
"conversation_id": "conv_abc123",
|
||||
"role": "assistant",
|
||||
"content": "北京今天天气晴朗,温度25°C,湿度60%",
|
||||
"token_count": 80,
|
||||
|
|
@ -608,6 +648,18 @@ User 1 ── * Conversation 1 ── * Message
|
|||
}
|
||||
```
|
||||
|
||||
**4. 工具消息 (role=tool)**
|
||||
|
||||
用于 API 调用时传递工具执行结果(不存储在数据库):
|
||||
```json
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc123",
|
||||
"name": "get_weather",
|
||||
"content": "{\"temperature\": 25, \"humidity\": 60}"
|
||||
}
|
||||
```
|
||||
|
||||
#### 工具调用流程示例
|
||||
|
||||
```
|
||||
|
|
@ -617,14 +669,16 @@ User 1 ── * Conversation 1 ── * Message
|
|||
↓
|
||||
[AI 调用工具 get_weather]
|
||||
↓
|
||||
[msg_002] role=assistant, tool_calls=[{get_weather, args:{"city":"北京"}, result="{...}"}]
|
||||
content="北京今天天气晴朗,温度25°C..."
|
||||
[msg_002] role=assistant, content="北京今天天气晴朗,温度25°C..."
|
||||
tool_calls=[{get_weather, args:{"city":"北京"}, result="{...}"}]
|
||||
```
|
||||
|
||||
**说明:**
|
||||
- 工具调用结果直接存储在 `tool_calls[].result` 字段中
|
||||
- 不再创建独立的 `role=tool` 消息
|
||||
- 前端可通过 `tool_calls` 数组展示完整的工具调用过程
|
||||
- 工具调用记录存储在独立的 `tool_calls` 表中,与 `messages` 表通过 `message_id` 关联
|
||||
- API 响应时自动查询并组装 `tool_calls` 数组
|
||||
- 工具调用包含完整的调用参数和执行结果
|
||||
- `call_index` 字段记录同一消息中多次工具调用的顺序
|
||||
- `execution_time` 字段记录工具执行耗时
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
## 概述
|
||||
|
||||
本文档描述 NanoClaw 工具调用系统的设计,采用简化的工厂模式,减少不必要的类层次。
|
||||
NanoClaw 工具调用系统采用简化的工厂模式,支持装饰器注册、缓存优化、重复调用检测等功能。
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -27,13 +27,20 @@ classDiagram
|
|||
+register(ToolDefinition tool) void
|
||||
+get(str name) ToolDefinition?
|
||||
+list_all() list~dict~
|
||||
+execute(str name, dict args) Any
|
||||
+list_by_category(str category) list~dict~
|
||||
+execute(str name, dict args) dict
|
||||
+remove(str name) bool
|
||||
+has(str name) bool
|
||||
}
|
||||
|
||||
class ToolExecutor {
|
||||
-ToolRegistry registry
|
||||
+process_tool_calls(list tool_calls) list~dict~
|
||||
+build_request(list messages) dict
|
||||
-dict _cache
|
||||
-list _call_history
|
||||
+process_tool_calls(list tool_calls, dict context) list~dict~
|
||||
+build_request(list messages, str model, list tools, dict kwargs) dict
|
||||
+clear_history() void
|
||||
+execute_with_retry(str name, dict args, int max_retries) dict
|
||||
}
|
||||
|
||||
class ToolResult {
|
||||
|
|
@ -42,6 +49,8 @@ classDiagram
|
|||
+Any data
|
||||
+str? error
|
||||
+dict to_dict()
|
||||
+ok(Any data)$ ToolResult
|
||||
+fail(str error)$ ToolResult
|
||||
}
|
||||
|
||||
ToolRegistry "1" --> "*" ToolDefinition : manages
|
||||
|
|
@ -61,11 +70,8 @@ classDiagram
|
|||
|
||||
class ToolFactory {
|
||||
<<module>>
|
||||
+tool(name, description, parameters)$ decorator
|
||||
+register(name, handler, description, parameters)$ void
|
||||
+create_crawler_tools()$ list~ToolDefinition~
|
||||
+create_data_tools()$ list~ToolDefinition~
|
||||
+create_file_tools()$ list~ToolDefinition~
|
||||
+tool(name, description, parameters, category)$ decorator
|
||||
+register_tool(name, handler, description, parameters, category)$ void
|
||||
}
|
||||
|
||||
class ToolDefinition {
|
||||
|
|
@ -73,6 +79,7 @@ classDiagram
|
|||
+str description
|
||||
+dict parameters
|
||||
+Callable handler
|
||||
+str category
|
||||
}
|
||||
|
||||
ToolFactory ..> ToolDefinition : creates
|
||||
|
|
@ -153,18 +160,31 @@ class ToolRegistry:
|
|||
return cls._instance
|
||||
|
||||
def register(self, tool: ToolDefinition) -> None:
|
||||
"""注册工具"""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> Optional[ToolDefinition]:
|
||||
"""获取工具定义"""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_all(self) -> List[dict]:
|
||||
"""列出所有工具(OpenAI 格式)"""
|
||||
return [t.to_openai_format() for t in self._tools.values()]
|
||||
|
||||
def list_by_category(self, category: str) -> List[dict]:
|
||||
"""按类别列出工具"""
|
||||
return [
|
||||
t.to_openai_format()
|
||||
for t in self._tools.values()
|
||||
if t.category == category
|
||||
]
|
||||
|
||||
def execute(self, name: str, arguments: dict) -> dict:
|
||||
"""执行工具"""
|
||||
tool = self.get(name)
|
||||
if not tool:
|
||||
return ToolResult.fail(f"Tool not found: {name}").to_dict()
|
||||
|
||||
try:
|
||||
result = tool.handler(arguments)
|
||||
if isinstance(result, ToolResult):
|
||||
|
|
@ -173,6 +193,17 @@ class ToolRegistry:
|
|||
except Exception as e:
|
||||
return ToolResult.fail(str(e)).to_dict()
|
||||
|
||||
def remove(self, name: str) -> bool:
|
||||
"""移除工具"""
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
"""检查工具是否存在"""
|
||||
return name in self._tools
|
||||
|
||||
|
||||
# 全局注册表
|
||||
registry = ToolRegistry()
|
||||
|
|
@ -182,39 +213,205 @@ registry = ToolRegistry()
|
|||
|
||||
```python
|
||||
import json
|
||||
from typing import List, Dict
|
||||
import time
|
||||
import hashlib
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
class ToolExecutor:
|
||||
"""工具执行器"""
|
||||
"""工具执行器(支持缓存和重复检测)"""
|
||||
|
||||
def __init__(self, registry: ToolRegistry = None):
|
||||
def __init__(
|
||||
self,
|
||||
registry: ToolRegistry = None,
|
||||
api_url: str = None,
|
||||
api_key: str = None,
|
||||
enable_cache: bool = True,
|
||||
cache_ttl: int = 300, # 5分钟
|
||||
):
|
||||
self.registry = registry or ToolRegistry()
|
||||
self.api_url = api_url
|
||||
self.api_key = api_key
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self._cache: Dict[str, tuple] = {} # key -> (result, timestamp)
|
||||
self._call_history: List[dict] = [] # 当前会话的调用历史
|
||||
|
||||
def process_tool_calls(self, tool_calls: List[dict]) -> List[dict]:
|
||||
"""处理工具调用,返回消息列表"""
|
||||
def _make_cache_key(self, name: str, args: dict) -> str:
|
||||
"""生成缓存键"""
|
||||
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest()
|
||||
|
||||
def _get_cached(self, key: str) -> Optional[dict]:
|
||||
"""获取缓存结果"""
|
||||
if not self.enable_cache:
|
||||
return None
|
||||
if key in self._cache:
|
||||
result, timestamp = self._cache[key]
|
||||
if time.time() - timestamp < self.cache_ttl:
|
||||
return result
|
||||
del self._cache[key]
|
||||
return None
|
||||
|
||||
def _set_cache(self, key: str, result: dict) -> None:
|
||||
"""设置缓存"""
|
||||
if self.enable_cache:
|
||||
self._cache[key] = (result, time.time())
|
||||
|
||||
def _check_duplicate_in_history(self, name: str, args: dict) -> Optional[dict]:
|
||||
"""检查历史中是否有相同调用"""
|
||||
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||
for record in self._call_history:
|
||||
if record["name"] == name and record["args_str"] == args_str:
|
||||
return record["result"]
|
||||
return None
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""清空调用历史(新会话开始时调用)"""
|
||||
self._call_history.clear()
|
||||
|
||||
def process_tool_calls(
|
||||
self,
|
||||
tool_calls: List[dict],
|
||||
context: dict = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
处理工具调用,返回消息列表
|
||||
|
||||
Args:
|
||||
tool_calls: LLM 返回的工具调用列表
|
||||
context: 可选上下文信息(user_id 等)
|
||||
|
||||
Returns:
|
||||
工具响应消息列表,可直接追加到 messages
|
||||
"""
|
||||
results = []
|
||||
seen_calls = set() # 当前批次内的重复检测
|
||||
|
||||
for call in tool_calls:
|
||||
name = call["function"]["name"]
|
||||
args = json.loads(call["function"]["arguments"])
|
||||
args_str = call["function"]["arguments"]
|
||||
call_id = call["id"]
|
||||
|
||||
try:
|
||||
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||
except json.JSONDecodeError:
|
||||
results.append(self._create_error_result(
|
||||
call_id, name, "Invalid JSON arguments"
|
||||
))
|
||||
continue
|
||||
|
||||
# 检查批次内重复
|
||||
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
|
||||
if call_key in seen_calls:
|
||||
results.append(self._create_tool_result(
|
||||
call_id, name,
|
||||
{"success": True, "data": None, "cached": True, "duplicate": True}
|
||||
))
|
||||
continue
|
||||
seen_calls.add(call_key)
|
||||
|
||||
# 检查历史重复
|
||||
history_result = self._check_duplicate_in_history(name, args)
|
||||
if history_result is not None:
|
||||
result = {**history_result, "cached": True}
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
continue
|
||||
|
||||
# 检查缓存
|
||||
cache_key = self._make_cache_key(name, args)
|
||||
cached_result = self._get_cached(cache_key)
|
||||
if cached_result is not None:
|
||||
result = {**cached_result, "cached": True}
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
continue
|
||||
|
||||
# 执行工具
|
||||
result = self.registry.execute(name, args)
|
||||
|
||||
results.append({
|
||||
# 缓存结果
|
||||
self._set_cache(cache_key, result)
|
||||
|
||||
# 添加到历史
|
||||
self._call_history.append({
|
||||
"name": name,
|
||||
"args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
|
||||
"result": result
|
||||
})
|
||||
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
|
||||
return results
|
||||
|
||||
def _create_tool_result(
|
||||
self,
|
||||
call_id: str,
|
||||
name: str,
|
||||
result: dict,
|
||||
execution_time: float = 0
|
||||
) -> dict:
|
||||
"""创建工具结果消息"""
|
||||
result["execution_time"] = execution_time
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": name,
|
||||
"content": json.dumps(result, ensure_ascii=False)
|
||||
})
|
||||
return results
|
||||
"content": json.dumps(result, ensure_ascii=False, default=str)
|
||||
}
|
||||
|
||||
def build_request(self, messages: List[dict], **kwargs) -> dict:
|
||||
"""构建 API 请求"""
|
||||
def _create_error_result(
|
||||
self,
|
||||
call_id: str,
|
||||
name: str,
|
||||
error: str
|
||||
) -> dict:
|
||||
"""创建错误结果消息"""
|
||||
return {
|
||||
"model": kwargs.get("model", "glm-5"),
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": name,
|
||||
"content": json.dumps({
|
||||
"success": False,
|
||||
"error": error
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
|
||||
def build_request(
|
||||
self,
|
||||
messages: List[dict],
|
||||
model: str = "glm-5",
|
||||
tools: List[dict] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""构建 API 请求体"""
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": self.registry.list_all(),
|
||||
"tool_choice": "auto"
|
||||
"tools": tools or self.registry.list_all(),
|
||||
"tool_choice": kwargs.get("tool_choice", "auto"),
|
||||
**{k: v for k, v in kwargs.items() if k not in ["tool_choice"]}
|
||||
}
|
||||
|
||||
def execute_with_retry(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0
|
||||
) -> dict:
|
||||
"""带重试的工具执行"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.registry.execute(name, arguments)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed after {max_retries} retries: {last_error}"
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -227,11 +424,30 @@ class ToolExecutor:
|
|||
```python
|
||||
# backend/tools/factory.py
|
||||
|
||||
from .core import ToolDefinition, registry
|
||||
from typing import Callable
|
||||
from backend.tools.core import ToolDefinition, registry
|
||||
|
||||
def tool(name: str, description: str, parameters: dict, category: str = "general"):
|
||||
"""工具注册装饰器"""
|
||||
def decorator(func):
|
||||
|
||||
def tool(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict,
|
||||
category: str = "general"
|
||||
) -> Callable:
|
||||
"""
|
||||
工具注册装饰器
|
||||
|
||||
用法:
|
||||
@tool(
|
||||
name="web_search",
|
||||
description="搜索互联网获取信息",
|
||||
parameters={"type": "object", "properties": {...}},
|
||||
category="crawler"
|
||||
)
|
||||
def web_search(arguments: dict) -> dict:
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
tool_def = ToolDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
|
|
@ -242,6 +458,34 @@ def tool(name: str, description: str, parameters: dict, category: str = "general
|
|||
registry.register(tool_def)
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def register_tool(
|
||||
name: str,
|
||||
handler: Callable,
|
||||
description: str,
|
||||
parameters: dict,
|
||||
category: str = "general"
|
||||
) -> None:
|
||||
"""
|
||||
直接注册工具(无需装饰器)
|
||||
|
||||
用法:
|
||||
register_tool(
|
||||
name="my_tool",
|
||||
handler=my_function,
|
||||
description="工具描述",
|
||||
parameters={...}
|
||||
)
|
||||
"""
|
||||
tool_def = ToolDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
handler=handler,
|
||||
category=category
|
||||
)
|
||||
registry.register(tool_def)
|
||||
```
|
||||
|
||||
### 4.2 使用示例
|
||||
|
|
@ -249,24 +493,32 @@ def tool(name: str, description: str, parameters: dict, category: str = "general
|
|||
```python
|
||||
# backend/tools/builtin/crawler.py
|
||||
|
||||
from ..factory import tool
|
||||
from backend.tools.factory import tool
|
||||
from backend.tools.services import SearchService, FetchService
|
||||
|
||||
# 网页搜索工具
|
||||
@tool(
|
||||
name="web_search",
|
||||
description="搜索互联网获取信息",
|
||||
description="Search the internet for information. Use when you need to find latest news or answer questions that require web search.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"},
|
||||
"max_results": {"type": "integer", "default": 5}
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search keywords"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return, default 5",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
},
|
||||
category="crawler"
|
||||
)
|
||||
def web_search(arguments: dict) -> dict:
|
||||
from ..services import SearchService
|
||||
"""Web search tool"""
|
||||
query = arguments["query"]
|
||||
max_results = arguments.get("max_results", 5)
|
||||
service = SearchService()
|
||||
|
|
@ -277,19 +529,27 @@ def web_search(arguments: dict) -> dict:
|
|||
# 页面抓取工具
|
||||
@tool(
|
||||
name="fetch_page",
|
||||
description="抓取指定网页内容",
|
||||
description="Fetch content from a specific webpage. Use when user needs detailed information from a webpage.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "网页URL"},
|
||||
"extract_type": {"type": "string", "enum": ["text", "links", "structured"]}
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to fetch"
|
||||
},
|
||||
"extract_type": {
|
||||
"type": "string",
|
||||
"description": "Extraction type",
|
||||
"enum": ["text", "links", "structured"],
|
||||
"default": "text"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
},
|
||||
category="crawler"
|
||||
)
|
||||
def fetch_page(arguments: dict) -> dict:
|
||||
from ..services import FetchService
|
||||
"""Page fetch tool"""
|
||||
url = arguments["url"]
|
||||
extract_type = arguments.get("extract_type", "text")
|
||||
service = FetchService()
|
||||
|
|
@ -297,33 +557,39 @@ def fetch_page(arguments: dict) -> dict:
|
|||
return result
|
||||
|
||||
|
||||
# 计算器工具
|
||||
# 批量抓取工具
|
||||
@tool(
|
||||
name="calculator",
|
||||
description="执行数学计算",
|
||||
name="crawl_batch",
|
||||
description="Batch fetch multiple webpages. Use when you need to get content from multiple pages at once.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "数学表达式"}
|
||||
"urls": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of URLs to fetch"
|
||||
},
|
||||
"required": ["expression"]
|
||||
},
|
||||
category="data"
|
||||
)
|
||||
def calculator(arguments: dict) -> dict:
|
||||
import ast
|
||||
import operator
|
||||
expr = arguments["expression"]
|
||||
# 安全计算
|
||||
ops = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv
|
||||
"extract_type": {
|
||||
"type": "string",
|
||||
"enum": ["text", "links", "structured"],
|
||||
"default": "text"
|
||||
}
|
||||
node = ast.parse(expr, mode='eval')
|
||||
result = eval(compile(node, '<string>', 'eval'), {"__builtins__": {}}, ops)
|
||||
return {"result": result}
|
||||
},
|
||||
"required": ["urls"]
|
||||
},
|
||||
category="crawler"
|
||||
)
|
||||
def crawl_batch(arguments: dict) -> dict:
|
||||
"""Batch fetch tool"""
|
||||
urls = arguments["urls"]
|
||||
extract_type = arguments.get("extract_type", "text")
|
||||
|
||||
if len(urls) > 10:
|
||||
return {"error": "Maximum 10 pages can be fetched at once"}
|
||||
|
||||
service = FetchService()
|
||||
results = service.fetch_batch(urls, extract_type)
|
||||
return {"results": results, "total": len(results)}
|
||||
```
|
||||
|
||||
---
|
||||
|
|
@ -332,90 +598,220 @@ def calculator(arguments: dict) -> dict:
|
|||
|
||||
工具依赖的服务保持独立,不与工具类耦合:
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
direction LR
|
||||
|
||||
class SearchService {
|
||||
-SearchEngine engine
|
||||
+search(str query, int limit) list~dict~
|
||||
}
|
||||
|
||||
class FetchService {
|
||||
+fetch(str url, str type) dict
|
||||
+fetch_batch(list urls) dict
|
||||
}
|
||||
|
||||
class ContentExtractor {
|
||||
+extract_text(html) str
|
||||
+extract_links(html) list
|
||||
+extract_structured(html) dict
|
||||
}
|
||||
|
||||
FetchService --> ContentExtractor : uses
|
||||
```
|
||||
|
||||
```python
|
||||
# backend/tools/services.py
|
||||
|
||||
from typing import List, Dict
|
||||
from ddgs import DDGS
|
||||
import re
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""搜索服务"""
|
||||
def __init__(self, engine=None):
|
||||
from ddgs import DDGS
|
||||
self.engine = engine or DDGS()
|
||||
|
||||
def search(self, query: str, max_results: int = 5) -> list:
|
||||
results = list(self.engine.text(query, max_results=max_results))
|
||||
def __init__(self, engine: str = "duckduckgo"):
|
||||
self.engine = engine
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "cn-zh"
|
||||
) -> List[dict]:
|
||||
"""执行搜索"""
|
||||
if self.engine == "duckduckgo":
|
||||
return self._search_duckduckgo(query, max_results, region)
|
||||
else:
|
||||
raise ValueError(f"Unsupported search engine: {self.engine}")
|
||||
|
||||
def _search_duckduckgo(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int,
|
||||
region: str
|
||||
) -> List[dict]:
|
||||
"""DuckDuckGo 搜索"""
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text(
|
||||
query,
|
||||
max_results=max_results,
|
||||
region=region
|
||||
))
|
||||
|
||||
return [
|
||||
{"title": r["title"], "url": r["href"], "snippet": r["body"]}
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("href", ""),
|
||||
"snippet": r.get("body", "")
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
class FetchService:
|
||||
"""页面抓取服务"""
|
||||
def __init__(self, timeout: float = 30.0):
|
||||
|
||||
def __init__(self, timeout: float = 30.0, user_agent: str = None):
|
||||
self.timeout = timeout
|
||||
self.user_agent = user_agent or (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/120.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
def fetch(self, url: str, extract_type: str = "text") -> dict:
|
||||
"""抓取单个页面"""
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
resp = httpx.get(url, timeout=self.timeout, follow_redirects=True)
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
try:
|
||||
resp = httpx.get(
|
||||
url,
|
||||
timeout=self.timeout,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": self.user_agent}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except Exception as e:
|
||||
return {"error": str(e), "url": url}
|
||||
|
||||
html = resp.text
|
||||
extractor = ContentExtractor(html)
|
||||
|
||||
extractor = ContentExtractor(soup)
|
||||
if extract_type == "text":
|
||||
return {"text": extractor.extract_text()}
|
||||
return {
|
||||
"url": url,
|
||||
"text": extractor.extract_text()
|
||||
}
|
||||
elif extract_type == "links":
|
||||
return {"links": extractor.extract_links()}
|
||||
return {
|
||||
"url": url,
|
||||
"links": extractor.extract_links()
|
||||
}
|
||||
else:
|
||||
return extractor.extract_structured()
|
||||
return extractor.extract_structured(url)
|
||||
|
||||
def fetch_batch(
|
||||
self,
|
||||
urls: List[str],
|
||||
extract_type: str = "text",
|
||||
max_concurrent: int = 5
|
||||
) -> List[dict]:
|
||||
"""批量抓取页面"""
|
||||
results = []
|
||||
for url in urls:
|
||||
results.append(self.fetch(url, extract_type))
|
||||
return results
|
||||
|
||||
|
||||
class ContentExtractor:
|
||||
"""内容提取器"""
|
||||
def __init__(self, soup):
|
||||
self.soup = soup
|
||||
|
||||
def __init__(self, html: str):
|
||||
self.html = html
|
||||
self._soup = None
|
||||
|
||||
@property
|
||||
def soup(self):
|
||||
if self._soup is None:
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
self._soup = BeautifulSoup(self.html, "html.parser")
|
||||
except ImportError:
|
||||
raise ImportError("Please install beautifulsoup4: pip install beautifulsoup4")
|
||||
return self._soup
|
||||
|
||||
def extract_text(self) -> str:
|
||||
"""提取纯文本"""
|
||||
# 移除脚本和样式
|
||||
for tag in self.soup(["script", "style"]):
|
||||
for tag in self.soup(["script", "style", "nav", "footer", "header"]):
|
||||
tag.decompose()
|
||||
return self.soup.get_text(separator="\n", strip=True)
|
||||
|
||||
def extract_links(self) -> list:
|
||||
return [
|
||||
{"text": a.get_text(strip=True), "href": a.get("href")}
|
||||
for a in self.soup.find_all("a", href=True)
|
||||
]
|
||||
text = self.soup.get_text(separator="\n", strip=True)
|
||||
# 清理多余空白
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text
|
||||
|
||||
def extract_links(self) -> List[dict]:
|
||||
"""提取链接"""
|
||||
links = []
|
||||
for a in self.soup.find_all("a", href=True):
|
||||
text = a.get_text(strip=True)
|
||||
href = a["href"]
|
||||
if text and href and not href.startswith(("#", "javascript:")):
|
||||
links.append({"text": text, "href": href})
|
||||
return links[:50] # 限制数量
|
||||
|
||||
def extract_structured(self, url: str = "") -> dict:
|
||||
"""提取结构化内容"""
|
||||
soup = self.soup
|
||||
|
||||
# 提取标题
|
||||
title = ""
|
||||
if soup.title:
|
||||
title = soup.title.string or ""
|
||||
|
||||
# 提取 meta 描述
|
||||
description = ""
|
||||
meta_desc = soup.find("meta", attrs={"name": "description"})
|
||||
if meta_desc:
|
||||
description = meta_desc.get("content", "")
|
||||
|
||||
def extract_structured(self) -> dict:
|
||||
return {
|
||||
"title": self.soup.title.string if self.soup.title else "",
|
||||
"text": self.extract_text(),
|
||||
"url": url,
|
||||
"title": title.strip(),
|
||||
"description": description.strip(),
|
||||
"text": self.extract_text()[:5000], # 限制长度
|
||||
"links": self.extract_links()[:20]
|
||||
}
|
||||
|
||||
|
||||
class CalculatorService:
|
||||
"""安全计算服务"""
|
||||
|
||||
ALLOWED_OPS = {
|
||||
"add", "sub", "mul", "truediv", "floordiv",
|
||||
"mod", "pow", "neg", "abs"
|
||||
}
|
||||
|
||||
def evaluate(self, expression: str) -> dict:
|
||||
"""安全计算数学表达式"""
|
||||
import ast
|
||||
import operator
|
||||
|
||||
ops = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.FloorDiv: operator.floordiv,
|
||||
ast.Mod: operator.mod,
|
||||
ast.Pow: operator.pow,
|
||||
ast.USub: operator.neg,
|
||||
ast.UAdd: operator.pos,
|
||||
}
|
||||
|
||||
try:
|
||||
# 解析表达式
|
||||
node = ast.parse(expression, mode="eval")
|
||||
|
||||
# 验证节点类型
|
||||
for child in ast.walk(node):
|
||||
if isinstance(child, ast.Call):
|
||||
return {"error": "Function calls not allowed"}
|
||||
if isinstance(child, ast.Name):
|
||||
return {"error": "Variable names not allowed"}
|
||||
|
||||
# 安全执行
|
||||
result = eval(
|
||||
compile(node, "<string>", "eval"),
|
||||
{"__builtins__": {}},
|
||||
{}
|
||||
)
|
||||
|
||||
return {"result": result}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Calculation error: {str(e)}"}
|
||||
```
|
||||
|
||||
---
|
||||
|
|
@ -425,57 +821,154 @@ class ContentExtractor:
|
|||
```python
|
||||
# backend/tools/__init__.py
|
||||
|
||||
from .core import ToolDefinition, ToolResult, ToolRegistry, registry, ToolExecutor
|
||||
from .factory import tool
|
||||
"""
|
||||
NanoClaw Tool System
|
||||
|
||||
def init_tools():
|
||||
"""初始化所有内置工具"""
|
||||
# 导入即自动注册
|
||||
from .builtin import crawler, data, weather
|
||||
Usage:
|
||||
from backend.tools import registry, ToolExecutor, tool
|
||||
from backend.tools import init_tools
|
||||
|
||||
# 使用时
|
||||
init_tools()
|
||||
# 初始化内置工具
|
||||
init_tools()
|
||||
|
||||
# 列出所有工具
|
||||
tools = registry.list_all()
|
||||
|
||||
# 执行工具
|
||||
result = registry.execute("web_search", {"query": "Python"})
|
||||
"""
|
||||
|
||||
from backend.tools.core import ToolDefinition, ToolResult, ToolRegistry, registry
|
||||
from backend.tools.factory import tool, register_tool
|
||||
from backend.tools.executor import ToolExecutor
|
||||
|
||||
|
||||
def init_tools() -> None:
|
||||
"""
|
||||
初始化所有内置工具
|
||||
|
||||
导入 builtin 模块会自动注册所有装饰器定义的工具
|
||||
"""
|
||||
from backend.tools.builtin import crawler, data, weather, file_ops # noqa: F401
|
||||
|
||||
|
||||
# 公开 API 导出
|
||||
__all__ = [
|
||||
# 核心类
|
||||
"ToolDefinition",
|
||||
"ToolResult",
|
||||
"ToolRegistry",
|
||||
"ToolExecutor",
|
||||
# 实例
|
||||
"registry",
|
||||
# 工厂函数
|
||||
"tool",
|
||||
"register_tool",
|
||||
# 初始化
|
||||
"init_tools",
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 七、工具清单
|
||||
|
||||
| 类别 | 工具名称 | 描述 | 依赖服务 |
|
||||
| ------- | --------------- | ---- | ------------- |
|
||||
| crawler | `web_search` | 网页搜索 | SearchService |
|
||||
| crawler | `fetch_page` | 单页抓取 | FetchService |
|
||||
| crawler | `crawl_batch` | 批量爬取 | FetchService |
|
||||
| data | `calculator` | 数学计算 | CalculatorService |
|
||||
| data | `text_process` | 文本处理 | - |
|
||||
| data | `json_process` | JSON处理 | - |
|
||||
| weather | `get_weather` | 天气查询 | - (模拟数据) |
|
||||
| file | `file_read` | 读取文件 | - |
|
||||
| file | `file_write` | 写入文件 | - |
|
||||
| file | `file_delete` | 删除文件 | - |
|
||||
| file | `file_list` | 列出目录 | - |
|
||||
| file | `file_exists` | 检查存在 | - |
|
||||
| file | `file_mkdir` | 创建目录 | - |
|
||||
### 7.1 爬虫工具 (crawler)
|
||||
|
||||
| 工具名称 | 描述 | 参数 |
|
||||
| --------------- | --------------------------- | --------------------------------------- |
|
||||
| `web_search` | 搜索互联网获取信息 | `query`: 搜索关键词<br>`max_results`: 结果数量(默认 5) |
|
||||
| `fetch_page` | 抓取单个网页内容 | `url`: 网页 URL<br>`extract_type`: 提取类型(text/links/structured) |
|
||||
| `crawl_batch` | 批量抓取多个网页(最多 10 个) | `urls`: URL 列表<br>`extract_type`: 提取类型 |
|
||||
|
||||
### 7.2 数据处理工具 (data)
|
||||
|
||||
| 工具名称 | 描述 | 参数 |
|
||||
| --------------- | --------------------------- | --------------------------------------- |
|
||||
| `calculator` | 执行数学计算(支持加减乘除、幂、模等) | `expression`: 数学表达式 |
|
||||
| `text_process` | 文本处理(计数、格式转换等) | `text`: 文本内容<br>`operation`: 操作类型(count/lines/words/upper/lower/reverse) |
|
||||
| `json_process` | JSON 处理(解析、格式化、提取、验证) | `json_string`: JSON 字符串<br>`operation`: 操作类型(parse/format/keys/validate) |
|
||||
|
||||
### 7.3 天气工具 (weather)
|
||||
|
||||
| 工具名称 | 描述 | 参数 |
|
||||
| --------------- | --------------------------- | --------------------------------------- |
|
||||
| `get_weather` | 查询指定城市的天气信息(模拟数据) | `city`: 城市名称(如:北京、上海、广州) |
|
||||
|
||||
### 7.4 文件操作工具 (file)
|
||||
|
||||
| 工具名称 | 描述 | 参数 |
|
||||
| --------------- | --------------------------- | --------------------------------------- |
|
||||
| `file_read` | 读取文件内容 | `path`: 文件路径<br>`encoding`: 编码(默认 utf-8) |
|
||||
| `file_write` | 写入文件(支持覆盖和追加) | `path`: 文件路径<br>`content`: 内容<br>`mode`: 写入模式(write/append) |
|
||||
| `file_delete` | 删除文件 | `path`: 文件路径 |
|
||||
| `file_list` | 列出目录内容 | `path`: 目录路径(默认 .)<br>`pattern`: 文件模式(默认 *) |
|
||||
| `file_exists` | 检查文件或目录是否存在 | `path`: 路径 |
|
||||
| `file_mkdir` | 创建目录(自动创建父目录) | `path`: 目录路径 |
|
||||
|
||||
**安全说明**:文件操作工具限制在项目根目录内,防止越权访问。
|
||||
|
||||
---
|
||||
|
||||
## 八、与旧设计对比
|
||||
|
||||
| 方面 | 旧设计 | 新设计 |
|
||||
| ----- | ----------------- | --------- |
|
||||
| --------- | ----------------- | ----------------- |
|
||||
| 类数量 | 30+ | ~10 |
|
||||
| 工具定义 | 继承 BaseTool | 装饰器 + 函数 |
|
||||
| 中间抽象层 | 5个(CrawlerTool 等) | 无 |
|
||||
| 扩展方式 | 创建子类 | 写函数 + 装饰器 |
|
||||
| 缓存机制 | 无 | 支持结果缓存(TTL 可配置) |
|
||||
| 重复检测 | 无 | 支持会话内重复调用检测 |
|
||||
| 代码量 | 多 | 少 |
|
||||
|
||||
---
|
||||
|
||||
## 九、总结
|
||||
## 九、核心特性
|
||||
|
||||
简化后的设计:
|
||||
### 9.1 装饰器注册
|
||||
|
||||
简化工具定义,只需一个装饰器:
|
||||
|
||||
```python
|
||||
@tool(
|
||||
name="my_tool",
|
||||
description="工具描述",
|
||||
parameters={...},
|
||||
category="custom"
|
||||
)
|
||||
def my_tool(arguments: dict) -> dict:
|
||||
# 工具实现
|
||||
return {"result": "ok"}
|
||||
```
|
||||
|
||||
### 9.2 智能缓存
|
||||
|
||||
- **结果缓存**:相同参数的工具调用结果会被缓存(默认 5 分钟)
|
||||
- **可配置 TTL**:通过 `cache_ttl` 参数设置缓存过期时间
|
||||
- **可禁用**:通过 `enable_cache=False` 关闭缓存
|
||||
|
||||
### 9.3 重复检测
|
||||
|
||||
- **批次内去重**:同一批次中相同工具+参数的调用会被跳过
|
||||
- **历史去重**:同一会话内已调用过的工具会直接返回缓存结果
|
||||
- **自动清理**:新会话开始时调用 `clear_history()` 清理历史
|
||||
|
||||
### 9.4 安全设计
|
||||
|
||||
- **计算器安全**:禁止函数调用和变量名,只支持数学运算
|
||||
- **文件沙箱**:文件操作限制在项目根目录内,防止越权访问
|
||||
- **错误处理**:所有工具执行都有 try-catch,不会因工具错误导致系统崩溃
|
||||
|
||||
---
|
||||
|
||||
## 十、总结
|
||||
|
||||
简化后的设计特点:
|
||||
|
||||
1. **核心类**:`ToolDefinition`、`ToolRegistry`、`ToolExecutor`、`ToolResult`
|
||||
2. **工厂模式**:使用 `@tool` 装饰器注册工具
|
||||
3. **服务分离**:工具依赖的服务独立,不与工具类耦合
|
||||
4. **易于扩展**:新增工具只需写一个函数并加装饰器
|
||||
4. **性能优化**:支持缓存和重复检测,减少重复计算和网络请求
|
||||
5. **易于扩展**:新增工具只需写一个函数并加装饰器
|
||||
6. **安全可靠**:文件沙箱、安全计算、完善的错误处理
|
||||
|
|
|
|||
Loading…
Reference in New Issue