feat: 初步完成工具调用设计
This commit is contained in:
parent
8639860fb9
commit
e77fd71aa7
|
|
@ -4,6 +4,7 @@ from flask import Flask
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Initialize db BEFORE importing models/routes that depend on it
|
||||||
db = SQLAlchemy()
|
db = SQLAlchemy()
|
||||||
CONFIG_PATH = Path(__file__).parent.parent / "config.yml"
|
CONFIG_PATH = Path(__file__).parent.parent / "config.yml"
|
||||||
|
|
||||||
|
|
@ -26,9 +27,13 @@ def create_app():
|
||||||
|
|
||||||
db.init_app(app)
|
db.init_app(app)
|
||||||
|
|
||||||
|
# Import after db is initialized
|
||||||
from .models import User, Conversation, Message, TokenUsage
|
from .models import User, Conversation, Message, TokenUsage
|
||||||
from .routes import register_routes
|
from .routes import register_routes
|
||||||
|
from .tools import init_tools
|
||||||
|
|
||||||
register_routes(app)
|
register_routes(app)
|
||||||
|
init_tools()
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
db.create_all()
|
db.create_all()
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from flask import request, jsonify, Response, Blueprint, current_app
|
||||||
from . import db
|
from . import db
|
||||||
from .models import Conversation, Message, User, TokenUsage
|
from .models import Conversation, Message, User, TokenUsage
|
||||||
from . import load_config
|
from . import load_config
|
||||||
|
from .tools import registry, ToolExecutor
|
||||||
|
|
||||||
bp = Blueprint("api", __name__)
|
bp = Blueprint("api", __name__)
|
||||||
|
|
||||||
|
|
@ -51,7 +52,7 @@ def to_dict(inst, **extra):
|
||||||
|
|
||||||
|
|
||||||
def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
||||||
"""记录 token 使用量"""
|
"""Record token usage"""
|
||||||
from datetime import date
|
from datetime import date
|
||||||
today = date.today()
|
today = date.today()
|
||||||
usage = TokenUsage.query.filter_by(
|
usage = TokenUsage.query.filter_by(
|
||||||
|
|
@ -75,10 +76,13 @@ def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
||||||
|
|
||||||
|
|
||||||
def build_glm_messages(conv):
|
def build_glm_messages(conv):
|
||||||
|
"""Build messages list for GLM API from conversation"""
|
||||||
msgs = []
|
msgs = []
|
||||||
if conv.system_prompt:
|
if conv.system_prompt:
|
||||||
msgs.append({"role": "system", "content": conv.system_prompt})
|
msgs.append({"role": "system", "content": conv.system_prompt})
|
||||||
for m in conv.messages:
|
# Query messages directly to avoid detached instance warning
|
||||||
|
messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all()
|
||||||
|
for m in messages:
|
||||||
msgs.append({"role": m.role, "content": m.content})
|
msgs.append({"role": m.role, "content": m.content})
|
||||||
return msgs
|
return msgs
|
||||||
|
|
||||||
|
|
@ -87,15 +91,27 @@ def build_glm_messages(conv):
|
||||||
|
|
||||||
@bp.route("/api/models", methods=["GET"])
|
@bp.route("/api/models", methods=["GET"])
|
||||||
def list_models():
|
def list_models():
|
||||||
"""获取可用模型列表"""
|
"""Get available model list"""
|
||||||
return ok(MODELS)
|
return ok(MODELS)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Tools API --------------------------------------------
|
||||||
|
|
||||||
|
@bp.route("/api/tools", methods=["GET"])
|
||||||
|
def list_tools():
|
||||||
|
"""Get available tool list"""
|
||||||
|
tools = registry.list_all()
|
||||||
|
return ok({
|
||||||
|
"tools": tools,
|
||||||
|
"total": len(tools)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
# -- Token Usage Statistics --------------------------------
|
# -- Token Usage Statistics --------------------------------
|
||||||
|
|
||||||
@bp.route("/api/stats/tokens", methods=["GET"])
|
@bp.route("/api/stats/tokens", methods=["GET"])
|
||||||
def token_stats():
|
def token_stats():
|
||||||
"""获取 token 使用统计"""
|
"""Get token usage statistics"""
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from datetime import date, timedelta
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
|
@ -105,7 +121,7 @@ def token_stats():
|
||||||
today = date.today()
|
today = date.today()
|
||||||
|
|
||||||
if period == "daily":
|
if period == "daily":
|
||||||
# 今日统计
|
# Today's statistics
|
||||||
stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all()
|
stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all()
|
||||||
result = {
|
result = {
|
||||||
"period": "daily",
|
"period": "daily",
|
||||||
|
|
@ -116,7 +132,7 @@ def token_stats():
|
||||||
"by_model": {s.model: {"prompt": s.prompt_tokens, "completion": s.completion_tokens, "total": s.total_tokens} for s in stats}
|
"by_model": {s.model: {"prompt": s.prompt_tokens, "completion": s.completion_tokens, "total": s.total_tokens} for s in stats}
|
||||||
}
|
}
|
||||||
elif period == "weekly":
|
elif period == "weekly":
|
||||||
# 本周统计 (最近7天)
|
# Weekly statistics (last 7 days)
|
||||||
start_date = today - timedelta(days=6)
|
start_date = today - timedelta(days=6)
|
||||||
stats = TokenUsage.query.filter(
|
stats = TokenUsage.query.filter(
|
||||||
TokenUsage.user_id == user.id,
|
TokenUsage.user_id == user.id,
|
||||||
|
|
@ -133,7 +149,7 @@ def token_stats():
|
||||||
daily_data[d]["completion"] += s.completion_tokens
|
daily_data[d]["completion"] += s.completion_tokens
|
||||||
daily_data[d]["total"] += s.total_tokens
|
daily_data[d]["total"] += s.total_tokens
|
||||||
|
|
||||||
# 填充没有数据的日期
|
# Fill missing dates
|
||||||
for i in range(7):
|
for i in range(7):
|
||||||
d = (today - timedelta(days=6-i)).isoformat()
|
d = (today - timedelta(days=6-i)).isoformat()
|
||||||
if d not in daily_data:
|
if d not in daily_data:
|
||||||
|
|
@ -149,7 +165,7 @@ def token_stats():
|
||||||
"daily": daily_data
|
"daily": daily_data
|
||||||
}
|
}
|
||||||
elif period == "monthly":
|
elif period == "monthly":
|
||||||
# 本月统计 (最近30天)
|
# Monthly statistics (last 30 days)
|
||||||
start_date = today - timedelta(days=29)
|
start_date = today - timedelta(days=29)
|
||||||
stats = TokenUsage.query.filter(
|
stats = TokenUsage.query.filter(
|
||||||
TokenUsage.user_id == user.id,
|
TokenUsage.user_id == user.id,
|
||||||
|
|
@ -166,7 +182,7 @@ def token_stats():
|
||||||
daily_data[d]["completion"] += s.completion_tokens
|
daily_data[d]["completion"] += s.completion_tokens
|
||||||
daily_data[d]["total"] += s.total_tokens
|
daily_data[d]["total"] += s.total_tokens
|
||||||
|
|
||||||
# 填充没有数据的日期
|
# Fill missing dates
|
||||||
for i in range(30):
|
for i in range(30):
|
||||||
d = (today - timedelta(days=29-i)).isoformat()
|
d = (today - timedelta(days=29-i)).isoformat()
|
||||||
if d not in daily_data:
|
if d not in daily_data:
|
||||||
|
|
@ -301,15 +317,19 @@ def delete_message(conv_id, msg_id):
|
||||||
|
|
||||||
# -- Chat Completion ----------------------------------
|
# -- Chat Completion ----------------------------------
|
||||||
|
|
||||||
def _call_glm(conv, stream=False):
|
def _call_glm(conv, stream=False, tools=None, messages=None):
|
||||||
|
"""Call GLM API"""
|
||||||
body = {
|
body = {
|
||||||
"model": conv.model,
|
"model": conv.model,
|
||||||
"messages": build_glm_messages(conv),
|
"messages": messages if messages is not None else build_glm_messages(conv),
|
||||||
"max_tokens": conv.max_tokens,
|
"max_tokens": conv.max_tokens,
|
||||||
"temperature": conv.temperature,
|
"temperature": conv.temperature,
|
||||||
}
|
}
|
||||||
if conv.thinking_enabled:
|
if conv.thinking_enabled:
|
||||||
body["thinking"] = {"type": "enabled"}
|
body["thinking"] = {"type": "enabled"}
|
||||||
|
if tools:
|
||||||
|
body["tools"] = tools
|
||||||
|
body["tool_choice"] = "auto"
|
||||||
if stream:
|
if stream:
|
||||||
body["stream"] = True
|
body["stream"] = True
|
||||||
return requests.post(
|
return requests.post(
|
||||||
|
|
@ -320,101 +340,192 @@ def _call_glm(conv, stream=False):
|
||||||
|
|
||||||
|
|
||||||
def _sync_response(conv):
|
def _sync_response(conv):
|
||||||
try:
|
"""Sync response with tool call support"""
|
||||||
resp = _call_glm(conv)
|
executor = ToolExecutor(registry=registry)
|
||||||
resp.raise_for_status()
|
tools = registry.list_all()
|
||||||
result = resp.json()
|
messages = build_glm_messages(conv)
|
||||||
except Exception as e:
|
max_iterations = 5 # Max tool call iterations
|
||||||
return err(500, f"upstream error: {e}")
|
|
||||||
|
|
||||||
choice = result["choices"][0]
|
|
||||||
usage = result.get("usage", {})
|
|
||||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
||||||
completion_tokens = usage.get("completion_tokens", 0)
|
|
||||||
|
|
||||||
msg = Message(
|
|
||||||
id=str(uuid.uuid4()), conversation_id=conv.id, role="assistant",
|
|
||||||
content=choice["message"]["content"],
|
|
||||||
token_count=completion_tokens,
|
|
||||||
thinking_content=choice["message"].get("reasoning_content", ""),
|
|
||||||
)
|
|
||||||
db.session.add(msg)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# 记录 token 使用
|
|
||||||
user = get_or_create_default_user()
|
|
||||||
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
return ok({
|
|
||||||
"message": to_dict(msg, thinking_content=msg.thinking_content or None),
|
|
||||||
"usage": {"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": completion_tokens,
|
|
||||||
"total_tokens": usage.get("total_tokens", 0)},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def _stream_response(conv):
|
|
||||||
conv_id = conv.id
|
|
||||||
conv_model = conv.model
|
|
||||||
app = current_app._get_current_object()
|
|
||||||
|
|
||||||
def generate():
|
|
||||||
full_content = ""
|
|
||||||
full_thinking = ""
|
|
||||||
token_count = 0
|
|
||||||
prompt_tokens = 0
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
|
for _ in range(max_iterations):
|
||||||
try:
|
try:
|
||||||
with app.app_context():
|
resp = _call_glm(conv, tools=tools if tools else None, messages=messages)
|
||||||
active_conv = db.session.get(Conversation, conv_id)
|
resp.raise_for_status()
|
||||||
resp = _call_glm(active_conv, stream=True)
|
result = resp.json()
|
||||||
resp.raise_for_status()
|
|
||||||
|
|
||||||
for line in resp.iter_lines():
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
line = line.decode("utf-8")
|
|
||||||
if not line.startswith("data: "):
|
|
||||||
continue
|
|
||||||
data_str = line[6:]
|
|
||||||
if data_str == "[DONE]":
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
chunk = json.loads(data_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
delta = chunk["choices"][0].get("delta", {})
|
|
||||||
reasoning = delta.get("reasoning_content", "")
|
|
||||||
text = delta.get("content", "")
|
|
||||||
if reasoning:
|
|
||||||
full_thinking += reasoning
|
|
||||||
yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n"
|
|
||||||
if text:
|
|
||||||
full_content += text
|
|
||||||
yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n"
|
|
||||||
usage = chunk.get("usage", {})
|
|
||||||
if usage:
|
|
||||||
token_count = usage.get("completion_tokens", 0)
|
|
||||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n"
|
return err(500, f"upstream error: {e}")
|
||||||
return
|
|
||||||
|
choice = result["choices"][0]
|
||||||
|
message = choice["message"]
|
||||||
|
|
||||||
|
# If no tool calls, return final result
|
||||||
|
if not message.get("tool_calls"):
|
||||||
|
usage = result.get("usage", {})
|
||||||
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
completion_tokens = usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
# 流式结束后最后写入数据库
|
|
||||||
with app.app_context():
|
|
||||||
msg = Message(
|
msg = Message(
|
||||||
id=msg_id, conversation_id=conv_id, role="assistant",
|
id=str(uuid.uuid4()), conversation_id=conv.id, role="assistant",
|
||||||
content=full_content, token_count=token_count, thinking_content=full_thinking,
|
content=message.get("content", ""),
|
||||||
|
token_count=completion_tokens,
|
||||||
|
thinking_content=message.get("reasoning_content", ""),
|
||||||
)
|
)
|
||||||
db.session.add(msg)
|
db.session.add(msg)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# 记录 token 使用
|
|
||||||
user = get_or_create_default_user()
|
user = get_or_create_default_user()
|
||||||
record_token_usage(user.id, conv_model, prompt_tokens, token_count)
|
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': token_count})}\n\n"
|
return ok({
|
||||||
|
"message": to_dict(msg, thinking_content=msg.thinking_content or None),
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": usage.get("total_tokens", 0)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process tool calls
|
||||||
|
tool_calls = message["tool_calls"]
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
# Execute tools and add results
|
||||||
|
tool_results = executor.process_tool_calls(tool_calls)
|
||||||
|
messages.extend(tool_results)
|
||||||
|
|
||||||
|
# Save tool call records to database
|
||||||
|
for i, call in enumerate(tool_calls):
|
||||||
|
tool_msg = Message(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
conversation_id=conv.id,
|
||||||
|
role="tool",
|
||||||
|
content=tool_results[i]["content"]
|
||||||
|
)
|
||||||
|
db.session.add(tool_msg)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return err(500, "exceeded maximum tool call iterations")
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_response(conv):
|
||||||
|
"""Stream response with tool call support"""
|
||||||
|
conv_id = conv.id
|
||||||
|
conv_model = conv.model
|
||||||
|
app = current_app._get_current_object()
|
||||||
|
executor = ToolExecutor(registry=registry)
|
||||||
|
tools = registry.list_all()
|
||||||
|
# Build messages BEFORE entering generator (in request context)
|
||||||
|
initial_messages = build_glm_messages(conv)
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
messages = list(initial_messages) # Copy to avoid mutation
|
||||||
|
max_iterations = 5
|
||||||
|
|
||||||
|
for iteration in range(max_iterations):
|
||||||
|
full_content = ""
|
||||||
|
full_thinking = ""
|
||||||
|
token_count = 0
|
||||||
|
prompt_tokens = 0
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
tool_calls_list = []
|
||||||
|
current_tool_call = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with app.app_context():
|
||||||
|
active_conv = db.session.get(Conversation, conv_id)
|
||||||
|
resp = _call_glm(active_conv, stream=True, tools=tools if tools else None, messages=messages)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
for line in resp.iter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
line = line.decode("utf-8")
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk["choices"][0].get("delta", {})
|
||||||
|
|
||||||
|
# Process thinking chain
|
||||||
|
reasoning = delta.get("reasoning_content", "")
|
||||||
|
if reasoning:
|
||||||
|
full_thinking += reasoning
|
||||||
|
yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# Process text content
|
||||||
|
text = delta.get("content", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# Process tool calls
|
||||||
|
tool_calls_delta = delta.get("tool_calls", [])
|
||||||
|
for tc in tool_calls_delta:
|
||||||
|
idx = tc.get("index", 0)
|
||||||
|
if idx >= len(tool_calls_list):
|
||||||
|
tool_calls_list.append({
|
||||||
|
"id": tc.get("id", ""),
|
||||||
|
"type": tc.get("type", "function"),
|
||||||
|
"function": {"name": "", "arguments": ""}
|
||||||
|
})
|
||||||
|
if tc.get("id"):
|
||||||
|
tool_calls_list[idx]["id"] = tc["id"]
|
||||||
|
if tc.get("function"):
|
||||||
|
if tc["function"].get("name"):
|
||||||
|
tool_calls_list[idx]["function"]["name"] = tc["function"]["name"]
|
||||||
|
if tc["function"].get("arguments"):
|
||||||
|
tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"]
|
||||||
|
|
||||||
|
usage = chunk.get("usage", {})
|
||||||
|
if usage:
|
||||||
|
token_count = usage.get("completion_tokens", 0)
|
||||||
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# If tool calls exist, execute and continue loop
|
||||||
|
if tool_calls_list:
|
||||||
|
# Send tool call info
|
||||||
|
yield f"event: tool_calls\ndata: {json.dumps({'calls': tool_calls_list}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# Execute tools
|
||||||
|
tool_results = executor.process_tool_calls(tool_calls_list)
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content or None,
|
||||||
|
"tool_calls": tool_calls_list
|
||||||
|
})
|
||||||
|
messages.extend(tool_results)
|
||||||
|
|
||||||
|
# Send tool results
|
||||||
|
for tr in tool_results:
|
||||||
|
yield f"event: tool_result\ndata: {json.dumps({'name': tr['name'], 'content': tr['content']}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
# No tool calls, finish
|
||||||
|
with app.app_context():
|
||||||
|
msg = Message(
|
||||||
|
id=msg_id, conversation_id=conv_id, role="assistant",
|
||||||
|
content=full_content, token_count=token_count, thinking_content=full_thinking,
|
||||||
|
)
|
||||||
|
db.session.add(msg)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
user = get_or_create_default_user()
|
||||||
|
record_token_usage(user.id, conv_model, prompt_tokens, token_count)
|
||||||
|
|
||||||
|
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': token_count})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
return Response(generate(), mimetype="text/event-stream",
|
return Response(generate(), mimetype="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
"""
|
||||||
|
NanoClaw Tool System
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from backend.tools import registry, ToolExecutor, tool
|
||||||
|
from backend.tools import init_tools
|
||||||
|
|
||||||
|
# Initialize built-in tools
|
||||||
|
init_tools()
|
||||||
|
|
||||||
|
# List all tools
|
||||||
|
tools = registry.list_all()
|
||||||
|
|
||||||
|
# Execute a tool
|
||||||
|
result = registry.execute("web_search", {"query": "Python"})
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .core import ToolDefinition, ToolResult, ToolRegistry, registry
|
||||||
|
from .factory import tool, register_tool
|
||||||
|
from .executor import ToolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
def init_tools() -> None:
|
||||||
|
"""
|
||||||
|
Initialize all built-in tools
|
||||||
|
|
||||||
|
Importing builtin module automatically registers all decorator-defined tools
|
||||||
|
"""
|
||||||
|
from .builtin import crawler, data # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
# Public API exports
|
||||||
|
__all__ = [
|
||||||
|
# Core classes
|
||||||
|
"ToolDefinition",
|
||||||
|
"ToolResult",
|
||||||
|
"ToolRegistry",
|
||||||
|
"ToolExecutor",
|
||||||
|
# Instances
|
||||||
|
"registry",
|
||||||
|
# Factory functions
|
||||||
|
"tool",
|
||||||
|
"register_tool",
|
||||||
|
# Initialization
|
||||||
|
"init_tools",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""Built-in tools"""
|
||||||
|
from .crawler import *
|
||||||
|
from .data import *
|
||||||
|
|
@ -0,0 +1,134 @@
|
||||||
|
"""Crawler related tools"""
|
||||||
|
from ..factory import tool
|
||||||
|
from ..services import SearchService, FetchService
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the internet for information. Use when you need to find latest news or answer questions that require web search.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search keywords"
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of results to return, default 5",
|
||||||
|
"default": 5
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def web_search(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Web search tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"query": "search keywords",
|
||||||
|
"max_results": 5
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"results": [...]}
|
||||||
|
"""
|
||||||
|
query = arguments["query"]
|
||||||
|
max_results = arguments.get("max_results", 5)
|
||||||
|
|
||||||
|
service = SearchService()
|
||||||
|
results = service.search(query, max_results)
|
||||||
|
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="fetch_page",
|
||||||
|
description="Fetch content from a specific webpage. Use when user needs detailed information from a webpage.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "URL of the webpage to fetch"
|
||||||
|
},
|
||||||
|
"extract_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Extraction type",
|
||||||
|
"enum": ["text", "links", "structured"],
|
||||||
|
"default": "text"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def fetch_page(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Page fetch tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"url": "https://example.com",
|
||||||
|
"extract_type": "text" | "links" | "structured"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Page content
|
||||||
|
"""
|
||||||
|
url = arguments["url"]
|
||||||
|
extract_type = arguments.get("extract_type", "text")
|
||||||
|
|
||||||
|
service = FetchService()
|
||||||
|
result = service.fetch(url, extract_type)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="crawl_batch",
|
||||||
|
description="Batch fetch multiple webpages. Use when you need to get content from multiple pages at once.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"urls": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "List of URLs to fetch"
|
||||||
|
},
|
||||||
|
"extract_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["text", "links", "structured"],
|
||||||
|
"default": "text"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["urls"]
|
||||||
|
},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def crawl_batch(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Batch fetch tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"urls": ["url1", "url2", ...],
|
||||||
|
"extract_type": "text"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"results": [...]}
|
||||||
|
"""
|
||||||
|
urls = arguments["urls"]
|
||||||
|
extract_type = arguments.get("extract_type", "text")
|
||||||
|
|
||||||
|
if len(urls) > 10:
|
||||||
|
return {"error": "Maximum 10 pages can be fetched at once"}
|
||||||
|
|
||||||
|
service = FetchService()
|
||||||
|
results = service.fetch_batch(urls, extract_type)
|
||||||
|
|
||||||
|
return {"results": results, "total": len(results)}
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
"""Data processing related tools"""
|
||||||
|
from ..factory import tool
|
||||||
|
from ..services import CalculatorService
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="calculator",
|
||||||
|
description="Perform mathematical calculations. Supports basic arithmetic: addition, subtraction, multiplication, division, power, modulo, etc.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Mathematical expression, e.g.: (2 + 3) * 4, 2 ** 10, 100 / 7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def calculator(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Calculator tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"expression": "2 + 3 * 4"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"result": 14}
|
||||||
|
"""
|
||||||
|
expression = arguments["expression"]
|
||||||
|
service = CalculatorService()
|
||||||
|
return service.evaluate(expression)
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="text_process",
|
||||||
|
description="Process text content, supports counting, format conversion and other operations.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to process"
|
||||||
|
},
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Operation type",
|
||||||
|
"enum": ["count", "lines", "words", "upper", "lower", "reverse"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text", "operation"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def text_process(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Text processing tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"text": "text content",
|
||||||
|
"operation": "count" | "lines" | "words" | ...
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processing result
|
||||||
|
"""
|
||||||
|
text = arguments["text"]
|
||||||
|
operation = arguments["operation"]
|
||||||
|
|
||||||
|
operations = {
|
||||||
|
"count": lambda t: {"count": len(t)},
|
||||||
|
"lines": lambda t: {"lines": len(t.splitlines())},
|
||||||
|
"words": lambda t: {"words": len(t.split())},
|
||||||
|
"upper": lambda t: {"result": t.upper()},
|
||||||
|
"lower": lambda t: {"result": t.lower()},
|
||||||
|
"reverse": lambda t: {"result": t[::-1]}
|
||||||
|
}
|
||||||
|
|
||||||
|
if operation not in operations:
|
||||||
|
return {"error": f"Unknown operation: {operation}"}
|
||||||
|
|
||||||
|
return operations[operation](text)
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="json_process",
|
||||||
|
description="Process JSON data, supports parsing, formatting, extraction and other operations.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"json_string": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON string"
|
||||||
|
},
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Operation type",
|
||||||
|
"enum": ["parse", "format", "keys", "validate"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["json_string", "operation"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def json_process(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
JSON processing tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"json_string": '{"key": "value"}',
|
||||||
|
"operation": "parse" | "format" | "keys" | "validate"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processing result
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
json_string = arguments["json_string"]
|
||||||
|
operation = arguments["operation"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if operation == "validate":
|
||||||
|
json.loads(json_string)
|
||||||
|
return {"valid": True}
|
||||||
|
|
||||||
|
data = json.loads(json_string)
|
||||||
|
|
||||||
|
if operation == "parse":
|
||||||
|
return {"data": data}
|
||||||
|
elif operation == "format":
|
||||||
|
return {"result": json.dumps(data, indent=2, ensure_ascii=False)}
|
||||||
|
elif operation == "keys":
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return {"keys": list(data.keys())}
|
||||||
|
return {"error": "JSON root element is not an object"}
|
||||||
|
else:
|
||||||
|
return {"error": f"Unknown operation: {operation}"}
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return {"error": f"JSON parse error: {str(e)}"}
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
"""Tool system core classes"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolDefinition:
|
||||||
|
"""Tool definition"""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: dict # JSON Schema
|
||||||
|
handler: Callable[[dict], Any]
|
||||||
|
category: str = "general"
|
||||||
|
|
||||||
|
def to_openai_format(self) -> dict:
|
||||||
|
"""Convert to OpenAI/GLM compatible format"""
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": self.parameters
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResult:
|
||||||
|
"""Tool execution result"""
|
||||||
|
success: bool
|
||||||
|
data: Any = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"success": self.success,
|
||||||
|
"data": self.data,
|
||||||
|
"error": self.error
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ok(cls, data: Any) -> "ToolResult":
|
||||||
|
return cls(success=True, data=data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fail(cls, error: str) -> "ToolResult":
|
||||||
|
return cls(success=False, error=error)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""Tool registry (singleton)"""
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._tools: Dict[str, ToolDefinition] = {}
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def register(self, tool: ToolDefinition) -> None:
|
||||||
|
"""Register a tool"""
|
||||||
|
self._tools[tool.name] = tool
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[ToolDefinition]:
|
||||||
|
"""Get tool definition by name"""
|
||||||
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
def list_all(self) -> List[dict]:
|
||||||
|
"""List all tools in OpenAI format"""
|
||||||
|
return [t.to_openai_format() for t in self._tools.values()]
|
||||||
|
|
||||||
|
def list_by_category(self, category: str) -> List[dict]:
|
||||||
|
"""List tools by category"""
|
||||||
|
return [
|
||||||
|
t.to_openai_format()
|
||||||
|
for t in self._tools.values()
|
||||||
|
if t.category == category
|
||||||
|
]
|
||||||
|
|
||||||
|
def execute(self, name: str, arguments: dict) -> dict:
|
||||||
|
"""Execute a tool"""
|
||||||
|
tool = self.get(name)
|
||||||
|
if not tool:
|
||||||
|
return ToolResult.fail(f"Tool not found: {name}").to_dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = tool.handler(arguments)
|
||||||
|
if isinstance(result, ToolResult):
|
||||||
|
return result.to_dict()
|
||||||
|
return ToolResult.ok(result).to_dict()
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(str(e)).to_dict()
|
||||||
|
|
||||||
|
def remove(self, name: str) -> bool:
|
||||||
|
"""Remove a tool"""
|
||||||
|
if name in self._tools:
|
||||||
|
del self._tools[name]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def has(self, name: str) -> bool:
|
||||||
|
"""Check if tool exists"""
|
||||||
|
return name in self._tools
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
@ -0,0 +1,148 @@
|
||||||
|
"""Tool executor"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Optional, Generator, Any
|
||||||
|
from .core import ToolRegistry, registry
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutor:
|
||||||
|
"""Tool call executor"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
registry: Optional[ToolRegistry] = None,
|
||||||
|
api_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
):
|
||||||
|
self.registry = registry or ToolRegistry()
|
||||||
|
self.api_url = api_url
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
def process_tool_calls(
|
||||||
|
self,
|
||||||
|
tool_calls: List[dict],
|
||||||
|
context: Optional[dict] = None
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Process tool calls and return message list
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls: Tool call list returned by LLM
|
||||||
|
context: Optional context info (user_id, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool response message list, can be appended to messages
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for call in tool_calls:
|
||||||
|
name = call["function"]["name"]
|
||||||
|
args_str = call["function"]["arguments"]
|
||||||
|
call_id = call["id"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
results.append(self._create_error_result(
|
||||||
|
call_id, name, "Invalid JSON arguments"
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
result = self.registry.execute(name, args)
|
||||||
|
results.append(self._create_tool_result(call_id, name, result))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _create_tool_result(
|
||||||
|
self,
|
||||||
|
call_id: str,
|
||||||
|
name: str,
|
||||||
|
result: dict,
|
||||||
|
execution_time: float = 0
|
||||||
|
) -> dict:
|
||||||
|
"""Create tool result message"""
|
||||||
|
result["execution_time"] = execution_time
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"name": name,
|
||||||
|
"content": json.dumps(result, ensure_ascii=False, default=str)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_error_result(
|
||||||
|
self,
|
||||||
|
call_id: str,
|
||||||
|
name: str,
|
||||||
|
error: str
|
||||||
|
) -> dict:
|
||||||
|
"""Create error result message"""
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"name": name,
|
||||||
|
"content": json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"error": error
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
}
|
||||||
|
|
||||||
|
def build_request(
|
||||||
|
self,
|
||||||
|
messages: List[dict],
|
||||||
|
model: str = "glm-5",
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Build API request body
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Message list
|
||||||
|
model: Model name
|
||||||
|
tools: Tool list (default: all tools in registry)
|
||||||
|
**kwargs: Other parameters (temperature, max_tokens, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Request body dict
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools or self.registry.list_all(),
|
||||||
|
"tool_choice": kwargs.get("tool_choice", "auto"),
|
||||||
|
**{k: v for k, v in kwargs.items() if k not in ["tool_choice"]}
|
||||||
|
}
|
||||||
|
|
||||||
|
def execute_with_retry(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
arguments: dict,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 1.0
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Execute tool with retry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Tool name
|
||||||
|
arguments: Tool arguments
|
||||||
|
max_retries: Max retry count
|
||||||
|
retry_delay: Retry delay in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Execution result
|
||||||
|
"""
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return self.registry.execute(name, arguments)
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Failed after {max_retries} retries: {last_error}"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Tool factory - decorator registration"""
|
||||||
|
from typing import Callable
|
||||||
|
from .core import ToolDefinition, registry
|
||||||
|
|
||||||
|
|
||||||
|
def tool(
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: dict,
|
||||||
|
category: str = "general"
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
Tool registration decorator
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@tool(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the web",
|
||||||
|
parameters={"type": "object", "properties": {...}},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def web_search(arguments: dict) -> dict:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
tool_def = ToolDefinition(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters,
|
||||||
|
handler=func,
|
||||||
|
category=category
|
||||||
|
)
|
||||||
|
registry.register(tool_def)
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_tool(
|
||||||
|
name: str,
|
||||||
|
handler: Callable,
|
||||||
|
description: str,
|
||||||
|
parameters: dict,
|
||||||
|
category: str = "general"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a tool directly (without decorator)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
register_tool(
|
||||||
|
name="my_tool",
|
||||||
|
handler=my_function,
|
||||||
|
description="Description",
|
||||||
|
parameters={...}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
tool_def = ToolDefinition(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters,
|
||||||
|
handler=handler,
|
||||||
|
category=category
|
||||||
|
)
|
||||||
|
registry.register(tool_def)
|
||||||
|
|
@ -0,0 +1,257 @@
|
||||||
|
"""Tool helper services"""
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class SearchService:
|
||||||
|
"""Search service"""
|
||||||
|
|
||||||
|
def __init__(self, engine: str = "duckduckgo"):
|
||||||
|
self.engine = engine
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
max_results: int = 5,
|
||||||
|
region: str = "cn-zh"
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Execute search
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search keywords
|
||||||
|
max_results: Max result count
|
||||||
|
region: Region setting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Search result list
|
||||||
|
"""
|
||||||
|
if self.engine == "duckduckgo":
|
||||||
|
return self._search_duckduckgo(query, max_results, region)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported search engine: {self.engine}")
|
||||||
|
|
||||||
|
def _search_duckduckgo(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
max_results: int,
|
||||||
|
region: str
|
||||||
|
) -> List[dict]:
|
||||||
|
"""DuckDuckGo search"""
|
||||||
|
try:
|
||||||
|
from duckduckgo_search import DDGS
|
||||||
|
except ImportError:
|
||||||
|
return [{"error": "Please install duckduckgo-search: pip install duckduckgo-search"}]
|
||||||
|
|
||||||
|
with DDGS() as ddgs:
|
||||||
|
results = list(ddgs.text(
|
||||||
|
query,
|
||||||
|
max_results=max_results,
|
||||||
|
region=region
|
||||||
|
))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"title": r.get("title", ""),
|
||||||
|
"url": r.get("href", ""),
|
||||||
|
"snippet": r.get("body", "")
|
||||||
|
}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FetchService:
|
||||||
|
"""Page fetch service"""
|
||||||
|
|
||||||
|
def __init__(self, timeout: float = 30.0, user_agent: str = None):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.user_agent = user_agent or (
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||||
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/120.0.0.0 Safari/537.36"
|
||||||
|
)
|
||||||
|
|
||||||
|
def fetch(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
extract_type: str = "text"
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Fetch a single page
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Page URL
|
||||||
|
extract_type: Extract type (text, links, structured)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fetch result
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = httpx.get(
|
||||||
|
url,
|
||||||
|
timeout=self.timeout,
|
||||||
|
follow_redirects=True,
|
||||||
|
headers={"User-Agent": self.user_agent}
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e), "url": url}
|
||||||
|
|
||||||
|
html = resp.text
|
||||||
|
extractor = ContentExtractor(html)
|
||||||
|
|
||||||
|
if extract_type == "text":
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"text": extractor.extract_text()
|
||||||
|
}
|
||||||
|
elif extract_type == "links":
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"links": extractor.extract_links()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return extractor.extract_structured(url)
|
||||||
|
|
||||||
|
def fetch_batch(
|
||||||
|
self,
|
||||||
|
urls: List[str],
|
||||||
|
extract_type: str = "text",
|
||||||
|
max_concurrent: int = 5
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Batch fetch pages
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urls: URL list
|
||||||
|
extract_type: Extract type
|
||||||
|
max_concurrent: Max concurrent requests
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result list
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for url in urls:
|
||||||
|
results.append(self.fetch(url, extract_type))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class ContentExtractor:
|
||||||
|
"""Content extractor"""
|
||||||
|
|
||||||
|
def __init__(self, html: str):
|
||||||
|
self.html = html
|
||||||
|
self._soup = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def soup(self):
|
||||||
|
if self._soup is None:
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
self._soup = BeautifulSoup(self.html, "html.parser")
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install beautifulsoup4: pip install beautifulsoup4")
|
||||||
|
return self._soup
|
||||||
|
|
||||||
|
def extract_text(self) -> str:
|
||||||
|
"""Extract plain text"""
|
||||||
|
# Remove script and style
|
||||||
|
for tag in self.soup(["script", "style", "nav", "footer", "header"]):
|
||||||
|
tag.decompose()
|
||||||
|
|
||||||
|
text = self.soup.get_text(separator="\n", strip=True)
|
||||||
|
# Clean extra whitespace
|
||||||
|
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def extract_links(self) -> List[dict]:
|
||||||
|
"""Extract links"""
|
||||||
|
links = []
|
||||||
|
for a in self.soup.find_all("a", href=True):
|
||||||
|
text = a.get_text(strip=True)
|
||||||
|
href = a["href"]
|
||||||
|
if text and href and not href.startswith(("#", "javascript:")):
|
||||||
|
links.append({"text": text, "href": href})
|
||||||
|
return links[:50] # Limit count
|
||||||
|
|
||||||
|
def extract_structured(self, url: str = "") -> dict:
|
||||||
|
"""Extract structured content"""
|
||||||
|
soup = self.soup
|
||||||
|
|
||||||
|
# Extract title
|
||||||
|
title = ""
|
||||||
|
if soup.title:
|
||||||
|
title = soup.title.string or ""
|
||||||
|
|
||||||
|
# Extract meta description
|
||||||
|
description = ""
|
||||||
|
meta_desc = soup.find("meta", attrs={"name": "description"})
|
||||||
|
if meta_desc:
|
||||||
|
description = meta_desc.get("content", "")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"title": title.strip(),
|
||||||
|
"description": description.strip(),
|
||||||
|
"text": self.extract_text()[:5000], # Limit length
|
||||||
|
"links": self.extract_links()[:20]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CalculatorService:
|
||||||
|
"""Safe calculation service"""
|
||||||
|
|
||||||
|
ALLOWED_OPS = {
|
||||||
|
"add", "sub", "mul", "truediv", "floordiv",
|
||||||
|
"mod", "pow", "neg", "abs"
|
||||||
|
}
|
||||||
|
|
||||||
|
def evaluate(self, expression: str) -> dict:
|
||||||
|
"""
|
||||||
|
Safely evaluate mathematical expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: Mathematical expression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Calculation result
|
||||||
|
"""
|
||||||
|
import ast
|
||||||
|
import operator
|
||||||
|
|
||||||
|
ops = {
|
||||||
|
ast.Add: operator.add,
|
||||||
|
ast.Sub: operator.sub,
|
||||||
|
ast.Mult: operator.mul,
|
||||||
|
ast.Div: operator.truediv,
|
||||||
|
ast.FloorDiv: operator.floordiv,
|
||||||
|
ast.Mod: operator.mod,
|
||||||
|
ast.Pow: operator.pow,
|
||||||
|
ast.USub: operator.neg,
|
||||||
|
ast.UAdd: operator.pos,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse expression
|
||||||
|
node = ast.parse(expression, mode="eval")
|
||||||
|
|
||||||
|
# Validate node types
|
||||||
|
for child in ast.walk(node):
|
||||||
|
if isinstance(child, ast.Call):
|
||||||
|
return {"error": "Function calls not allowed"}
|
||||||
|
if isinstance(child, ast.Name):
|
||||||
|
return {"error": "Variable names not allowed"}
|
||||||
|
|
||||||
|
# Safe execution
|
||||||
|
result = eval(
|
||||||
|
compile(node, "<string>", "eval"),
|
||||||
|
{"__builtins__": {}},
|
||||||
|
{}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"result": result}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Calculation error: {str(e)}"}
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -4,21 +4,21 @@
|
||||||
|
|
||||||
### 会话管理
|
### 会话管理
|
||||||
|
|
||||||
| 方法 | 路径 | 说明 |
|
| 方法 | 路径 | 说明 |
|
||||||
|------|------|------|
|
| -------- | ------------------------ | ------ |
|
||||||
| `POST` | `/api/conversations` | 创建会话 |
|
| `POST` | `/api/conversations` | 创建会话 |
|
||||||
| `GET` | `/api/conversations` | 获取会话列表 |
|
| `GET` | `/api/conversations` | 获取会话列表 |
|
||||||
| `GET` | `/api/conversations/:id` | 获取会话详情 |
|
| `GET` | `/api/conversations/:id` | 获取会话详情 |
|
||||||
| `PATCH` | `/api/conversations/:id` | 更新会话 |
|
| `PATCH` | `/api/conversations/:id` | 更新会话 |
|
||||||
| `DELETE` | `/api/conversations/:id` | 删除会话 |
|
| `DELETE` | `/api/conversations/:id` | 删除会话 |
|
||||||
|
|
||||||
### 消息管理
|
### 消息管理
|
||||||
|
|
||||||
| 方法 | 路径 | 说明 |
|
| 方法 | 路径 | 说明 |
|
||||||
|------|------|------|
|
| -------- | --------------------------------------------- | ------------------------- |
|
||||||
| `GET` | `/api/conversations/:id/messages` | 获取消息列表 |
|
| `GET` | `/api/conversations/:id/messages` | 获取消息列表 |
|
||||||
| `POST` | `/api/conversations/:id/messages` | 发送消息(对话补全,支持 `stream` 流式) |
|
| `POST` | `/api/conversations/:id/messages` | 发送消息(对话补全,支持 `stream` 流式) |
|
||||||
| `DELETE` | `/api/conversations/:id/messages/:message_id` | 删除消息 |
|
| `DELETE` | `/api/conversations/:id/messages/:message_id` | 删除消息 |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,475 @@
|
||||||
|
# 工具调用系统设计
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
本文档描述 NanoClaw 工具调用系统的设计,采用简化的工厂模式,减少不必要的类层次。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 一、核心类图
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
classDiagram
|
||||||
|
direction TB
|
||||||
|
|
||||||
|
class ToolDefinition {
|
||||||
|
<<dataclass>>
|
||||||
|
+str name
|
||||||
|
+str description
|
||||||
|
+dict parameters
|
||||||
|
+Callable handler
|
||||||
|
+str category
|
||||||
|
+dict to_openai_format()
|
||||||
|
}
|
||||||
|
|
||||||
|
class ToolRegistry {
|
||||||
|
-dict _tools
|
||||||
|
+register(ToolDefinition tool) void
|
||||||
|
+get(str name) ToolDefinition?
|
||||||
|
+list_all() list~dict~
|
||||||
|
+execute(str name, dict args) Any
|
||||||
|
}
|
||||||
|
|
||||||
|
class ToolExecutor {
|
||||||
|
-ToolRegistry registry
|
||||||
|
+process_tool_calls(list tool_calls) list~dict~
|
||||||
|
+build_request(list messages) dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class ToolResult {
|
||||||
|
<<dataclass>>
|
||||||
|
+bool success
|
||||||
|
+Any data
|
||||||
|
+str? error
|
||||||
|
+dict to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
ToolRegistry "1" --> "*" ToolDefinition : manages
|
||||||
|
ToolExecutor "1" --> "1" ToolRegistry : uses
|
||||||
|
ToolDefinition ..> ToolResult : returns
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、工具定义工厂
|
||||||
|
|
||||||
|
使用工厂函数创建工具,避免复杂的类继承:
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
classDiagram
|
||||||
|
direction LR
|
||||||
|
|
||||||
|
class ToolFactory {
|
||||||
|
<<module>>
|
||||||
|
+tool(name, description, parameters)$ decorator
|
||||||
|
+register(name, handler, description, parameters)$ void
|
||||||
|
+create_crawler_tools()$ list~ToolDefinition~
|
||||||
|
+create_data_tools()$ list~ToolDefinition~
|
||||||
|
+create_file_tools()$ list~ToolDefinition~
|
||||||
|
}
|
||||||
|
|
||||||
|
class ToolDefinition {
|
||||||
|
+str name
|
||||||
|
+str description
|
||||||
|
+dict parameters
|
||||||
|
+Callable handler
|
||||||
|
}
|
||||||
|
|
||||||
|
ToolFactory ..> ToolDefinition : creates
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、核心类实现
|
||||||
|
|
||||||
|
### 3.1 ToolDefinition
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Any
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolDefinition:
|
||||||
|
"""工具定义"""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: dict # JSON Schema
|
||||||
|
handler: Callable[[dict], Any]
|
||||||
|
category: str = "general"
|
||||||
|
|
||||||
|
def to_openai_format(self) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": self.parameters
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 ToolResult
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResult:
|
||||||
|
"""工具执行结果"""
|
||||||
|
success: bool
|
||||||
|
data: Any = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"success": self.success,
|
||||||
|
"data": self.data,
|
||||||
|
"error": self.error
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ok(cls, data: Any) -> "ToolResult":
|
||||||
|
return cls(success=True, data=data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fail(cls, error: str) -> "ToolResult":
|
||||||
|
return cls(success=False, error=error)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.3 ToolRegistry
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""工具注册表(单例)"""
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._tools: Dict[str, ToolDefinition] = {}
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def register(self, tool: ToolDefinition) -> None:
|
||||||
|
self._tools[tool.name] = tool
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[ToolDefinition]:
|
||||||
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
def list_all(self) -> List[dict]:
|
||||||
|
return [t.to_openai_format() for t in self._tools.values()]
|
||||||
|
|
||||||
|
def execute(self, name: str, arguments: dict) -> dict:
|
||||||
|
tool = self.get(name)
|
||||||
|
if not tool:
|
||||||
|
return ToolResult.fail(f"Tool not found: {name}").to_dict()
|
||||||
|
try:
|
||||||
|
result = tool.handler(arguments)
|
||||||
|
if isinstance(result, ToolResult):
|
||||||
|
return result.to_dict()
|
||||||
|
return ToolResult.ok(result).to_dict()
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(str(e)).to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
# 全局注册表
|
||||||
|
registry = ToolRegistry()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.4 ToolExecutor
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
class ToolExecutor:
|
||||||
|
"""工具执行器"""
|
||||||
|
|
||||||
|
def __init__(self, registry: ToolRegistry = None):
|
||||||
|
self.registry = registry or ToolRegistry()
|
||||||
|
|
||||||
|
def process_tool_calls(self, tool_calls: List[dict]) -> List[dict]:
|
||||||
|
"""处理工具调用,返回消息列表"""
|
||||||
|
results = []
|
||||||
|
for call in tool_calls:
|
||||||
|
name = call["function"]["name"]
|
||||||
|
args = json.loads(call["function"]["arguments"])
|
||||||
|
call_id = call["id"]
|
||||||
|
|
||||||
|
result = self.registry.execute(name, args)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"name": name,
|
||||||
|
"content": json.dumps(result, ensure_ascii=False)
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
|
def build_request(self, messages: List[dict], **kwargs) -> dict:
|
||||||
|
"""构建 API 请求"""
|
||||||
|
return {
|
||||||
|
"model": kwargs.get("model", "glm-5"),
|
||||||
|
"messages": messages,
|
||||||
|
"tools": self.registry.list_all(),
|
||||||
|
"tool_choice": "auto"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、工具工厂模式
|
||||||
|
|
||||||
|
### 4.1 装饰器注册
|
||||||
|
|
||||||
|
```python
|
||||||
|
# backend/tools/factory.py
|
||||||
|
|
||||||
|
from .core import ToolDefinition, registry
|
||||||
|
|
||||||
|
def tool(name: str, description: str, parameters: dict, category: str = "general"):
|
||||||
|
"""工具注册装饰器"""
|
||||||
|
def decorator(func):
|
||||||
|
tool_def = ToolDefinition(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters,
|
||||||
|
handler=func,
|
||||||
|
category=category
|
||||||
|
)
|
||||||
|
registry.register(tool_def)
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 使用示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# backend/tools/builtin/crawler.py
|
||||||
|
|
||||||
|
from ..factory import tool
|
||||||
|
|
||||||
|
# 网页搜索工具
|
||||||
|
@tool(
|
||||||
|
name="web_search",
|
||||||
|
description="搜索互联网获取信息",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "搜索关键词"},
|
||||||
|
"max_results": {"type": "integer", "default": 5}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def web_search(arguments: dict) -> dict:
|
||||||
|
from ..services import SearchService
|
||||||
|
query = arguments["query"]
|
||||||
|
max_results = arguments.get("max_results", 5)
|
||||||
|
service = SearchService()
|
||||||
|
results = service.search(query, max_results)
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
|
||||||
|
# 页面抓取工具
|
||||||
|
@tool(
|
||||||
|
name="fetch_page",
|
||||||
|
description="抓取指定网页内容",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {"type": "string", "description": "网页URL"},
|
||||||
|
"extract_type": {"type": "string", "enum": ["text", "links", "structured"]}
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def fetch_page(arguments: dict) -> dict:
|
||||||
|
from ..services import FetchService
|
||||||
|
url = arguments["url"]
|
||||||
|
extract_type = arguments.get("extract_type", "text")
|
||||||
|
service = FetchService()
|
||||||
|
result = service.fetch(url, extract_type)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 计算器工具
|
||||||
|
@tool(
|
||||||
|
name="calculator",
|
||||||
|
description="执行数学计算",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {"type": "string", "description": "数学表达式"}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def calculator(arguments: dict) -> dict:
|
||||||
|
import ast
|
||||||
|
import operator
|
||||||
|
expr = arguments["expression"]
|
||||||
|
# 安全计算
|
||||||
|
ops = {
|
||||||
|
ast.Add: operator.add,
|
||||||
|
ast.Sub: operator.sub,
|
||||||
|
ast.Mult: operator.mul,
|
||||||
|
ast.Div: operator.truediv
|
||||||
|
}
|
||||||
|
node = ast.parse(expr, mode='eval')
|
||||||
|
result = eval(compile(node, '<string>', 'eval'), {"__builtins__": {}}, ops)
|
||||||
|
return {"result": result}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、辅助服务类
|
||||||
|
|
||||||
|
工具依赖的服务保持独立,不与工具类耦合:
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
classDiagram
|
||||||
|
direction LR
|
||||||
|
|
||||||
|
class SearchService {
|
||||||
|
-SearchEngine engine
|
||||||
|
+search(str query, int limit) list~dict~
|
||||||
|
}
|
||||||
|
|
||||||
|
class FetchService {
|
||||||
|
+fetch(str url, str type) dict
|
||||||
|
+fetch_batch(list urls) dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class ContentExtractor {
|
||||||
|
+extract_text(html) str
|
||||||
|
+extract_links(html) list
|
||||||
|
+extract_structured(html) dict
|
||||||
|
}
|
||||||
|
|
||||||
|
FetchService --> ContentExtractor : uses
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# backend/tools/services.py
|
||||||
|
|
||||||
|
class SearchService:
|
||||||
|
"""搜索服务"""
|
||||||
|
def __init__(self, engine=None):
|
||||||
|
from duckduckgo_search import DDGS
|
||||||
|
self.engine = engine or DDGS()
|
||||||
|
|
||||||
|
def search(self, query: str, max_results: int = 5) -> list:
|
||||||
|
results = list(self.engine.text(query, max_results=max_results))
|
||||||
|
return [
|
||||||
|
{"title": r["title"], "url": r["href"], "snippet": r["body"]}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FetchService:
|
||||||
|
"""页面抓取服务"""
|
||||||
|
def __init__(self, timeout: float = 30.0):
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def fetch(self, url: str, extract_type: str = "text") -> dict:
|
||||||
|
import httpx
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
resp = httpx.get(url, timeout=self.timeout, follow_redirects=True)
|
||||||
|
soup = BeautifulSoup(resp.text, "html.parser")
|
||||||
|
|
||||||
|
extractor = ContentExtractor(soup)
|
||||||
|
if extract_type == "text":
|
||||||
|
return {"text": extractor.extract_text()}
|
||||||
|
elif extract_type == "links":
|
||||||
|
return {"links": extractor.extract_links()}
|
||||||
|
else:
|
||||||
|
return extractor.extract_structured()
|
||||||
|
|
||||||
|
|
||||||
|
class ContentExtractor:
|
||||||
|
"""内容提取器"""
|
||||||
|
def __init__(self, soup):
|
||||||
|
self.soup = soup
|
||||||
|
|
||||||
|
def extract_text(self) -> str:
|
||||||
|
# 移除脚本和样式
|
||||||
|
for tag in self.soup(["script", "style"]):
|
||||||
|
tag.decompose()
|
||||||
|
return self.soup.get_text(separator="\n", strip=True)
|
||||||
|
|
||||||
|
def extract_links(self) -> list:
|
||||||
|
return [
|
||||||
|
{"text": a.get_text(strip=True), "href": a.get("href")}
|
||||||
|
for a in self.soup.find_all("a", href=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract_structured(self) -> dict:
|
||||||
|
return {
|
||||||
|
"title": self.soup.title.string if self.soup.title else "",
|
||||||
|
"text": self.extract_text(),
|
||||||
|
"links": self.extract_links()[:20]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、工具初始化
|
||||||
|
|
||||||
|
```python
|
||||||
|
# backend/tools/__init__.py
|
||||||
|
|
||||||
|
from .core import ToolDefinition, ToolResult, ToolRegistry, registry, ToolExecutor
|
||||||
|
from .factory import tool
|
||||||
|
|
||||||
|
def init_tools():
|
||||||
|
"""初始化所有内置工具"""
|
||||||
|
# 导入即自动注册
|
||||||
|
from .builtin import crawler, data, file_ops
|
||||||
|
|
||||||
|
# 使用时
|
||||||
|
init_tools()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 七、工具清单
|
||||||
|
|
||||||
|
| 类别 | 工具名称 | 描述 | 依赖服务 |
|
||||||
|
| ------- | --------------- | ---- | ------------- |
|
||||||
|
| crawler | `web_search` | 网页搜索 | SearchService |
|
||||||
|
| crawler | `fetch_page` | 单页抓取 | FetchService |
|
||||||
|
| crawler | `crawl_batch` | 批量爬取 | FetchService |
|
||||||
|
| data | `calculator` | 数学计算 | - |
|
||||||
|
| data | `data_analysis` | 数据分析 | - |
|
||||||
|
| file | `file_reader` | 文件读取 | - |
|
||||||
|
| file | `file_writer` | 文件写入 | - |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 八、与旧设计对比
|
||||||
|
|
||||||
|
| 方面 | 旧设计 | 新设计 |
|
||||||
|
| ----- | ----------------- | --------- |
|
||||||
|
| 类数量 | 30+ | ~10 |
|
||||||
|
| 工具定义 | 继承 BaseTool | 装饰器 + 函数 |
|
||||||
|
| 中间抽象层 | 5个(CrawlerTool 等) | 无 |
|
||||||
|
| 扩展方式 | 创建子类 | 写函数 + 装饰器 |
|
||||||
|
| 代码量 | 多 | 少 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 九、总结
|
||||||
|
|
||||||
|
简化后的设计:
|
||||||
|
|
||||||
|
1. **核心类**:`ToolDefinition`、`ToolRegistry`、`ToolExecutor`、`ToolResult`
|
||||||
|
2. **工厂模式**:使用 `@tool` 装饰器注册工具
|
||||||
|
3. **服务分离**:工具依赖的服务独立,不与工具类耦合
|
||||||
|
4. **易于扩展**:新增工具只需写一个函数并加装饰器
|
||||||
Loading…
Reference in New Issue