feat: 初步完成工具调用设计
This commit is contained in:
parent
8639860fb9
commit
e77fd71aa7
|
|
@ -4,6 +4,7 @@ from flask import Flask
|
|||
from flask_sqlalchemy import SQLAlchemy
|
||||
from pathlib import Path
|
||||
|
||||
# Initialize db BEFORE importing models/routes that depend on it
|
||||
db = SQLAlchemy()
|
||||
CONFIG_PATH = Path(__file__).parent.parent / "config.yml"
|
||||
|
||||
|
|
@ -26,9 +27,13 @@ def create_app():
|
|||
|
||||
db.init_app(app)
|
||||
|
||||
# Import after db is initialized
|
||||
from .models import User, Conversation, Message, TokenUsage
|
||||
from .routes import register_routes
|
||||
from .tools import init_tools
|
||||
|
||||
register_routes(app)
|
||||
init_tools()
|
||||
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from flask import request, jsonify, Response, Blueprint, current_app
|
|||
from . import db
|
||||
from .models import Conversation, Message, User, TokenUsage
|
||||
from . import load_config
|
||||
from .tools import registry, ToolExecutor
|
||||
|
||||
bp = Blueprint("api", __name__)
|
||||
|
||||
|
|
@ -51,7 +52,7 @@ def to_dict(inst, **extra):
|
|||
|
||||
|
||||
def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
||||
"""记录 token 使用量"""
|
||||
"""Record token usage"""
|
||||
from datetime import date
|
||||
today = date.today()
|
||||
usage = TokenUsage.query.filter_by(
|
||||
|
|
@ -75,10 +76,13 @@ def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
|||
|
||||
|
||||
def build_glm_messages(conv):
|
||||
"""Build messages list for GLM API from conversation"""
|
||||
msgs = []
|
||||
if conv.system_prompt:
|
||||
msgs.append({"role": "system", "content": conv.system_prompt})
|
||||
for m in conv.messages:
|
||||
# Query messages directly to avoid detached instance warning
|
||||
messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all()
|
||||
for m in messages:
|
||||
msgs.append({"role": m.role, "content": m.content})
|
||||
return msgs
|
||||
|
||||
|
|
@ -87,15 +91,27 @@ def build_glm_messages(conv):
|
|||
|
||||
@bp.route("/api/models", methods=["GET"])
|
||||
def list_models():
|
||||
"""获取可用模型列表"""
|
||||
"""Get available model list"""
|
||||
return ok(MODELS)
|
||||
|
||||
|
||||
# -- Tools API --------------------------------------------
|
||||
|
||||
@bp.route("/api/tools", methods=["GET"])
|
||||
def list_tools():
|
||||
"""Get available tool list"""
|
||||
tools = registry.list_all()
|
||||
return ok({
|
||||
"tools": tools,
|
||||
"total": len(tools)
|
||||
})
|
||||
|
||||
|
||||
# -- Token Usage Statistics --------------------------------
|
||||
|
||||
@bp.route("/api/stats/tokens", methods=["GET"])
|
||||
def token_stats():
|
||||
"""获取 token 使用统计"""
|
||||
"""Get token usage statistics"""
|
||||
from sqlalchemy import func
|
||||
from datetime import date, timedelta
|
||||
|
||||
|
|
@ -105,7 +121,7 @@ def token_stats():
|
|||
today = date.today()
|
||||
|
||||
if period == "daily":
|
||||
# 今日统计
|
||||
# Today's statistics
|
||||
stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all()
|
||||
result = {
|
||||
"period": "daily",
|
||||
|
|
@ -116,7 +132,7 @@ def token_stats():
|
|||
"by_model": {s.model: {"prompt": s.prompt_tokens, "completion": s.completion_tokens, "total": s.total_tokens} for s in stats}
|
||||
}
|
||||
elif period == "weekly":
|
||||
# 本周统计 (最近7天)
|
||||
# Weekly statistics (last 7 days)
|
||||
start_date = today - timedelta(days=6)
|
||||
stats = TokenUsage.query.filter(
|
||||
TokenUsage.user_id == user.id,
|
||||
|
|
@ -133,7 +149,7 @@ def token_stats():
|
|||
daily_data[d]["completion"] += s.completion_tokens
|
||||
daily_data[d]["total"] += s.total_tokens
|
||||
|
||||
# 填充没有数据的日期
|
||||
# Fill missing dates
|
||||
for i in range(7):
|
||||
d = (today - timedelta(days=6-i)).isoformat()
|
||||
if d not in daily_data:
|
||||
|
|
@ -149,7 +165,7 @@ def token_stats():
|
|||
"daily": daily_data
|
||||
}
|
||||
elif period == "monthly":
|
||||
# 本月统计 (最近30天)
|
||||
# Monthly statistics (last 30 days)
|
||||
start_date = today - timedelta(days=29)
|
||||
stats = TokenUsage.query.filter(
|
||||
TokenUsage.user_id == user.id,
|
||||
|
|
@ -166,7 +182,7 @@ def token_stats():
|
|||
daily_data[d]["completion"] += s.completion_tokens
|
||||
daily_data[d]["total"] += s.total_tokens
|
||||
|
||||
# 填充没有数据的日期
|
||||
# Fill missing dates
|
||||
for i in range(30):
|
||||
d = (today - timedelta(days=29-i)).isoformat()
|
||||
if d not in daily_data:
|
||||
|
|
@ -301,15 +317,19 @@ def delete_message(conv_id, msg_id):
|
|||
|
||||
# -- Chat Completion ----------------------------------
|
||||
|
||||
def _call_glm(conv, stream=False):
|
||||
def _call_glm(conv, stream=False, tools=None, messages=None):
|
||||
"""Call GLM API"""
|
||||
body = {
|
||||
"model": conv.model,
|
||||
"messages": build_glm_messages(conv),
|
||||
"messages": messages if messages is not None else build_glm_messages(conv),
|
||||
"max_tokens": conv.max_tokens,
|
||||
"temperature": conv.temperature,
|
||||
}
|
||||
if conv.thinking_enabled:
|
||||
body["thinking"] = {"type": "enabled"}
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
body["tool_choice"] = "auto"
|
||||
if stream:
|
||||
body["stream"] = True
|
||||
return requests.post(
|
||||
|
|
@ -320,101 +340,192 @@ def _call_glm(conv, stream=False):
|
|||
|
||||
|
||||
def _sync_response(conv):
|
||||
try:
|
||||
resp = _call_glm(conv)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
except Exception as e:
|
||||
return err(500, f"upstream error: {e}")
|
||||
|
||||
choice = result["choices"][0]
|
||||
usage = result.get("usage", {})
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
msg = Message(
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id, role="assistant",
|
||||
content=choice["message"]["content"],
|
||||
token_count=completion_tokens,
|
||||
thinking_content=choice["message"].get("reasoning_content", ""),
|
||||
)
|
||||
db.session.add(msg)
|
||||
db.session.commit()
|
||||
|
||||
# 记录 token 使用
|
||||
user = get_or_create_default_user()
|
||||
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
|
||||
|
||||
return ok({
|
||||
"message": to_dict(msg, thinking_content=msg.thinking_content or None),
|
||||
"usage": {"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": usage.get("total_tokens", 0)},
|
||||
})
|
||||
|
||||
|
||||
def _stream_response(conv):
|
||||
conv_id = conv.id
|
||||
conv_model = conv.model
|
||||
app = current_app._get_current_object()
|
||||
|
||||
def generate():
|
||||
full_content = ""
|
||||
full_thinking = ""
|
||||
token_count = 0
|
||||
prompt_tokens = 0
|
||||
msg_id = str(uuid.uuid4())
|
||||
"""Sync response with tool call support"""
|
||||
executor = ToolExecutor(registry=registry)
|
||||
tools = registry.list_all()
|
||||
messages = build_glm_messages(conv)
|
||||
max_iterations = 5 # Max tool call iterations
|
||||
|
||||
for _ in range(max_iterations):
|
||||
try:
|
||||
with app.app_context():
|
||||
active_conv = db.session.get(Conversation, conv_id)
|
||||
resp = _call_glm(active_conv, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode("utf-8")
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
text = delta.get("content", "")
|
||||
if reasoning:
|
||||
full_thinking += reasoning
|
||||
yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n"
|
||||
if text:
|
||||
full_content += text
|
||||
yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n"
|
||||
usage = chunk.get("usage", {})
|
||||
if usage:
|
||||
token_count = usage.get("completion_tokens", 0)
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
resp = _call_glm(conv, tools=tools if tools else None, messages=messages)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
except Exception as e:
|
||||
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
return err(500, f"upstream error: {e}")
|
||||
|
||||
choice = result["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
# If no tool calls, return final result
|
||||
if not message.get("tool_calls"):
|
||||
usage = result.get("usage", {})
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
# 流式结束后最后写入数据库
|
||||
with app.app_context():
|
||||
msg = Message(
|
||||
id=msg_id, conversation_id=conv_id, role="assistant",
|
||||
content=full_content, token_count=token_count, thinking_content=full_thinking,
|
||||
id=str(uuid.uuid4()), conversation_id=conv.id, role="assistant",
|
||||
content=message.get("content", ""),
|
||||
token_count=completion_tokens,
|
||||
thinking_content=message.get("reasoning_content", ""),
|
||||
)
|
||||
db.session.add(msg)
|
||||
db.session.commit()
|
||||
|
||||
# 记录 token 使用
|
||||
user = get_or_create_default_user()
|
||||
record_token_usage(user.id, conv_model, prompt_tokens, token_count)
|
||||
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
|
||||
|
||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': token_count})}\n\n"
|
||||
return ok({
|
||||
"message": to_dict(msg, thinking_content=msg.thinking_content or None),
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": usage.get("total_tokens", 0)
|
||||
},
|
||||
})
|
||||
|
||||
# Process tool calls
|
||||
tool_calls = message["tool_calls"]
|
||||
messages.append(message)
|
||||
|
||||
# Execute tools and add results
|
||||
tool_results = executor.process_tool_calls(tool_calls)
|
||||
messages.extend(tool_results)
|
||||
|
||||
# Save tool call records to database
|
||||
for i, call in enumerate(tool_calls):
|
||||
tool_msg = Message(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conv.id,
|
||||
role="tool",
|
||||
content=tool_results[i]["content"]
|
||||
)
|
||||
db.session.add(tool_msg)
|
||||
db.session.commit()
|
||||
|
||||
return err(500, "exceeded maximum tool call iterations")
|
||||
|
||||
|
||||
def _stream_response(conv):
|
||||
"""Stream response with tool call support"""
|
||||
conv_id = conv.id
|
||||
conv_model = conv.model
|
||||
app = current_app._get_current_object()
|
||||
executor = ToolExecutor(registry=registry)
|
||||
tools = registry.list_all()
|
||||
# Build messages BEFORE entering generator (in request context)
|
||||
initial_messages = build_glm_messages(conv)
|
||||
|
||||
def generate():
|
||||
messages = list(initial_messages) # Copy to avoid mutation
|
||||
max_iterations = 5
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
full_content = ""
|
||||
full_thinking = ""
|
||||
token_count = 0
|
||||
prompt_tokens = 0
|
||||
msg_id = str(uuid.uuid4())
|
||||
tool_calls_list = []
|
||||
current_tool_call = None
|
||||
|
||||
try:
|
||||
with app.app_context():
|
||||
active_conv = db.session.get(Conversation, conv_id)
|
||||
resp = _call_glm(active_conv, stream=True, tools=tools if tools else None, messages=messages)
|
||||
resp.raise_for_status()
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode("utf-8")
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
|
||||
# Process thinking chain
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
full_thinking += reasoning
|
||||
yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Process text content
|
||||
text = delta.get("content", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Process tool calls
|
||||
tool_calls_delta = delta.get("tool_calls", [])
|
||||
for tc in tool_calls_delta:
|
||||
idx = tc.get("index", 0)
|
||||
if idx >= len(tool_calls_list):
|
||||
tool_calls_list.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": tc.get("type", "function"),
|
||||
"function": {"name": "", "arguments": ""}
|
||||
})
|
||||
if tc.get("id"):
|
||||
tool_calls_list[idx]["id"] = tc["id"]
|
||||
if tc.get("function"):
|
||||
if tc["function"].get("name"):
|
||||
tool_calls_list[idx]["function"]["name"] = tc["function"]["name"]
|
||||
if tc["function"].get("arguments"):
|
||||
tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"]
|
||||
|
||||
usage = chunk.get("usage", {})
|
||||
if usage:
|
||||
token_count = usage.get("completion_tokens", 0)
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
|
||||
except Exception as e:
|
||||
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# If tool calls exist, execute and continue loop
|
||||
if tool_calls_list:
|
||||
# Send tool call info
|
||||
yield f"event: tool_calls\ndata: {json.dumps({'calls': tool_calls_list}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Execute tools
|
||||
tool_results = executor.process_tool_calls(tool_calls_list)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": full_content or None,
|
||||
"tool_calls": tool_calls_list
|
||||
})
|
||||
messages.extend(tool_results)
|
||||
|
||||
# Send tool results
|
||||
for tr in tool_results:
|
||||
yield f"event: tool_result\ndata: {json.dumps({'name': tr['name'], 'content': tr['content']}, ensure_ascii=False)}\n\n"
|
||||
|
||||
continue
|
||||
|
||||
# No tool calls, finish
|
||||
with app.app_context():
|
||||
msg = Message(
|
||||
id=msg_id, conversation_id=conv_id, role="assistant",
|
||||
content=full_content, token_count=token_count, thinking_content=full_thinking,
|
||||
)
|
||||
db.session.add(msg)
|
||||
db.session.commit()
|
||||
|
||||
user = get_or_create_default_user()
|
||||
record_token_usage(user.id, conv_model, prompt_tokens, token_count)
|
||||
|
||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': token_count})}\n\n"
|
||||
return
|
||||
|
||||
yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return Response(generate(), mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
"""
|
||||
NanoClaw Tool System
|
||||
|
||||
Usage:
|
||||
from backend.tools import registry, ToolExecutor, tool
|
||||
from backend.tools import init_tools
|
||||
|
||||
# Initialize built-in tools
|
||||
init_tools()
|
||||
|
||||
# List all tools
|
||||
tools = registry.list_all()
|
||||
|
||||
# Execute a tool
|
||||
result = registry.execute("web_search", {"query": "Python"})
|
||||
"""
|
||||
|
||||
from .core import ToolDefinition, ToolResult, ToolRegistry, registry
|
||||
from .factory import tool, register_tool
|
||||
from .executor import ToolExecutor
|
||||
|
||||
|
||||
def init_tools() -> None:
|
||||
"""
|
||||
Initialize all built-in tools
|
||||
|
||||
Importing builtin module automatically registers all decorator-defined tools
|
||||
"""
|
||||
from .builtin import crawler, data # noqa: F401
|
||||
|
||||
|
||||
# Public API exports
|
||||
__all__ = [
|
||||
# Core classes
|
||||
"ToolDefinition",
|
||||
"ToolResult",
|
||||
"ToolRegistry",
|
||||
"ToolExecutor",
|
||||
# Instances
|
||||
"registry",
|
||||
# Factory functions
|
||||
"tool",
|
||||
"register_tool",
|
||||
# Initialization
|
||||
"init_tools",
|
||||
]
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""Built-in tools"""
|
||||
from .crawler import *
|
||||
from .data import *
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
"""Crawler related tools"""
|
||||
from ..factory import tool
|
||||
from ..services import SearchService, FetchService
|
||||
|
||||
|
||||
@tool(
|
||||
name="web_search",
|
||||
description="Search the internet for information. Use when you need to find latest news or answer questions that require web search.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search keywords"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return, default 5",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
},
|
||||
category="crawler"
|
||||
)
|
||||
def web_search(arguments: dict) -> dict:
|
||||
"""
|
||||
Web search tool
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"query": "search keywords",
|
||||
"max_results": 5
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"results": [...]}
|
||||
"""
|
||||
query = arguments["query"]
|
||||
max_results = arguments.get("max_results", 5)
|
||||
|
||||
service = SearchService()
|
||||
results = service.search(query, max_results)
|
||||
|
||||
return {"results": results}
|
||||
|
||||
|
||||
@tool(
|
||||
name="fetch_page",
|
||||
description="Fetch content from a specific webpage. Use when user needs detailed information from a webpage.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the webpage to fetch"
|
||||
},
|
||||
"extract_type": {
|
||||
"type": "string",
|
||||
"description": "Extraction type",
|
||||
"enum": ["text", "links", "structured"],
|
||||
"default": "text"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
},
|
||||
category="crawler"
|
||||
)
|
||||
def fetch_page(arguments: dict) -> dict:
|
||||
"""
|
||||
Page fetch tool
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"url": "https://example.com",
|
||||
"extract_type": "text" | "links" | "structured"
|
||||
}
|
||||
|
||||
Returns:
|
||||
Page content
|
||||
"""
|
||||
url = arguments["url"]
|
||||
extract_type = arguments.get("extract_type", "text")
|
||||
|
||||
service = FetchService()
|
||||
result = service.fetch(url, extract_type)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@tool(
|
||||
name="crawl_batch",
|
||||
description="Batch fetch multiple webpages. Use when you need to get content from multiple pages at once.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"urls": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of URLs to fetch"
|
||||
},
|
||||
"extract_type": {
|
||||
"type": "string",
|
||||
"enum": ["text", "links", "structured"],
|
||||
"default": "text"
|
||||
}
|
||||
},
|
||||
"required": ["urls"]
|
||||
},
|
||||
category="crawler"
|
||||
)
|
||||
def crawl_batch(arguments: dict) -> dict:
|
||||
"""
|
||||
Batch fetch tool
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"urls": ["url1", "url2", ...],
|
||||
"extract_type": "text"
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"results": [...]}
|
||||
"""
|
||||
urls = arguments["urls"]
|
||||
extract_type = arguments.get("extract_type", "text")
|
||||
|
||||
if len(urls) > 10:
|
||||
return {"error": "Maximum 10 pages can be fetched at once"}
|
||||
|
||||
service = FetchService()
|
||||
results = service.fetch_batch(urls, extract_type)
|
||||
|
||||
return {"results": results, "total": len(results)}
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
"""Data processing related tools"""
|
||||
from ..factory import tool
|
||||
from ..services import CalculatorService
|
||||
|
||||
|
||||
@tool(
|
||||
name="calculator",
|
||||
description="Perform mathematical calculations. Supports basic arithmetic: addition, subtraction, multiplication, division, power, modulo, etc.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Mathematical expression, e.g.: (2 + 3) * 4, 2 ** 10, 100 / 7"
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
},
|
||||
category="data"
|
||||
)
|
||||
def calculator(arguments: dict) -> dict:
|
||||
"""
|
||||
Calculator tool
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"expression": "2 + 3 * 4"
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"result": 14}
|
||||
"""
|
||||
expression = arguments["expression"]
|
||||
service = CalculatorService()
|
||||
return service.evaluate(expression)
|
||||
|
||||
|
||||
@tool(
|
||||
name="text_process",
|
||||
description="Process text content, supports counting, format conversion and other operations.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to process"
|
||||
},
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"description": "Operation type",
|
||||
"enum": ["count", "lines", "words", "upper", "lower", "reverse"]
|
||||
}
|
||||
},
|
||||
"required": ["text", "operation"]
|
||||
},
|
||||
category="data"
|
||||
)
|
||||
def text_process(arguments: dict) -> dict:
|
||||
"""
|
||||
Text processing tool
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"text": "text content",
|
||||
"operation": "count" | "lines" | "words" | ...
|
||||
}
|
||||
|
||||
Returns:
|
||||
Processing result
|
||||
"""
|
||||
text = arguments["text"]
|
||||
operation = arguments["operation"]
|
||||
|
||||
operations = {
|
||||
"count": lambda t: {"count": len(t)},
|
||||
"lines": lambda t: {"lines": len(t.splitlines())},
|
||||
"words": lambda t: {"words": len(t.split())},
|
||||
"upper": lambda t: {"result": t.upper()},
|
||||
"lower": lambda t: {"result": t.lower()},
|
||||
"reverse": lambda t: {"result": t[::-1]}
|
||||
}
|
||||
|
||||
if operation not in operations:
|
||||
return {"error": f"Unknown operation: {operation}"}
|
||||
|
||||
return operations[operation](text)
|
||||
|
||||
|
||||
@tool(
|
||||
name="json_process",
|
||||
description="Process JSON data, supports parsing, formatting, extraction and other operations.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"json_string": {
|
||||
"type": "string",
|
||||
"description": "JSON string"
|
||||
},
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"description": "Operation type",
|
||||
"enum": ["parse", "format", "keys", "validate"]
|
||||
}
|
||||
},
|
||||
"required": ["json_string", "operation"]
|
||||
},
|
||||
category="data"
|
||||
)
|
||||
def json_process(arguments: dict) -> dict:
|
||||
"""
|
||||
JSON processing tool
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"json_string": '{"key": "value"}',
|
||||
"operation": "parse" | "format" | "keys" | "validate"
|
||||
}
|
||||
|
||||
Returns:
|
||||
Processing result
|
||||
"""
|
||||
import json
|
||||
|
||||
json_string = arguments["json_string"]
|
||||
operation = arguments["operation"]
|
||||
|
||||
try:
|
||||
if operation == "validate":
|
||||
json.loads(json_string)
|
||||
return {"valid": True}
|
||||
|
||||
data = json.loads(json_string)
|
||||
|
||||
if operation == "parse":
|
||||
return {"data": data}
|
||||
elif operation == "format":
|
||||
return {"result": json.dumps(data, indent=2, ensure_ascii=False)}
|
||||
elif operation == "keys":
|
||||
if isinstance(data, dict):
|
||||
return {"keys": list(data.keys())}
|
||||
return {"error": "JSON root element is not an object"}
|
||||
else:
|
||||
return {"error": f"Unknown operation: {operation}"}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return {"error": f"JSON parse error: {str(e)}"}
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
"""Tool system core classes"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""Tool definition"""
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict # JSON Schema
|
||||
handler: Callable[[dict], Any]
|
||||
category: str = "general"
|
||||
|
||||
def to_openai_format(self) -> dict:
|
||||
"""Convert to OpenAI/GLM compatible format"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Tool execution result"""
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"success": self.success,
|
||||
"data": self.data,
|
||||
"error": self.error
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def ok(cls, data: Any) -> "ToolResult":
|
||||
return cls(success=True, data=data)
|
||||
|
||||
@classmethod
|
||||
def fail(cls, error: str) -> "ToolResult":
|
||||
return cls(success=False, error=error)
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Tool registry (singleton)"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._tools: Dict[str, ToolDefinition] = {}
|
||||
return cls._instance
|
||||
|
||||
def register(self, tool: ToolDefinition) -> None:
|
||||
"""Register a tool"""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> Optional[ToolDefinition]:
|
||||
"""Get tool definition by name"""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_all(self) -> List[dict]:
|
||||
"""List all tools in OpenAI format"""
|
||||
return [t.to_openai_format() for t in self._tools.values()]
|
||||
|
||||
def list_by_category(self, category: str) -> List[dict]:
|
||||
"""List tools by category"""
|
||||
return [
|
||||
t.to_openai_format()
|
||||
for t in self._tools.values()
|
||||
if t.category == category
|
||||
]
|
||||
|
||||
def execute(self, name: str, arguments: dict) -> dict:
|
||||
"""Execute a tool"""
|
||||
tool = self.get(name)
|
||||
if not tool:
|
||||
return ToolResult.fail(f"Tool not found: {name}").to_dict()
|
||||
|
||||
try:
|
||||
result = tool.handler(arguments)
|
||||
if isinstance(result, ToolResult):
|
||||
return result.to_dict()
|
||||
return ToolResult.ok(result).to_dict()
|
||||
except Exception as e:
|
||||
return ToolResult.fail(str(e)).to_dict()
|
||||
|
||||
def remove(self, name: str) -> bool:
|
||||
"""Remove a tool"""
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
"""Check if tool exists"""
|
||||
return name in self._tools
|
||||
|
||||
|
||||
# Global registry instance
|
||||
registry = ToolRegistry()
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
"""Tool executor"""
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Optional, Generator, Any
|
||||
from .core import ToolRegistry, registry
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Tool call executor"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: Optional[ToolRegistry] = None,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
self.registry = registry or ToolRegistry()
|
||||
self.api_url = api_url
|
||||
self.api_key = api_key
|
||||
|
||||
def process_tool_calls(
|
||||
self,
|
||||
tool_calls: List[dict],
|
||||
context: Optional[dict] = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Process tool calls and return message list
|
||||
|
||||
Args:
|
||||
tool_calls: Tool call list returned by LLM
|
||||
context: Optional context info (user_id, etc.)
|
||||
|
||||
Returns:
|
||||
Tool response message list, can be appended to messages
|
||||
"""
|
||||
results = []
|
||||
|
||||
for call in tool_calls:
|
||||
name = call["function"]["name"]
|
||||
args_str = call["function"]["arguments"]
|
||||
call_id = call["id"]
|
||||
|
||||
try:
|
||||
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||
except json.JSONDecodeError:
|
||||
results.append(self._create_error_result(
|
||||
call_id, name, "Invalid JSON arguments"
|
||||
))
|
||||
continue
|
||||
|
||||
result = self.registry.execute(name, args)
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
|
||||
return results
|
||||
|
||||
def _create_tool_result(
|
||||
self,
|
||||
call_id: str,
|
||||
name: str,
|
||||
result: dict,
|
||||
execution_time: float = 0
|
||||
) -> dict:
|
||||
"""Create tool result message"""
|
||||
result["execution_time"] = execution_time
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": name,
|
||||
"content": json.dumps(result, ensure_ascii=False, default=str)
|
||||
}
|
||||
|
||||
def _create_error_result(
|
||||
self,
|
||||
call_id: str,
|
||||
name: str,
|
||||
error: str
|
||||
) -> dict:
|
||||
"""Create error result message"""
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": name,
|
||||
"content": json.dumps({
|
||||
"success": False,
|
||||
"error": error
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
|
||||
def build_request(
|
||||
self,
|
||||
messages: List[dict],
|
||||
model: str = "glm-5",
|
||||
tools: Optional[List[dict]] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
Build API request body
|
||||
|
||||
Args:
|
||||
messages: Message list
|
||||
model: Model name
|
||||
tools: Tool list (default: all tools in registry)
|
||||
**kwargs: Other parameters (temperature, max_tokens, etc.)
|
||||
|
||||
Returns:
|
||||
Request body dict
|
||||
"""
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools or self.registry.list_all(),
|
||||
"tool_choice": kwargs.get("tool_choice", "auto"),
|
||||
**{k: v for k, v in kwargs.items() if k not in ["tool_choice"]}
|
||||
}
|
||||
|
||||
def execute_with_retry(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0
|
||||
) -> dict:
|
||||
"""
|
||||
Execute tool with retry
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
arguments: Tool arguments
|
||||
max_retries: Max retry count
|
||||
retry_delay: Retry delay in seconds
|
||||
|
||||
Returns:
|
||||
Execution result
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.registry.execute(name, arguments)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed after {max_retries} retries: {last_error}"
|
||||
}
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
"""Tool factory - decorator registration"""
|
||||
from typing import Callable
|
||||
from .core import ToolDefinition, registry
|
||||
|
||||
|
||||
def tool(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict,
|
||||
category: str = "general"
|
||||
) -> Callable:
|
||||
"""
|
||||
Tool registration decorator
|
||||
|
||||
Usage:
|
||||
@tool(
|
||||
name="web_search",
|
||||
description="Search the web",
|
||||
parameters={"type": "object", "properties": {...}},
|
||||
category="crawler"
|
||||
)
|
||||
def web_search(arguments: dict) -> dict:
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
tool_def = ToolDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
handler=func,
|
||||
category=category
|
||||
)
|
||||
registry.register(tool_def)
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def register_tool(
|
||||
name: str,
|
||||
handler: Callable,
|
||||
description: str,
|
||||
parameters: dict,
|
||||
category: str = "general"
|
||||
) -> None:
|
||||
"""
|
||||
Register a tool directly (without decorator)
|
||||
|
||||
Usage:
|
||||
register_tool(
|
||||
name="my_tool",
|
||||
handler=my_function,
|
||||
description="Description",
|
||||
parameters={...}
|
||||
)
|
||||
"""
|
||||
tool_def = ToolDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
handler=handler,
|
||||
category=category
|
||||
)
|
||||
registry.register(tool_def)
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
"""Tool helper services"""
|
||||
from typing import List, Dict, Optional, Any
|
||||
import re
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Search service"""
|
||||
|
||||
def __init__(self, engine: str = "duckduckgo"):
|
||||
self.engine = engine
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "cn-zh"
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Execute search
|
||||
|
||||
Args:
|
||||
query: Search keywords
|
||||
max_results: Max result count
|
||||
region: Region setting
|
||||
|
||||
Returns:
|
||||
Search result list
|
||||
"""
|
||||
if self.engine == "duckduckgo":
|
||||
return self._search_duckduckgo(query, max_results, region)
|
||||
else:
|
||||
raise ValueError(f"Unsupported search engine: {self.engine}")
|
||||
|
||||
def _search_duckduckgo(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int,
|
||||
region: str
|
||||
) -> List[dict]:
|
||||
"""DuckDuckGo search"""
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
return [{"error": "Please install duckduckgo-search: pip install duckduckgo-search"}]
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text(
|
||||
query,
|
||||
max_results=max_results,
|
||||
region=region
|
||||
))
|
||||
|
||||
return [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("href", ""),
|
||||
"snippet": r.get("body", "")
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
class FetchService:
|
||||
"""Page fetch service"""
|
||||
|
||||
def __init__(self, timeout: float = 30.0, user_agent: str = None):
|
||||
self.timeout = timeout
|
||||
self.user_agent = user_agent or (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/120.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
url: str,
|
||||
extract_type: str = "text"
|
||||
) -> dict:
|
||||
"""
|
||||
Fetch a single page
|
||||
|
||||
Args:
|
||||
url: Page URL
|
||||
extract_type: Extract type (text, links, structured)
|
||||
|
||||
Returns:
|
||||
Fetch result
|
||||
"""
|
||||
import httpx
|
||||
|
||||
try:
|
||||
resp = httpx.get(
|
||||
url,
|
||||
timeout=self.timeout,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": self.user_agent}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except Exception as e:
|
||||
return {"error": str(e), "url": url}
|
||||
|
||||
html = resp.text
|
||||
extractor = ContentExtractor(html)
|
||||
|
||||
if extract_type == "text":
|
||||
return {
|
||||
"url": url,
|
||||
"text": extractor.extract_text()
|
||||
}
|
||||
elif extract_type == "links":
|
||||
return {
|
||||
"url": url,
|
||||
"links": extractor.extract_links()
|
||||
}
|
||||
else:
|
||||
return extractor.extract_structured(url)
|
||||
|
||||
def fetch_batch(
|
||||
self,
|
||||
urls: List[str],
|
||||
extract_type: str = "text",
|
||||
max_concurrent: int = 5
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Batch fetch pages
|
||||
|
||||
Args:
|
||||
urls: URL list
|
||||
extract_type: Extract type
|
||||
max_concurrent: Max concurrent requests
|
||||
|
||||
Returns:
|
||||
Result list
|
||||
"""
|
||||
results = []
|
||||
for url in urls:
|
||||
results.append(self.fetch(url, extract_type))
|
||||
return results
|
||||
|
||||
|
||||
class ContentExtractor:
|
||||
"""Content extractor"""
|
||||
|
||||
def __init__(self, html: str):
|
||||
self.html = html
|
||||
self._soup = None
|
||||
|
||||
@property
|
||||
def soup(self):
|
||||
if self._soup is None:
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
self._soup = BeautifulSoup(self.html, "html.parser")
|
||||
except ImportError:
|
||||
raise ImportError("Please install beautifulsoup4: pip install beautifulsoup4")
|
||||
return self._soup
|
||||
|
||||
def extract_text(self) -> str:
|
||||
"""Extract plain text"""
|
||||
# Remove script and style
|
||||
for tag in self.soup(["script", "style", "nav", "footer", "header"]):
|
||||
tag.decompose()
|
||||
|
||||
text = self.soup.get_text(separator="\n", strip=True)
|
||||
# Clean extra whitespace
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text
|
||||
|
||||
def extract_links(self) -> List[dict]:
|
||||
"""Extract links"""
|
||||
links = []
|
||||
for a in self.soup.find_all("a", href=True):
|
||||
text = a.get_text(strip=True)
|
||||
href = a["href"]
|
||||
if text and href and not href.startswith(("#", "javascript:")):
|
||||
links.append({"text": text, "href": href})
|
||||
return links[:50] # Limit count
|
||||
|
||||
def extract_structured(self, url: str = "") -> dict:
|
||||
"""Extract structured content"""
|
||||
soup = self.soup
|
||||
|
||||
# Extract title
|
||||
title = ""
|
||||
if soup.title:
|
||||
title = soup.title.string or ""
|
||||
|
||||
# Extract meta description
|
||||
description = ""
|
||||
meta_desc = soup.find("meta", attrs={"name": "description"})
|
||||
if meta_desc:
|
||||
description = meta_desc.get("content", "")
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"title": title.strip(),
|
||||
"description": description.strip(),
|
||||
"text": self.extract_text()[:5000], # Limit length
|
||||
"links": self.extract_links()[:20]
|
||||
}
|
||||
|
||||
|
||||
class CalculatorService:
|
||||
"""Safe calculation service"""
|
||||
|
||||
ALLOWED_OPS = {
|
||||
"add", "sub", "mul", "truediv", "floordiv",
|
||||
"mod", "pow", "neg", "abs"
|
||||
}
|
||||
|
||||
def evaluate(self, expression: str) -> dict:
|
||||
"""
|
||||
Safely evaluate mathematical expression
|
||||
|
||||
Args:
|
||||
expression: Mathematical expression
|
||||
|
||||
Returns:
|
||||
Calculation result
|
||||
"""
|
||||
import ast
|
||||
import operator
|
||||
|
||||
ops = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.FloorDiv: operator.floordiv,
|
||||
ast.Mod: operator.mod,
|
||||
ast.Pow: operator.pow,
|
||||
ast.USub: operator.neg,
|
||||
ast.UAdd: operator.pos,
|
||||
}
|
||||
|
||||
try:
|
||||
# Parse expression
|
||||
node = ast.parse(expression, mode="eval")
|
||||
|
||||
# Validate node types
|
||||
for child in ast.walk(node):
|
||||
if isinstance(child, ast.Call):
|
||||
return {"error": "Function calls not allowed"}
|
||||
if isinstance(child, ast.Name):
|
||||
return {"error": "Variable names not allowed"}
|
||||
|
||||
# Safe execution
|
||||
result = eval(
|
||||
compile(node, "<string>", "eval"),
|
||||
{"__builtins__": {}},
|
||||
{}
|
||||
)
|
||||
|
||||
return {"result": result}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Calculation error: {str(e)}"}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -4,21 +4,21 @@
|
|||
|
||||
### 会话管理
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| `POST` | `/api/conversations` | 创建会话 |
|
||||
| `GET` | `/api/conversations` | 获取会话列表 |
|
||||
| `GET` | `/api/conversations/:id` | 获取会话详情 |
|
||||
| `PATCH` | `/api/conversations/:id` | 更新会话 |
|
||||
| `DELETE` | `/api/conversations/:id` | 删除会话 |
|
||||
| 方法 | 路径 | 说明 |
|
||||
| -------- | ------------------------ | ------ |
|
||||
| `POST` | `/api/conversations` | 创建会话 |
|
||||
| `GET` | `/api/conversations` | 获取会话列表 |
|
||||
| `GET` | `/api/conversations/:id` | 获取会话详情 |
|
||||
| `PATCH` | `/api/conversations/:id` | 更新会话 |
|
||||
| `DELETE` | `/api/conversations/:id` | 删除会话 |
|
||||
|
||||
### 消息管理
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| `GET` | `/api/conversations/:id/messages` | 获取消息列表 |
|
||||
| `POST` | `/api/conversations/:id/messages` | 发送消息(对话补全,支持 `stream` 流式) |
|
||||
| `DELETE` | `/api/conversations/:id/messages/:message_id` | 删除消息 |
|
||||
| 方法 | 路径 | 说明 |
|
||||
| -------- | --------------------------------------------- | ------------------------- |
|
||||
| `GET` | `/api/conversations/:id/messages` | 获取消息列表 |
|
||||
| `POST` | `/api/conversations/:id/messages` | 发送消息(对话补全,支持 `stream` 流式) |
|
||||
| `DELETE` | `/api/conversations/:id/messages/:message_id` | 删除消息 |
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,475 @@
|
|||
# 工具调用系统设计
|
||||
|
||||
## 概述
|
||||
|
||||
本文档描述 NanoClaw 工具调用系统的设计,采用简化的工厂模式,减少不必要的类层次。
|
||||
|
||||
---
|
||||
|
||||
## 一、核心类图
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
direction TB
|
||||
|
||||
class ToolDefinition {
|
||||
<<dataclass>>
|
||||
+str name
|
||||
+str description
|
||||
+dict parameters
|
||||
+Callable handler
|
||||
+str category
|
||||
+dict to_openai_format()
|
||||
}
|
||||
|
||||
class ToolRegistry {
|
||||
-dict _tools
|
||||
+register(ToolDefinition tool) void
|
||||
+get(str name) ToolDefinition?
|
||||
+list_all() list~dict~
|
||||
+execute(str name, dict args) Any
|
||||
}
|
||||
|
||||
class ToolExecutor {
|
||||
-ToolRegistry registry
|
||||
+process_tool_calls(list tool_calls) list~dict~
|
||||
+build_request(list messages) dict
|
||||
}
|
||||
|
||||
class ToolResult {
|
||||
<<dataclass>>
|
||||
+bool success
|
||||
+Any data
|
||||
+str? error
|
||||
+dict to_dict()
|
||||
}
|
||||
|
||||
ToolRegistry "1" --> "*" ToolDefinition : manages
|
||||
ToolExecutor "1" --> "1" ToolRegistry : uses
|
||||
ToolDefinition ..> ToolResult : returns
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 二、工具定义工厂
|
||||
|
||||
使用工厂函数创建工具,避免复杂的类继承:
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
direction LR
|
||||
|
||||
class ToolFactory {
|
||||
<<module>>
|
||||
+tool(name, description, parameters)$ decorator
|
||||
+register(name, handler, description, parameters)$ void
|
||||
+create_crawler_tools()$ list~ToolDefinition~
|
||||
+create_data_tools()$ list~ToolDefinition~
|
||||
+create_file_tools()$ list~ToolDefinition~
|
||||
}
|
||||
|
||||
class ToolDefinition {
|
||||
+str name
|
||||
+str description
|
||||
+dict parameters
|
||||
+Callable handler
|
||||
}
|
||||
|
||||
ToolFactory ..> ToolDefinition : creates
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 三、核心类实现
|
||||
|
||||
### 3.1 ToolDefinition
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Any
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict # JSON Schema
|
||||
handler: Callable[[dict], Any]
|
||||
category: str = "general"
|
||||
|
||||
def to_openai_format(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3.2 ToolResult
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""工具执行结果"""
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"success": self.success,
|
||||
"data": self.data,
|
||||
"error": self.error
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def ok(cls, data: Any) -> "ToolResult":
|
||||
return cls(success=True, data=data)
|
||||
|
||||
@classmethod
|
||||
def fail(cls, error: str) -> "ToolResult":
|
||||
return cls(success=False, error=error)
|
||||
```
|
||||
|
||||
### 3.3 ToolRegistry
|
||||
|
||||
```python
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表(单例)"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._tools: Dict[str, ToolDefinition] = {}
|
||||
return cls._instance
|
||||
|
||||
def register(self, tool: ToolDefinition) -> None:
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> Optional[ToolDefinition]:
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_all(self) -> List[dict]:
|
||||
return [t.to_openai_format() for t in self._tools.values()]
|
||||
|
||||
def execute(self, name: str, arguments: dict) -> dict:
|
||||
tool = self.get(name)
|
||||
if not tool:
|
||||
return ToolResult.fail(f"Tool not found: {name}").to_dict()
|
||||
try:
|
||||
result = tool.handler(arguments)
|
||||
if isinstance(result, ToolResult):
|
||||
return result.to_dict()
|
||||
return ToolResult.ok(result).to_dict()
|
||||
except Exception as e:
|
||||
return ToolResult.fail(str(e)).to_dict()
|
||||
|
||||
|
||||
# 全局注册表
|
||||
registry = ToolRegistry()
|
||||
```
|
||||
|
||||
### 3.4 ToolExecutor
|
||||
|
||||
```python
|
||||
import json
|
||||
from typing import List, Dict
|
||||
|
||||
class ToolExecutor:
|
||||
"""工具执行器"""
|
||||
|
||||
def __init__(self, registry: ToolRegistry = None):
|
||||
self.registry = registry or ToolRegistry()
|
||||
|
||||
def process_tool_calls(self, tool_calls: List[dict]) -> List[dict]:
|
||||
"""处理工具调用,返回消息列表"""
|
||||
results = []
|
||||
for call in tool_calls:
|
||||
name = call["function"]["name"]
|
||||
args = json.loads(call["function"]["arguments"])
|
||||
call_id = call["id"]
|
||||
|
||||
result = self.registry.execute(name, args)
|
||||
|
||||
results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": name,
|
||||
"content": json.dumps(result, ensure_ascii=False)
|
||||
})
|
||||
return results
|
||||
|
||||
def build_request(self, messages: List[dict], **kwargs) -> dict:
|
||||
"""构建 API 请求"""
|
||||
return {
|
||||
"model": kwargs.get("model", "glm-5"),
|
||||
"messages": messages,
|
||||
"tools": self.registry.list_all(),
|
||||
"tool_choice": "auto"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 四、工具工厂模式
|
||||
|
||||
### 4.1 装饰器注册
|
||||
|
||||
```python
|
||||
# backend/tools/factory.py
|
||||
|
||||
from .core import ToolDefinition, registry
|
||||
|
||||
def tool(name: str, description: str, parameters: dict, category: str = "general"):
|
||||
"""工具注册装饰器"""
|
||||
def decorator(func):
|
||||
tool_def = ToolDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
handler=func,
|
||||
category=category
|
||||
)
|
||||
registry.register(tool_def)
|
||||
return func
|
||||
return decorator
|
||||
```
|
||||
|
||||
### 4.2 使用示例
|
||||
|
||||
```python
|
||||
# backend/tools/builtin/crawler.py
|
||||
|
||||
from ..factory import tool
|
||||
|
||||
# 网页搜索工具
|
||||
@tool(
|
||||
name="web_search",
|
||||
description="搜索互联网获取信息",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"},
|
||||
"max_results": {"type": "integer", "default": 5}
|
||||
},
|
||||
"required": ["query"]
|
||||
},
|
||||
category="crawler"
|
||||
)
|
||||
def web_search(arguments: dict) -> dict:
|
||||
from ..services import SearchService
|
||||
query = arguments["query"]
|
||||
max_results = arguments.get("max_results", 5)
|
||||
service = SearchService()
|
||||
results = service.search(query, max_results)
|
||||
return {"results": results}
|
||||
|
||||
|
||||
# 页面抓取工具
|
||||
@tool(
|
||||
name="fetch_page",
|
||||
description="抓取指定网页内容",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "网页URL"},
|
||||
"extract_type": {"type": "string", "enum": ["text", "links", "structured"]}
|
||||
},
|
||||
"required": ["url"]
|
||||
},
|
||||
category="crawler"
|
||||
)
|
||||
def fetch_page(arguments: dict) -> dict:
|
||||
from ..services import FetchService
|
||||
url = arguments["url"]
|
||||
extract_type = arguments.get("extract_type", "text")
|
||||
service = FetchService()
|
||||
result = service.fetch(url, extract_type)
|
||||
return result
|
||||
|
||||
|
||||
# 计算器工具
|
||||
@tool(
|
||||
name="calculator",
|
||||
description="执行数学计算",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "数学表达式"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
},
|
||||
category="data"
|
||||
)
|
||||
def calculator(arguments: dict) -> dict:
|
||||
import ast
|
||||
import operator
|
||||
expr = arguments["expression"]
|
||||
# 安全计算
|
||||
ops = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv
|
||||
}
|
||||
node = ast.parse(expr, mode='eval')
|
||||
result = eval(compile(node, '<string>', 'eval'), {"__builtins__": {}}, ops)
|
||||
return {"result": result}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 五、辅助服务类
|
||||
|
||||
工具依赖的服务保持独立,不与工具类耦合:
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
direction LR
|
||||
|
||||
class SearchService {
|
||||
-SearchEngine engine
|
||||
+search(str query, int limit) list~dict~
|
||||
}
|
||||
|
||||
class FetchService {
|
||||
+fetch(str url, str type) dict
|
||||
+fetch_batch(list urls) dict
|
||||
}
|
||||
|
||||
class ContentExtractor {
|
||||
+extract_text(html) str
|
||||
+extract_links(html) list
|
||||
+extract_structured(html) dict
|
||||
}
|
||||
|
||||
FetchService --> ContentExtractor : uses
|
||||
```
|
||||
|
||||
```python
|
||||
# backend/tools/services.py
|
||||
|
||||
class SearchService:
|
||||
"""搜索服务"""
|
||||
def __init__(self, engine=None):
|
||||
from duckduckgo_search import DDGS
|
||||
self.engine = engine or DDGS()
|
||||
|
||||
def search(self, query: str, max_results: int = 5) -> list:
|
||||
results = list(self.engine.text(query, max_results=max_results))
|
||||
return [
|
||||
{"title": r["title"], "url": r["href"], "snippet": r["body"]}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
class FetchService:
|
||||
"""页面抓取服务"""
|
||||
def __init__(self, timeout: float = 30.0):
|
||||
self.timeout = timeout
|
||||
|
||||
def fetch(self, url: str, extract_type: str = "text") -> dict:
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
resp = httpx.get(url, timeout=self.timeout, follow_redirects=True)
|
||||
soup = BeautifulSoup(resp.text, "html.parser")
|
||||
|
||||
extractor = ContentExtractor(soup)
|
||||
if extract_type == "text":
|
||||
return {"text": extractor.extract_text()}
|
||||
elif extract_type == "links":
|
||||
return {"links": extractor.extract_links()}
|
||||
else:
|
||||
return extractor.extract_structured()
|
||||
|
||||
|
||||
class ContentExtractor:
|
||||
"""内容提取器"""
|
||||
def __init__(self, soup):
|
||||
self.soup = soup
|
||||
|
||||
def extract_text(self) -> str:
|
||||
# 移除脚本和样式
|
||||
for tag in self.soup(["script", "style"]):
|
||||
tag.decompose()
|
||||
return self.soup.get_text(separator="\n", strip=True)
|
||||
|
||||
def extract_links(self) -> list:
|
||||
return [
|
||||
{"text": a.get_text(strip=True), "href": a.get("href")}
|
||||
for a in self.soup.find_all("a", href=True)
|
||||
]
|
||||
|
||||
def extract_structured(self) -> dict:
|
||||
return {
|
||||
"title": self.soup.title.string if self.soup.title else "",
|
||||
"text": self.extract_text(),
|
||||
"links": self.extract_links()[:20]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 六、工具初始化
|
||||
|
||||
```python
|
||||
# backend/tools/__init__.py
|
||||
|
||||
from .core import ToolDefinition, ToolResult, ToolRegistry, registry, ToolExecutor
|
||||
from .factory import tool
|
||||
|
||||
def init_tools():
|
||||
"""初始化所有内置工具"""
|
||||
# 导入即自动注册
|
||||
from .builtin import crawler, data, file_ops
|
||||
|
||||
# 使用时
|
||||
init_tools()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 七、工具清单
|
||||
|
||||
| 类别 | 工具名称 | 描述 | 依赖服务 |
|
||||
| ------- | --------------- | ---- | ------------- |
|
||||
| crawler | `web_search` | 网页搜索 | SearchService |
|
||||
| crawler | `fetch_page` | 单页抓取 | FetchService |
|
||||
| crawler | `crawl_batch` | 批量爬取 | FetchService |
|
||||
| data | `calculator` | 数学计算 | - |
|
||||
| data | `data_analysis` | 数据分析 | - |
|
||||
| file | `file_reader` | 文件读取 | - |
|
||||
| file | `file_writer` | 文件写入 | - |
|
||||
|
||||
---
|
||||
|
||||
## 八、与旧设计对比
|
||||
|
||||
| 方面 | 旧设计 | 新设计 |
|
||||
| ----- | ----------------- | --------- |
|
||||
| 类数量 | 30+ | ~10 |
|
||||
| 工具定义 | 继承 BaseTool | 装饰器 + 函数 |
|
||||
| 中间抽象层 | 5个(CrawlerTool 等) | 无 |
|
||||
| 扩展方式 | 创建子类 | 写函数 + 装饰器 |
|
||||
| 代码量 | 多 | 少 |
|
||||
|
||||
---
|
||||
|
||||
## 九、总结
|
||||
|
||||
简化后的设计:
|
||||
|
||||
1. **核心类**:`ToolDefinition`、`ToolRegistry`、`ToolExecutor`、`ToolResult`
|
||||
2. **工厂模式**:使用 `@tool` 装饰器注册工具
|
||||
3. **服务分离**:工具依赖的服务独立,不与工具类耦合
|
||||
4. **易于扩展**:新增工具只需写一个函数并加装饰器
|
||||
Loading…
Reference in New Issue