refactor: 优化后端结构
This commit is contained in:
parent
dc70a4a1f2
commit
362ab15338
|
|
@ -32,9 +32,9 @@ def create_app():
|
||||||
db.init_app(app)
|
db.init_app(app)
|
||||||
|
|
||||||
# Import after db is initialized
|
# Import after db is initialized
|
||||||
from .models import User, Conversation, Message, TokenUsage
|
from backend.models import User, Conversation, Message, TokenUsage
|
||||||
from .routes import register_routes
|
from backend.routes import register_routes
|
||||||
from .tools import init_tools
|
from backend.tools import init_tools
|
||||||
|
|
||||||
register_routes(app)
|
register_routes(app)
|
||||||
init_tools()
|
init_tools()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""Configuration management"""
|
||||||
|
from backend import load_config
|
||||||
|
|
||||||
|
_cfg = load_config()
|
||||||
|
|
||||||
|
API_URL = _cfg.get("api_url")
|
||||||
|
API_KEY = _cfg["api_key"]
|
||||||
|
MODELS = _cfg.get("models", [])
|
||||||
|
DEFAULT_MODEL = _cfg.get("default_model", "glm-5")
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||||
from . import db
|
from backend import db
|
||||||
|
|
||||||
|
|
||||||
class User(db.Model):
|
class User(db.Model):
|
||||||
|
|
|
||||||
|
|
@ -1,588 +0,0 @@
|
||||||
import uuid
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import requests
|
|
||||||
from datetime import datetime
|
|
||||||
from flask import request, jsonify, Response, Blueprint, current_app
|
|
||||||
from . import db
|
|
||||||
from .models import Conversation, Message, User, TokenUsage
|
|
||||||
from . import load_config
|
|
||||||
from .tools import registry, ToolExecutor
|
|
||||||
|
|
||||||
bp = Blueprint("api", __name__)
|
|
||||||
|
|
||||||
cfg = load_config()
|
|
||||||
API_URL = cfg.get("api_url")
|
|
||||||
API_KEY = cfg["api_key"]
|
|
||||||
MODELS = cfg.get("models", [])
|
|
||||||
DEFAULT_MODEL = cfg.get("default_model", "glm-5")
|
|
||||||
|
|
||||||
|
|
||||||
# -- Helpers ----------------------------------------------
|
|
||||||
|
|
||||||
def get_or_create_default_user():
|
|
||||||
user = User.query.filter_by(username="default").first()
|
|
||||||
if not user:
|
|
||||||
user = User(username="default", password="")
|
|
||||||
db.session.add(user)
|
|
||||||
db.session.commit()
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
def ok(data=None, message=None):
|
|
||||||
body = {"code": 0}
|
|
||||||
if data is not None:
|
|
||||||
body["data"] = data
|
|
||||||
if message is not None:
|
|
||||||
body["message"] = message
|
|
||||||
return jsonify(body)
|
|
||||||
|
|
||||||
|
|
||||||
def err(code, message):
|
|
||||||
return jsonify({"code": code, "message": message}), code
|
|
||||||
|
|
||||||
|
|
||||||
def to_dict(inst, **extra):
|
|
||||||
d = {c.name: getattr(inst, c.name) for c in inst.__table__.columns}
|
|
||||||
for k in ("created_at", "updated_at"):
|
|
||||||
if k in d and hasattr(d[k], "strftime"):
|
|
||||||
d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
||||||
|
|
||||||
# Parse tool_calls JSON if present
|
|
||||||
if "tool_calls" in d and d["tool_calls"]:
|
|
||||||
try:
|
|
||||||
d["tool_calls"] = json.loads(d["tool_calls"])
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Filter out None values for cleaner API response
|
|
||||||
d = {k: v for k, v in d.items() if v is not None}
|
|
||||||
|
|
||||||
d.update(extra)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
|
||||||
"""Record token usage"""
|
|
||||||
from datetime import date
|
|
||||||
today = date.today()
|
|
||||||
usage = TokenUsage.query.filter_by(
|
|
||||||
user_id=user_id, date=today, model=model
|
|
||||||
).first()
|
|
||||||
if usage:
|
|
||||||
usage.prompt_tokens += prompt_tokens
|
|
||||||
usage.completion_tokens += completion_tokens
|
|
||||||
usage.total_tokens += prompt_tokens + completion_tokens
|
|
||||||
else:
|
|
||||||
usage = TokenUsage(
|
|
||||||
user_id=user_id,
|
|
||||||
date=today,
|
|
||||||
model=model,
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
db.session.add(usage)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def build_glm_messages(conv):
|
|
||||||
"""Build messages list for GLM API from conversation"""
|
|
||||||
msgs = []
|
|
||||||
if conv.system_prompt:
|
|
||||||
msgs.append({"role": "system", "content": conv.system_prompt})
|
|
||||||
# Query messages directly to avoid detached instance warning
|
|
||||||
messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all()
|
|
||||||
for m in messages:
|
|
||||||
msgs.append({"role": m.role, "content": m.content})
|
|
||||||
return msgs
|
|
||||||
|
|
||||||
|
|
||||||
# -- Models API -------------------------------------------
|
|
||||||
|
|
||||||
@bp.route("/api/models", methods=["GET"])
|
|
||||||
def list_models():
|
|
||||||
"""Get available model list"""
|
|
||||||
return ok(MODELS)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Tools API --------------------------------------------
|
|
||||||
|
|
||||||
@bp.route("/api/tools", methods=["GET"])
|
|
||||||
def list_tools():
|
|
||||||
"""Get available tool list"""
|
|
||||||
tools = registry.list_all()
|
|
||||||
return ok({
|
|
||||||
"tools": tools,
|
|
||||||
"total": len(tools)
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
# -- Token Usage Statistics --------------------------------
|
|
||||||
|
|
||||||
@bp.route("/api/stats/tokens", methods=["GET"])
|
|
||||||
def token_stats():
|
|
||||||
"""Get token usage statistics"""
|
|
||||||
from sqlalchemy import func
|
|
||||||
from datetime import date, timedelta
|
|
||||||
|
|
||||||
user = get_or_create_default_user()
|
|
||||||
period = request.args.get("period", "daily") # daily, weekly, monthly
|
|
||||||
|
|
||||||
today = date.today()
|
|
||||||
|
|
||||||
if period == "daily":
|
|
||||||
# Today's statistics
|
|
||||||
stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all()
|
|
||||||
result = {
|
|
||||||
"period": "daily",
|
|
||||||
"date": today.isoformat(),
|
|
||||||
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
|
||||||
"completion_tokens": sum(s.completion_tokens for s in stats),
|
|
||||||
"total_tokens": sum(s.total_tokens for s in stats),
|
|
||||||
"by_model": {s.model: {"prompt": s.prompt_tokens, "completion": s.completion_tokens, "total": s.total_tokens} for s in stats}
|
|
||||||
}
|
|
||||||
elif period == "weekly":
|
|
||||||
# Weekly statistics (last 7 days)
|
|
||||||
start_date = today - timedelta(days=6)
|
|
||||||
stats = TokenUsage.query.filter(
|
|
||||||
TokenUsage.user_id == user.id,
|
|
||||||
TokenUsage.date >= start_date,
|
|
||||||
TokenUsage.date <= today
|
|
||||||
).all()
|
|
||||||
|
|
||||||
daily_data = {}
|
|
||||||
for s in stats:
|
|
||||||
d = s.date.isoformat()
|
|
||||||
if d not in daily_data:
|
|
||||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
|
||||||
daily_data[d]["prompt"] += s.prompt_tokens
|
|
||||||
daily_data[d]["completion"] += s.completion_tokens
|
|
||||||
daily_data[d]["total"] += s.total_tokens
|
|
||||||
|
|
||||||
# Fill missing dates
|
|
||||||
for i in range(7):
|
|
||||||
d = (today - timedelta(days=6-i)).isoformat()
|
|
||||||
if d not in daily_data:
|
|
||||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"period": "weekly",
|
|
||||||
"start_date": start_date.isoformat(),
|
|
||||||
"end_date": today.isoformat(),
|
|
||||||
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
|
||||||
"completion_tokens": sum(s.completion_tokens for s in stats),
|
|
||||||
"total_tokens": sum(s.total_tokens for s in stats),
|
|
||||||
"daily": daily_data
|
|
||||||
}
|
|
||||||
elif period == "monthly":
|
|
||||||
# Monthly statistics (last 30 days)
|
|
||||||
start_date = today - timedelta(days=29)
|
|
||||||
stats = TokenUsage.query.filter(
|
|
||||||
TokenUsage.user_id == user.id,
|
|
||||||
TokenUsage.date >= start_date,
|
|
||||||
TokenUsage.date <= today
|
|
||||||
).all()
|
|
||||||
|
|
||||||
daily_data = {}
|
|
||||||
for s in stats:
|
|
||||||
d = s.date.isoformat()
|
|
||||||
if d not in daily_data:
|
|
||||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
|
||||||
daily_data[d]["prompt"] += s.prompt_tokens
|
|
||||||
daily_data[d]["completion"] += s.completion_tokens
|
|
||||||
daily_data[d]["total"] += s.total_tokens
|
|
||||||
|
|
||||||
# Fill missing dates
|
|
||||||
for i in range(30):
|
|
||||||
d = (today - timedelta(days=29-i)).isoformat()
|
|
||||||
if d not in daily_data:
|
|
||||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"period": "monthly",
|
|
||||||
"start_date": start_date.isoformat(),
|
|
||||||
"end_date": today.isoformat(),
|
|
||||||
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
|
||||||
"completion_tokens": sum(s.completion_tokens for s in stats),
|
|
||||||
"total_tokens": sum(s.total_tokens for s in stats),
|
|
||||||
"daily": daily_data
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return err(400, "invalid period")
|
|
||||||
|
|
||||||
return ok(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Conversation CRUD ------------------------------------
|
|
||||||
|
|
||||||
@bp.route("/api/conversations", methods=["GET", "POST"])
|
|
||||||
def conversation_list():
|
|
||||||
if request.method == "POST":
|
|
||||||
d = request.json or {}
|
|
||||||
user = get_or_create_default_user()
|
|
||||||
conv = Conversation(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user.id,
|
|
||||||
title=d.get("title", ""),
|
|
||||||
model=d.get("model", DEFAULT_MODEL),
|
|
||||||
system_prompt=d.get("system_prompt", ""),
|
|
||||||
temperature=d.get("temperature", 1.0),
|
|
||||||
max_tokens=d.get("max_tokens", 65536),
|
|
||||||
thinking_enabled=d.get("thinking_enabled", False),
|
|
||||||
)
|
|
||||||
db.session.add(conv)
|
|
||||||
db.session.commit()
|
|
||||||
return ok(to_dict(conv))
|
|
||||||
|
|
||||||
cursor = request.args.get("cursor")
|
|
||||||
limit = min(int(request.args.get("limit", 20)), 100)
|
|
||||||
user = get_or_create_default_user()
|
|
||||||
q = Conversation.query.filter_by(user_id=user.id)
|
|
||||||
if cursor:
|
|
||||||
q = q.filter(Conversation.updated_at < (
|
|
||||||
db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
|
||||||
rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all()
|
|
||||||
|
|
||||||
items = [to_dict(r, message_count=r.messages.count()) for r in rows[:limit]]
|
|
||||||
return ok({
|
|
||||||
"items": items,
|
|
||||||
"next_cursor": items[-1]["id"] if len(rows) > limit else None,
|
|
||||||
"has_more": len(rows) > limit,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/api/conversations/<conv_id>", methods=["GET", "PATCH", "DELETE"])
|
|
||||||
def conversation_detail(conv_id):
|
|
||||||
conv = db.session.get(Conversation, conv_id)
|
|
||||||
if not conv:
|
|
||||||
return err(404, "conversation not found")
|
|
||||||
|
|
||||||
if request.method == "GET":
|
|
||||||
return ok(to_dict(conv))
|
|
||||||
|
|
||||||
if request.method == "DELETE":
|
|
||||||
db.session.delete(conv)
|
|
||||||
db.session.commit()
|
|
||||||
return ok(message="deleted")
|
|
||||||
|
|
||||||
d = request.json or {}
|
|
||||||
for k in ("title", "model", "system_prompt", "temperature", "max_tokens", "thinking_enabled"):
|
|
||||||
if k in d:
|
|
||||||
setattr(conv, k, d[k])
|
|
||||||
db.session.commit()
|
|
||||||
return ok(to_dict(conv))
|
|
||||||
|
|
||||||
|
|
||||||
# -- Messages ---------------------------------------------
|
|
||||||
|
|
||||||
@bp.route("/api/conversations/<conv_id>/messages", methods=["GET", "POST"])
|
|
||||||
def message_list(conv_id):
|
|
||||||
conv = db.session.get(Conversation, conv_id)
|
|
||||||
if not conv:
|
|
||||||
return err(404, "conversation not found")
|
|
||||||
|
|
||||||
if request.method == "GET":
|
|
||||||
cursor = request.args.get("cursor")
|
|
||||||
limit = min(int(request.args.get("limit", 50)), 100)
|
|
||||||
q = Message.query.filter_by(conversation_id=conv_id)
|
|
||||||
if cursor:
|
|
||||||
q = q.filter(Message.created_at < (
|
|
||||||
db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
|
||||||
rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all()
|
|
||||||
|
|
||||||
items = [to_dict(r) for r in rows[:limit]]
|
|
||||||
return ok({
|
|
||||||
"items": items,
|
|
||||||
"next_cursor": items[-1]["id"] if len(rows) > limit else None,
|
|
||||||
"has_more": len(rows) > limit,
|
|
||||||
})
|
|
||||||
|
|
||||||
d = request.json or {}
|
|
||||||
content = (d.get("content") or "").strip()
|
|
||||||
if not content:
|
|
||||||
return err(400, "content is required")
|
|
||||||
|
|
||||||
user_msg = Message(id=str(uuid.uuid4()), conversation_id=conv_id, role="user", content=content)
|
|
||||||
db.session.add(user_msg)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
tools_enabled = d.get("tools_enabled", True)
|
|
||||||
|
|
||||||
if d.get("stream", False):
|
|
||||||
return _stream_response(conv, tools_enabled)
|
|
||||||
|
|
||||||
return _sync_response(conv, tools_enabled)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/api/conversations/<conv_id>/messages/<msg_id>", methods=["DELETE"])
|
|
||||||
def delete_message(conv_id, msg_id):
|
|
||||||
conv = db.session.get(Conversation, conv_id)
|
|
||||||
if not conv:
|
|
||||||
return err(404, "conversation not found")
|
|
||||||
msg = db.session.get(Message, msg_id)
|
|
||||||
if not msg or msg.conversation_id != conv_id:
|
|
||||||
return err(404, "message not found")
|
|
||||||
db.session.delete(msg)
|
|
||||||
db.session.commit()
|
|
||||||
return ok(message="deleted")
|
|
||||||
|
|
||||||
|
|
||||||
# -- Chat Completion ----------------------------------
|
|
||||||
|
|
||||||
def _call_glm(conv, stream=False, tools=None, messages=None):
|
|
||||||
"""Call GLM API"""
|
|
||||||
body = {
|
|
||||||
"model": conv.model,
|
|
||||||
"messages": messages if messages is not None else build_glm_messages(conv),
|
|
||||||
"max_tokens": conv.max_tokens,
|
|
||||||
"temperature": conv.temperature,
|
|
||||||
}
|
|
||||||
if conv.thinking_enabled:
|
|
||||||
body["thinking"] = {"type": "enabled"}
|
|
||||||
if tools:
|
|
||||||
body["tools"] = tools
|
|
||||||
body["tool_choice"] = "auto"
|
|
||||||
if stream:
|
|
||||||
body["stream"] = True
|
|
||||||
return requests.post(
|
|
||||||
API_URL,
|
|
||||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"},
|
|
||||||
json=body, stream=stream, timeout=120,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sync_response(conv, tools_enabled=True):
|
|
||||||
"""Sync response with tool call support"""
|
|
||||||
executor = ToolExecutor(registry=registry)
|
|
||||||
tools = registry.list_all() if tools_enabled else None
|
|
||||||
messages = build_glm_messages(conv)
|
|
||||||
max_iterations = 5 # Max tool call iterations
|
|
||||||
|
|
||||||
# Collect all tool calls and results
|
|
||||||
all_tool_calls = []
|
|
||||||
all_tool_results = []
|
|
||||||
|
|
||||||
for _ in range(max_iterations):
|
|
||||||
try:
|
|
||||||
resp = _call_glm(conv, tools=tools, messages=messages)
|
|
||||||
resp.raise_for_status()
|
|
||||||
result = resp.json()
|
|
||||||
except Exception as e:
|
|
||||||
return err(500, f"upstream error: {e}")
|
|
||||||
|
|
||||||
choice = result["choices"][0]
|
|
||||||
message = choice["message"]
|
|
||||||
|
|
||||||
# If no tool calls, return final result
|
|
||||||
if not message.get("tool_calls"):
|
|
||||||
usage = result.get("usage", {})
|
|
||||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
||||||
completion_tokens = usage.get("completion_tokens", 0)
|
|
||||||
|
|
||||||
# Merge tool results into tool_calls
|
|
||||||
merged_tool_calls = []
|
|
||||||
for i, tc in enumerate(all_tool_calls):
|
|
||||||
merged_tc = dict(tc)
|
|
||||||
if i < len(all_tool_results):
|
|
||||||
merged_tc["result"] = all_tool_results[i]["content"]
|
|
||||||
merged_tool_calls.append(merged_tc)
|
|
||||||
|
|
||||||
# Save assistant message with all tool calls (including results)
|
|
||||||
msg = Message(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
conversation_id=conv.id,
|
|
||||||
role="assistant",
|
|
||||||
content=message.get("content", ""),
|
|
||||||
token_count=completion_tokens,
|
|
||||||
thinking_content=message.get("reasoning_content", ""),
|
|
||||||
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
|
||||||
)
|
|
||||||
db.session.add(msg)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
user = get_or_create_default_user()
|
|
||||||
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
return ok({
|
|
||||||
"message": to_dict(msg, thinking_content=msg.thinking_content or None),
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": completion_tokens,
|
|
||||||
"total_tokens": usage.get("total_tokens", 0)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
# Process tool calls
|
|
||||||
tool_calls = message["tool_calls"]
|
|
||||||
all_tool_calls.extend(tool_calls)
|
|
||||||
messages.append(message)
|
|
||||||
|
|
||||||
# Execute tools and add results
|
|
||||||
tool_results = executor.process_tool_calls(tool_calls)
|
|
||||||
all_tool_results.extend(tool_results)
|
|
||||||
messages.extend(tool_results)
|
|
||||||
|
|
||||||
return err(500, "exceeded maximum tool call iterations")
|
|
||||||
|
|
||||||
|
|
||||||
def _stream_response(conv, tools_enabled=True):
|
|
||||||
"""Stream response with tool call support"""
|
|
||||||
conv_id = conv.id
|
|
||||||
conv_model = conv.model
|
|
||||||
app = current_app._get_current_object()
|
|
||||||
executor = ToolExecutor(registry=registry)
|
|
||||||
tools = registry.list_all() if tools_enabled else None
|
|
||||||
# Build messages BEFORE entering generator (in request context)
|
|
||||||
initial_messages = build_glm_messages(conv)
|
|
||||||
|
|
||||||
def generate():
|
|
||||||
messages = list(initial_messages) # Copy to avoid mutation
|
|
||||||
max_iterations = 5
|
|
||||||
|
|
||||||
# Collect all tool calls and results
|
|
||||||
all_tool_calls = []
|
|
||||||
all_tool_results = []
|
|
||||||
total_content = ""
|
|
||||||
total_thinking = ""
|
|
||||||
total_tokens = 0
|
|
||||||
total_prompt_tokens = 0
|
|
||||||
|
|
||||||
for iteration in range(max_iterations):
|
|
||||||
full_content = ""
|
|
||||||
full_thinking = ""
|
|
||||||
token_count = 0
|
|
||||||
prompt_tokens = 0
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
tool_calls_list = []
|
|
||||||
current_tool_call = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with app.app_context():
|
|
||||||
active_conv = db.session.get(Conversation, conv_id)
|
|
||||||
resp = _call_glm(active_conv, stream=True, tools=tools, messages=messages)
|
|
||||||
resp.raise_for_status()
|
|
||||||
|
|
||||||
for line in resp.iter_lines():
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
line = line.decode("utf-8")
|
|
||||||
if not line.startswith("data: "):
|
|
||||||
continue
|
|
||||||
data_str = line[6:]
|
|
||||||
if data_str == "[DONE]":
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
chunk = json.loads(data_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta = chunk["choices"][0].get("delta", {})
|
|
||||||
|
|
||||||
# Process thinking chain
|
|
||||||
reasoning = delta.get("reasoning_content", "")
|
|
||||||
if reasoning:
|
|
||||||
full_thinking += reasoning
|
|
||||||
yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
# Process text content
|
|
||||||
text = delta.get("content", "")
|
|
||||||
if text:
|
|
||||||
full_content += text
|
|
||||||
yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
# Process tool calls
|
|
||||||
tool_calls_delta = delta.get("tool_calls", [])
|
|
||||||
for tc in tool_calls_delta:
|
|
||||||
idx = tc.get("index", 0)
|
|
||||||
if idx >= len(tool_calls_list):
|
|
||||||
tool_calls_list.append({
|
|
||||||
"id": tc.get("id", ""),
|
|
||||||
"type": tc.get("type", "function"),
|
|
||||||
"function": {"name": "", "arguments": ""}
|
|
||||||
})
|
|
||||||
if tc.get("id"):
|
|
||||||
tool_calls_list[idx]["id"] = tc["id"]
|
|
||||||
if tc.get("function"):
|
|
||||||
if tc["function"].get("name"):
|
|
||||||
tool_calls_list[idx]["function"]["name"] = tc["function"]["name"]
|
|
||||||
if tc["function"].get("arguments"):
|
|
||||||
tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"]
|
|
||||||
|
|
||||||
usage = chunk.get("usage", {})
|
|
||||||
if usage:
|
|
||||||
token_count = usage.get("completion_tokens", 0)
|
|
||||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
# If tool calls exist, execute and continue loop
|
|
||||||
if tool_calls_list:
|
|
||||||
# Collect tool calls
|
|
||||||
all_tool_calls.extend(tool_calls_list)
|
|
||||||
|
|
||||||
# Send tool call info
|
|
||||||
yield f"event: tool_calls\ndata: {json.dumps({'calls': tool_calls_list}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
# Execute tools
|
|
||||||
tool_results = executor.process_tool_calls(tool_calls_list)
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": full_content or None,
|
|
||||||
"tool_calls": tool_calls_list
|
|
||||||
})
|
|
||||||
messages.extend(tool_results)
|
|
||||||
|
|
||||||
# Collect tool results
|
|
||||||
all_tool_results.extend(tool_results)
|
|
||||||
|
|
||||||
# Send tool results
|
|
||||||
for tr in tool_results:
|
|
||||||
yield f"event: tool_result\ndata: {json.dumps({'name': tr['name'], 'content': tr['content']}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
# No tool calls, finish - save everything
|
|
||||||
total_content = full_content
|
|
||||||
total_thinking = full_thinking
|
|
||||||
total_tokens = token_count
|
|
||||||
total_prompt_tokens = prompt_tokens
|
|
||||||
|
|
||||||
# Merge tool results into tool_calls
|
|
||||||
merged_tool_calls = []
|
|
||||||
for i, tc in enumerate(all_tool_calls):
|
|
||||||
merged_tc = dict(tc)
|
|
||||||
if i < len(all_tool_results):
|
|
||||||
merged_tc["result"] = all_tool_results[i]["content"]
|
|
||||||
merged_tool_calls.append(merged_tc)
|
|
||||||
|
|
||||||
with app.app_context():
|
|
||||||
# Save assistant message with all tool calls (including results)
|
|
||||||
msg = Message(
|
|
||||||
id=msg_id,
|
|
||||||
conversation_id=conv_id,
|
|
||||||
role="assistant",
|
|
||||||
content=total_content,
|
|
||||||
token_count=total_tokens,
|
|
||||||
thinking_content=total_thinking,
|
|
||||||
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
|
||||||
)
|
|
||||||
db.session.add(msg)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
user = get_or_create_default_user()
|
|
||||||
record_token_usage(user.id, conv_model, total_prompt_tokens, total_tokens)
|
|
||||||
|
|
||||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': total_tokens})}\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
return Response(generate(), mimetype="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
||||||
|
|
||||||
|
|
||||||
def register_routes(app):
|
|
||||||
app.register_blueprint(bp)
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
"""Route registration"""
|
||||||
|
from flask import Flask
|
||||||
|
from backend.routes.conversations import bp as conversations_bp
|
||||||
|
from backend.routes.messages import bp as messages_bp, init_chat_service
|
||||||
|
from backend.routes.models import bp as models_bp
|
||||||
|
from backend.routes.tools import bp as tools_bp
|
||||||
|
from backend.routes.stats import bp as stats_bp
|
||||||
|
from backend.services.glm_client import GLMClient
|
||||||
|
from backend.config import API_URL, API_KEY
|
||||||
|
|
||||||
|
|
||||||
|
def register_routes(app: Flask):
|
||||||
|
"""Register all route blueprints"""
|
||||||
|
# Initialize GLM client and chat service
|
||||||
|
glm_client = GLMClient(API_URL, API_KEY)
|
||||||
|
init_chat_service(glm_client)
|
||||||
|
|
||||||
|
# Register blueprints
|
||||||
|
app.register_blueprint(conversations_bp)
|
||||||
|
app.register_blueprint(messages_bp)
|
||||||
|
app.register_blueprint(models_bp)
|
||||||
|
app.register_blueprint(tools_bp)
|
||||||
|
app.register_blueprint(stats_bp)
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
"""Conversation API routes"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from flask import Blueprint, request
|
||||||
|
from backend import db
|
||||||
|
from backend.models import Conversation
|
||||||
|
from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user
|
||||||
|
from backend.config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
bp = Blueprint("conversations", __name__)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/api/conversations", methods=["GET", "POST"])
|
||||||
|
def conversation_list():
|
||||||
|
"""List or create conversations"""
|
||||||
|
if request.method == "POST":
|
||||||
|
d = request.json or {}
|
||||||
|
user = get_or_create_default_user()
|
||||||
|
conv = Conversation(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user.id,
|
||||||
|
title=d.get("title", ""),
|
||||||
|
model=d.get("model", DEFAULT_MODEL),
|
||||||
|
system_prompt=d.get("system_prompt", ""),
|
||||||
|
temperature=d.get("temperature", 1.0),
|
||||||
|
max_tokens=d.get("max_tokens", 65536),
|
||||||
|
thinking_enabled=d.get("thinking_enabled", False),
|
||||||
|
)
|
||||||
|
db.session.add(conv)
|
||||||
|
db.session.commit()
|
||||||
|
return ok(to_dict(conv))
|
||||||
|
|
||||||
|
# GET - list conversations
|
||||||
|
cursor = request.args.get("cursor")
|
||||||
|
limit = min(int(request.args.get("limit", 20)), 100)
|
||||||
|
user = get_or_create_default_user()
|
||||||
|
q = Conversation.query.filter_by(user_id=user.id)
|
||||||
|
if cursor:
|
||||||
|
q = q.filter(Conversation.updated_at < (
|
||||||
|
db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
||||||
|
rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all()
|
||||||
|
|
||||||
|
items = [to_dict(r, message_count=r.messages.count()) for r in rows[:limit]]
|
||||||
|
return ok({
|
||||||
|
"items": items,
|
||||||
|
"next_cursor": items[-1]["id"] if len(rows) > limit else None,
|
||||||
|
"has_more": len(rows) > limit,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/api/conversations/<conv_id>", methods=["GET", "PATCH", "DELETE"])
|
||||||
|
def conversation_detail(conv_id):
|
||||||
|
"""Get, update or delete a conversation"""
|
||||||
|
conv = db.session.get(Conversation, conv_id)
|
||||||
|
if not conv:
|
||||||
|
return err(404, "conversation not found")
|
||||||
|
|
||||||
|
if request.method == "GET":
|
||||||
|
return ok(to_dict(conv))
|
||||||
|
|
||||||
|
if request.method == "DELETE":
|
||||||
|
db.session.delete(conv)
|
||||||
|
db.session.commit()
|
||||||
|
return ok(message="deleted")
|
||||||
|
|
||||||
|
# PATCH - update conversation
|
||||||
|
d = request.json or {}
|
||||||
|
for k in ("title", "model", "system_prompt", "temperature", "max_tokens", "thinking_enabled"):
|
||||||
|
if k in d:
|
||||||
|
setattr(conv, k, d[k])
|
||||||
|
db.session.commit()
|
||||||
|
return ok(to_dict(conv))
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
"""Message API routes"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from flask import Blueprint, request
|
||||||
|
from backend import db
|
||||||
|
from backend.models import Conversation, Message
|
||||||
|
from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user
|
||||||
|
from backend.services.chat import ChatService
|
||||||
|
|
||||||
|
|
||||||
|
bp = Blueprint("messages", __name__)
|
||||||
|
|
||||||
|
# ChatService will be injected during registration
|
||||||
|
_chat_service = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_chat_service(glm_client):
|
||||||
|
"""Initialize chat service with GLM client"""
|
||||||
|
global _chat_service
|
||||||
|
_chat_service = ChatService(glm_client)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/api/conversations/<conv_id>/messages", methods=["GET", "POST"])
|
||||||
|
def message_list(conv_id):
|
||||||
|
"""List or create messages"""
|
||||||
|
conv = db.session.get(Conversation, conv_id)
|
||||||
|
if not conv:
|
||||||
|
return err(404, "conversation not found")
|
||||||
|
|
||||||
|
if request.method == "GET":
|
||||||
|
cursor = request.args.get("cursor")
|
||||||
|
limit = min(int(request.args.get("limit", 50)), 100)
|
||||||
|
q = Message.query.filter_by(conversation_id=conv_id)
|
||||||
|
if cursor:
|
||||||
|
q = q.filter(Message.created_at < (
|
||||||
|
db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
||||||
|
rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all()
|
||||||
|
|
||||||
|
items = [to_dict(r) for r in rows[:limit]]
|
||||||
|
return ok({
|
||||||
|
"items": items,
|
||||||
|
"next_cursor": items[-1]["id"] if len(rows) > limit else None,
|
||||||
|
"has_more": len(rows) > limit,
|
||||||
|
})
|
||||||
|
|
||||||
|
# POST - create message and get AI response
|
||||||
|
d = request.json or {}
|
||||||
|
content = (d.get("content") or "").strip()
|
||||||
|
if not content:
|
||||||
|
return err(400, "content is required")
|
||||||
|
|
||||||
|
user_msg = Message(id=str(uuid.uuid4()), conversation_id=conv_id, role="user", content=content)
|
||||||
|
db.session.add(user_msg)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
tools_enabled = d.get("tools_enabled", True)
|
||||||
|
|
||||||
|
if d.get("stream", False):
|
||||||
|
return _chat_service.stream_response(conv, tools_enabled)
|
||||||
|
|
||||||
|
return _chat_service.sync_response(conv, tools_enabled)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/api/conversations/<conv_id>/messages/<msg_id>", methods=["DELETE"])
|
||||||
|
def delete_message(conv_id, msg_id):
|
||||||
|
"""Delete a message"""
|
||||||
|
conv = db.session.get(Conversation, conv_id)
|
||||||
|
if not conv:
|
||||||
|
return err(404, "conversation not found")
|
||||||
|
msg = db.session.get(Message, msg_id)
|
||||||
|
if not msg or msg.conversation_id != conv_id:
|
||||||
|
return err(404, "message not found")
|
||||||
|
db.session.delete(msg)
|
||||||
|
db.session.commit()
|
||||||
|
return ok(message="deleted")
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
"""Model list API routes"""
|
||||||
|
from flask import Blueprint
|
||||||
|
from backend.utils.helpers import ok
|
||||||
|
from backend.config import MODELS
|
||||||
|
|
||||||
|
bp = Blueprint("models", __name__)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/api/models", methods=["GET"])
|
||||||
|
def list_models():
|
||||||
|
"""Get available model list"""
|
||||||
|
return ok(MODELS)
|
||||||
|
|
@ -0,0 +1,84 @@
|
||||||
|
"""Token statistics API routes"""
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from flask import Blueprint, request
|
||||||
|
from sqlalchemy import func
|
||||||
|
from backend.models import TokenUsage
|
||||||
|
from backend.utils.helpers import ok, err, get_or_create_default_user
|
||||||
|
|
||||||
|
bp = Blueprint("stats", __name__)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/api/stats/tokens", methods=["GET"])
|
||||||
|
def token_stats():
|
||||||
|
"""Get token usage statistics"""
|
||||||
|
user = get_or_create_default_user()
|
||||||
|
period = request.args.get("period", "daily")
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
if period == "daily":
|
||||||
|
stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all()
|
||||||
|
result = {
|
||||||
|
"period": "daily",
|
||||||
|
"date": today.isoformat(),
|
||||||
|
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
||||||
|
"completion_tokens": sum(s.completion_tokens for s in stats),
|
||||||
|
"total_tokens": sum(s.total_tokens for s in stats),
|
||||||
|
"by_model": {
|
||||||
|
s.model: {
|
||||||
|
"prompt": s.prompt_tokens,
|
||||||
|
"completion": s.completion_tokens,
|
||||||
|
"total": s.total_tokens
|
||||||
|
} for s in stats
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif period == "weekly":
|
||||||
|
start_date = today - timedelta(days=6)
|
||||||
|
stats = TokenUsage.query.filter(
|
||||||
|
TokenUsage.user_id == user.id,
|
||||||
|
TokenUsage.date >= start_date,
|
||||||
|
TokenUsage.date <= today
|
||||||
|
).all()
|
||||||
|
|
||||||
|
result = _build_period_result(stats, "weekly", start_date, today, 7)
|
||||||
|
elif period == "monthly":
|
||||||
|
start_date = today - timedelta(days=29)
|
||||||
|
stats = TokenUsage.query.filter(
|
||||||
|
TokenUsage.user_id == user.id,
|
||||||
|
TokenUsage.date >= start_date,
|
||||||
|
TokenUsage.date <= today
|
||||||
|
).all()
|
||||||
|
|
||||||
|
result = _build_period_result(stats, "monthly", start_date, today, 30)
|
||||||
|
else:
|
||||||
|
return err(400, "invalid period")
|
||||||
|
|
||||||
|
return ok(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_period_result(stats, period, start_date, end_date, days):
|
||||||
|
"""Build result for period-based statistics"""
|
||||||
|
daily_data = {}
|
||||||
|
for s in stats:
|
||||||
|
d = s.date.isoformat()
|
||||||
|
if d not in daily_data:
|
||||||
|
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
||||||
|
daily_data[d]["prompt"] += s.prompt_tokens
|
||||||
|
daily_data[d]["completion"] += s.completion_tokens
|
||||||
|
daily_data[d]["total"] += s.total_tokens
|
||||||
|
|
||||||
|
# Fill missing dates
|
||||||
|
for i in range(days):
|
||||||
|
d = (end_date - timedelta(days=days - 1 - i)).isoformat()
|
||||||
|
if d not in daily_data:
|
||||||
|
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"period": period,
|
||||||
|
"start_date": start_date.isoformat(),
|
||||||
|
"end_date": end_date.isoformat(),
|
||||||
|
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
||||||
|
"completion_tokens": sum(s.completion_tokens for s in stats),
|
||||||
|
"total_tokens": sum(s.total_tokens for s in stats),
|
||||||
|
"daily": daily_data
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Tool list API routes"""
|
||||||
|
from flask import Blueprint
|
||||||
|
from backend.tools import registry
|
||||||
|
from backend.utils.helpers import ok
|
||||||
|
|
||||||
|
bp = Blueprint("tools", __name__)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/api/tools", methods=["GET"])
|
||||||
|
def list_tools():
|
||||||
|
"""Get available tool list"""
|
||||||
|
tools = registry.list_all()
|
||||||
|
return ok({
|
||||||
|
"tools": tools,
|
||||||
|
"total": len(tools)
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
"""Backend services"""
|
||||||
|
from backend.services.glm_client import GLMClient
|
||||||
|
from backend.services.chat import ChatService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GLMClient",
|
||||||
|
"ChatService",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,254 @@
|
||||||
|
"""Chat completion service"""
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from flask import current_app, Response
|
||||||
|
from backend import db
|
||||||
|
from backend.models import Conversation, Message
|
||||||
|
from backend.tools import registry, ToolExecutor
|
||||||
|
from backend.utils.helpers import (
|
||||||
|
get_or_create_default_user,
|
||||||
|
record_token_usage,
|
||||||
|
build_glm_messages,
|
||||||
|
ok,
|
||||||
|
err,
|
||||||
|
to_dict,
|
||||||
|
)
|
||||||
|
from backend.services.glm_client import GLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class ChatService:
|
||||||
|
"""Chat completion service with tool support"""
|
||||||
|
|
||||||
|
MAX_ITERATIONS = 5
|
||||||
|
|
||||||
|
def __init__(self, glm_client: GLMClient):
|
||||||
|
self.glm_client = glm_client
|
||||||
|
self.executor = ToolExecutor(registry=registry)
|
||||||
|
|
||||||
|
def sync_response(self, conv: Conversation, tools_enabled: bool = True):
|
||||||
|
"""Sync response with tool call support"""
|
||||||
|
tools = registry.list_all() if tools_enabled else None
|
||||||
|
messages = build_glm_messages(conv)
|
||||||
|
|
||||||
|
# Clear tool call history for new request
|
||||||
|
self.executor.clear_history()
|
||||||
|
|
||||||
|
all_tool_calls = []
|
||||||
|
all_tool_results = []
|
||||||
|
|
||||||
|
for _ in range(self.MAX_ITERATIONS):
|
||||||
|
try:
|
||||||
|
resp = self.glm_client.call(
|
||||||
|
model=conv.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=conv.max_tokens,
|
||||||
|
temperature=conv.temperature,
|
||||||
|
thinking_enabled=conv.thinking_enabled,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
result = resp.json()
|
||||||
|
except Exception as e:
|
||||||
|
return err(500, f"upstream error: {e}")
|
||||||
|
|
||||||
|
choice = result["choices"][0]
|
||||||
|
message = choice["message"]
|
||||||
|
|
||||||
|
# No tool calls - return final result
|
||||||
|
if not message.get("tool_calls"):
|
||||||
|
usage = result.get("usage", {})
|
||||||
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
completion_tokens = usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results)
|
||||||
|
|
||||||
|
msg = Message(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
conversation_id=conv.id,
|
||||||
|
role="assistant",
|
||||||
|
content=message.get("content", ""),
|
||||||
|
token_count=completion_tokens,
|
||||||
|
thinking_content=message.get("reasoning_content", ""),
|
||||||
|
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
||||||
|
)
|
||||||
|
db.session.add(msg)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
user = get_or_create_default_user()
|
||||||
|
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
return ok({
|
||||||
|
"message": to_dict(msg, thinking_content=msg.thinking_content or None),
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": usage.get("total_tokens", 0)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process tool calls
|
||||||
|
tool_calls = message["tool_calls"]
|
||||||
|
all_tool_calls.extend(tool_calls)
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
tool_results = self.executor.process_tool_calls(tool_calls)
|
||||||
|
all_tool_results.extend(tool_results)
|
||||||
|
messages.extend(tool_results)
|
||||||
|
|
||||||
|
return err(500, "exceeded maximum tool call iterations")
|
||||||
|
|
||||||
|
def stream_response(self, conv: Conversation, tools_enabled: bool = True):
|
||||||
|
"""Stream response with tool call support"""
|
||||||
|
conv_id = conv.id
|
||||||
|
conv_model = conv.model
|
||||||
|
app = current_app._get_current_object()
|
||||||
|
tools = registry.list_all() if tools_enabled else None
|
||||||
|
initial_messages = build_glm_messages(conv)
|
||||||
|
|
||||||
|
# Clear tool call history for new request
|
||||||
|
self.executor.clear_history()
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
messages = list(initial_messages)
|
||||||
|
all_tool_calls = []
|
||||||
|
all_tool_results = []
|
||||||
|
|
||||||
|
for iteration in range(self.MAX_ITERATIONS):
|
||||||
|
full_content = ""
|
||||||
|
full_thinking = ""
|
||||||
|
token_count = 0
|
||||||
|
prompt_tokens = 0
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
tool_calls_list = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with app.app_context():
|
||||||
|
active_conv = db.session.get(Conversation, conv_id)
|
||||||
|
resp = self.glm_client.call(
|
||||||
|
model=active_conv.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=active_conv.max_tokens,
|
||||||
|
temperature=active_conv.temperature,
|
||||||
|
thinking_enabled=active_conv.thinking_enabled,
|
||||||
|
tools=tools,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
for line in resp.iter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
line = line.decode("utf-8")
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk["choices"][0].get("delta", {})
|
||||||
|
|
||||||
|
# Process thinking
|
||||||
|
reasoning = delta.get("reasoning_content", "")
|
||||||
|
if reasoning:
|
||||||
|
full_thinking += reasoning
|
||||||
|
yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# Process text
|
||||||
|
text = delta.get("content", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# Process tool calls
|
||||||
|
tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list)
|
||||||
|
|
||||||
|
usage = chunk.get("usage", {})
|
||||||
|
if usage:
|
||||||
|
token_count = usage.get("completion_tokens", 0)
|
||||||
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Tool calls exist - execute and continue
|
||||||
|
if tool_calls_list:
|
||||||
|
all_tool_calls.extend(tool_calls_list)
|
||||||
|
yield f"event: tool_calls\ndata: {json.dumps({'calls': tool_calls_list}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
tool_results = self.executor.process_tool_calls(tool_calls_list)
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content or None,
|
||||||
|
"tool_calls": tool_calls_list
|
||||||
|
})
|
||||||
|
messages.extend(tool_results)
|
||||||
|
all_tool_results.extend(tool_results)
|
||||||
|
|
||||||
|
for tr in tool_results:
|
||||||
|
yield f"event: tool_result\ndata: {json.dumps({'name': tr['name'], 'content': tr['content']}, ensure_ascii=False)}\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
# No tool calls - finish
|
||||||
|
merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results)
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
msg = Message(
|
||||||
|
id=msg_id,
|
||||||
|
conversation_id=conv_id,
|
||||||
|
role="assistant",
|
||||||
|
content=full_content,
|
||||||
|
token_count=token_count,
|
||||||
|
thinking_content=full_thinking,
|
||||||
|
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
||||||
|
)
|
||||||
|
db.session.add(msg)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
user = get_or_create_default_user()
|
||||||
|
record_token_usage(user.id, conv_model, prompt_tokens, token_count)
|
||||||
|
|
||||||
|
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': token_count})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
generate(),
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list:
|
||||||
|
"""Process tool calls from streaming delta"""
|
||||||
|
tool_calls_delta = delta.get("tool_calls", [])
|
||||||
|
for tc in tool_calls_delta:
|
||||||
|
idx = tc.get("index", 0)
|
||||||
|
if idx >= len(tool_calls_list):
|
||||||
|
tool_calls_list.append({
|
||||||
|
"id": tc.get("id", ""),
|
||||||
|
"type": tc.get("type", "function"),
|
||||||
|
"function": {"name": "", "arguments": ""}
|
||||||
|
})
|
||||||
|
if tc.get("id"):
|
||||||
|
tool_calls_list[idx]["id"] = tc["id"]
|
||||||
|
if tc.get("function"):
|
||||||
|
if tc["function"].get("name"):
|
||||||
|
tool_calls_list[idx]["function"]["name"] = tc["function"]["name"]
|
||||||
|
if tc["function"].get("arguments"):
|
||||||
|
tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"]
|
||||||
|
return tool_calls_list
|
||||||
|
|
||||||
|
def _merge_tool_results(self, tool_calls: list, tool_results: list) -> list:
|
||||||
|
"""Merge tool results into tool calls"""
|
||||||
|
merged = []
|
||||||
|
for i, tc in enumerate(tool_calls):
|
||||||
|
merged_tc = dict(tc)
|
||||||
|
if i < len(tool_results):
|
||||||
|
merged_tc["result"] = tool_results[i]["content"]
|
||||||
|
merged.append(merged_tc)
|
||||||
|
return merged
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
"""GLM API client"""
|
||||||
|
import requests
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
|
||||||
|
class GLMClient:
|
||||||
|
"""GLM API client for chat completions"""
|
||||||
|
|
||||||
|
def __init__(self, api_url: str, api_key: str):
|
||||||
|
self.api_url = api_url
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[dict],
|
||||||
|
max_tokens: int = 65536,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
thinking_enabled: bool = False,
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
timeout: int = 120,
|
||||||
|
):
|
||||||
|
"""Call GLM API"""
|
||||||
|
body = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
if thinking_enabled:
|
||||||
|
body["thinking"] = {"type": "enabled"}
|
||||||
|
if tools:
|
||||||
|
body["tools"] = tools
|
||||||
|
body["tool_choice"] = "auto"
|
||||||
|
if stream:
|
||||||
|
body["stream"] = True
|
||||||
|
|
||||||
|
return requests.post(
|
||||||
|
self.api_url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
},
|
||||||
|
json=body,
|
||||||
|
stream=stream,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
@ -15,9 +15,9 @@ Usage:
|
||||||
result = registry.execute("web_search", {"query": "Python"})
|
result = registry.execute("web_search", {"query": "Python"})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .core import ToolDefinition, ToolResult, ToolRegistry, registry
|
from backend.tools.core import ToolDefinition, ToolResult, ToolRegistry, registry
|
||||||
from .factory import tool, register_tool
|
from backend.tools.factory import tool, register_tool
|
||||||
from .executor import ToolExecutor
|
from backend.tools.executor import ToolExecutor
|
||||||
|
|
||||||
|
|
||||||
def init_tools() -> None:
|
def init_tools() -> None:
|
||||||
|
|
@ -26,7 +26,7 @@ def init_tools() -> None:
|
||||||
|
|
||||||
Importing builtin module automatically registers all decorator-defined tools
|
Importing builtin module automatically registers all decorator-defined tools
|
||||||
"""
|
"""
|
||||||
from .builtin import crawler, data, weather # noqa: F401
|
from backend.tools.builtin import crawler, data, weather, file_ops # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
# Public API exports
|
# Public API exports
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
"""Built-in tools"""
|
"""Built-in tools"""
|
||||||
from .crawler import *
|
from backend.tools.builtin.crawler import *
|
||||||
from .data import *
|
from backend.tools.builtin.data import *
|
||||||
|
from backend.tools.builtin.file_ops import *
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Crawler related tools"""
|
"""Crawler related tools"""
|
||||||
from ..factory import tool
|
from backend.tools.factory import tool
|
||||||
from ..services import SearchService, FetchService
|
from backend.tools.services import SearchService, FetchService
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
@tool(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Data processing related tools"""
|
"""Data processing related tools"""
|
||||||
from ..factory import tool
|
from backend.tools.factory import tool
|
||||||
from ..services import CalculatorService
|
from backend.tools.services import CalculatorService
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
@tool(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,360 @@
|
||||||
|
"""File operation tools"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
from backend.tools.factory import tool
|
||||||
|
|
||||||
|
|
||||||
|
# Base directory for file operations (sandbox)
|
||||||
|
# Set to None to allow any path, or set a specific directory for security
|
||||||
|
BASE_DIR = Path(__file__).parent.parent.parent.parent # project root
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(path: str) -> Path:
|
||||||
|
"""Resolve path and ensure it's within allowed directory"""
|
||||||
|
p = Path(path)
|
||||||
|
if not p.is_absolute():
|
||||||
|
p = BASE_DIR / p
|
||||||
|
p = p.resolve()
|
||||||
|
|
||||||
|
# Security check: ensure path is within BASE_DIR
|
||||||
|
if BASE_DIR:
|
||||||
|
try:
|
||||||
|
p.relative_to(BASE_DIR.resolve())
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Path '{path}' is outside allowed directory")
|
||||||
|
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="file_read",
|
||||||
|
description="Read content from a file. Use when you need to read file content.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "File path to read (relative to project root or absolute)"
|
||||||
|
},
|
||||||
|
"encoding": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "File encoding, default utf-8",
|
||||||
|
"default": "utf-8"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
},
|
||||||
|
category="file"
|
||||||
|
)
|
||||||
|
def file_read(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Read file tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"path": "file.txt",
|
||||||
|
"encoding": "utf-8"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"content": "...", "size": 100}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = _resolve_path(arguments["path"])
|
||||||
|
encoding = arguments.get("encoding", "utf-8")
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return {"error": f"File not found: {path}"}
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
return {"error": f"Path is not a file: {path}"}
|
||||||
|
|
||||||
|
content = path.read_text(encoding=encoding)
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"size": len(content),
|
||||||
|
"path": str(path)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="file_write",
|
||||||
|
description="Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Use when you need to create or update a file.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "File path to write (relative to project root or absolute)"
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Content to write to the file"
|
||||||
|
},
|
||||||
|
"encoding": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "File encoding, default utf-8",
|
||||||
|
"default": "utf-8"
|
||||||
|
},
|
||||||
|
"mode": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Write mode: 'write' (overwrite) or 'append'",
|
||||||
|
"enum": ["write", "append"],
|
||||||
|
"default": "write"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "content"]
|
||||||
|
},
|
||||||
|
category="file"
|
||||||
|
)
|
||||||
|
def file_write(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Write file tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"path": "file.txt",
|
||||||
|
"content": "Hello World",
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"mode": "write"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"success": true, "size": 11}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = _resolve_path(arguments["path"])
|
||||||
|
content = arguments["content"]
|
||||||
|
encoding = arguments.get("encoding", "utf-8")
|
||||||
|
mode = arguments.get("mode", "write")
|
||||||
|
|
||||||
|
# Create parent directories if needed
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Write or append
|
||||||
|
if mode == "append":
|
||||||
|
with open(path, "a", encoding=encoding) as f:
|
||||||
|
f.write(content)
|
||||||
|
else:
|
||||||
|
path.write_text(content, encoding=encoding)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"size": len(content),
|
||||||
|
"path": str(path),
|
||||||
|
"mode": mode
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="file_delete",
|
||||||
|
description="Delete a file. Use when you need to remove a file.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "File path to delete (relative to project root or absolute)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
},
|
||||||
|
category="file"
|
||||||
|
)
|
||||||
|
def file_delete(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Delete file tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"path": "file.txt"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"success": true}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = _resolve_path(arguments["path"])
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return {"error": f"File not found: {path}"}
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
return {"error": f"Path is not a file: {path}"}
|
||||||
|
|
||||||
|
path.unlink()
|
||||||
|
return {"success": True, "path": str(path)}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="file_list",
|
||||||
|
description="List files and directories in a directory. Use when you need to see what files exist.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Directory path to list (relative to project root or absolute)",
|
||||||
|
"default": "."
|
||||||
|
},
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Glob pattern to filter files, e.g. '*.py'",
|
||||||
|
"default": "*"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": []
|
||||||
|
},
|
||||||
|
category="file"
|
||||||
|
)
|
||||||
|
def file_list(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
List directory contents
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"path": ".",
|
||||||
|
"pattern": "*"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"files": [...], "directories": [...]}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = _resolve_path(arguments.get("path", "."))
|
||||||
|
pattern = arguments.get("pattern", "*")
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return {"error": f"Directory not found: {path}"}
|
||||||
|
|
||||||
|
if not path.is_dir():
|
||||||
|
return {"error": f"Path is not a directory: {path}"}
|
||||||
|
|
||||||
|
files = []
|
||||||
|
directories = []
|
||||||
|
|
||||||
|
for item in path.glob(pattern):
|
||||||
|
if item.is_file():
|
||||||
|
files.append({
|
||||||
|
"name": item.name,
|
||||||
|
"size": item.stat().st_size,
|
||||||
|
"path": str(item.relative_to(BASE_DIR)) if BASE_DIR else str(item)
|
||||||
|
})
|
||||||
|
elif item.is_dir():
|
||||||
|
directories.append({
|
||||||
|
"name": item.name,
|
||||||
|
"path": str(item.relative_to(BASE_DIR)) if BASE_DIR else str(item)
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"path": str(path),
|
||||||
|
"files": files,
|
||||||
|
"directories": directories,
|
||||||
|
"total_files": len(files),
|
||||||
|
"total_dirs": len(directories)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="file_exists",
|
||||||
|
description="Check if a file or directory exists. Use when you need to verify file existence.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to check (relative to project root or absolute)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
},
|
||||||
|
category="file"
|
||||||
|
)
|
||||||
|
def file_exists(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Check if file/directory exists
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"path": "file.txt"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"exists": true, "type": "file"}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = _resolve_path(arguments["path"])
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return {"exists": False, "path": str(path)}
|
||||||
|
|
||||||
|
if path.is_file():
|
||||||
|
return {
|
||||||
|
"exists": True,
|
||||||
|
"type": "file",
|
||||||
|
"path": str(path),
|
||||||
|
"size": path.stat().st_size
|
||||||
|
}
|
||||||
|
elif path.is_dir():
|
||||||
|
return {
|
||||||
|
"exists": True,
|
||||||
|
"type": "directory",
|
||||||
|
"path": str(path)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"exists": True,
|
||||||
|
"type": "other",
|
||||||
|
"path": str(path)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="file_mkdir",
|
||||||
|
description="Create a directory. Creates parent directories if needed. Use when you need to create a folder.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Directory path to create (relative to project root or absolute)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
},
|
||||||
|
category="file"
|
||||||
|
)
|
||||||
|
def file_mkdir(arguments: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Create directory
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: {
|
||||||
|
"path": "new/folder"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"success": true}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = _resolve_path(arguments["path"])
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"path": str(path),
|
||||||
|
"created": not path.exists() or path.is_dir()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""Weather related tools"""
|
"""Weather related tools"""
|
||||||
from ..factory import tool
|
from backend.tools.factory import tool
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
@tool(
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,62 @@
|
||||||
"""Tool executor"""
|
"""Tool executor with caching and deduplication"""
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Optional, Generator, Any
|
import hashlib
|
||||||
from .core import ToolRegistry, registry
|
from typing import List, Dict, Optional, Any
|
||||||
|
from backend.tools.core import ToolRegistry, registry
|
||||||
|
|
||||||
|
|
||||||
class ToolExecutor:
|
class ToolExecutor:
|
||||||
"""Tool call executor"""
|
"""Tool call executor with caching and deduplication"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
registry: Optional[ToolRegistry] = None,
|
registry: Optional[ToolRegistry] = None,
|
||||||
api_url: Optional[str] = None,
|
api_url: Optional[str] = None,
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None,
|
||||||
|
enable_cache: bool = True,
|
||||||
|
cache_ttl: int = 300, # 5 minutes
|
||||||
):
|
):
|
||||||
self.registry = registry or ToolRegistry()
|
self.registry = registry or ToolRegistry()
|
||||||
self.api_url = api_url
|
self.api_url = api_url
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.enable_cache = enable_cache
|
||||||
|
self.cache_ttl = cache_ttl
|
||||||
|
self._cache: Dict[str, tuple] = {} # key -> (result, timestamp)
|
||||||
|
self._call_history: List[dict] = [] # Track calls in current session
|
||||||
|
|
||||||
|
def _make_cache_key(self, name: str, args: dict) -> str:
|
||||||
|
"""Generate cache key from tool name and arguments"""
|
||||||
|
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||||
|
return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest()
|
||||||
|
|
||||||
|
def _get_cached(self, key: str) -> Optional[dict]:
|
||||||
|
"""Get cached result if valid"""
|
||||||
|
if not self.enable_cache:
|
||||||
|
return None
|
||||||
|
if key in self._cache:
|
||||||
|
result, timestamp = self._cache[key]
|
||||||
|
if time.time() - timestamp < self.cache_ttl:
|
||||||
|
return result
|
||||||
|
del self._cache[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _set_cache(self, key: str, result: dict) -> None:
|
||||||
|
"""Cache a result"""
|
||||||
|
if self.enable_cache:
|
||||||
|
self._cache[key] = (result, time.time())
|
||||||
|
|
||||||
|
def _check_duplicate_in_history(self, name: str, args: dict) -> Optional[dict]:
|
||||||
|
"""Check if same tool+args was called before in this session"""
|
||||||
|
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||||
|
for record in self._call_history:
|
||||||
|
if record["name"] == name and record["args_str"] == args_str:
|
||||||
|
return record["result"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_history(self) -> None:
|
||||||
|
"""Clear call history (call this at start of new conversation turn)"""
|
||||||
|
self._call_history.clear()
|
||||||
|
|
||||||
def process_tool_calls(
|
def process_tool_calls(
|
||||||
self,
|
self,
|
||||||
|
|
@ -34,6 +74,7 @@ class ToolExecutor:
|
||||||
Tool response message list, can be appended to messages
|
Tool response message list, can be appended to messages
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
|
seen_calls = set() # Track calls within this batch
|
||||||
|
|
||||||
for call in tool_calls:
|
for call in tool_calls:
|
||||||
name = call["function"]["name"]
|
name = call["function"]["name"]
|
||||||
|
|
@ -48,7 +89,45 @@ class ToolExecutor:
|
||||||
))
|
))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check for duplicate within same batch
|
||||||
|
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
|
||||||
|
if call_key in seen_calls:
|
||||||
|
# Skip duplicate, but still return a result
|
||||||
|
results.append(self._create_tool_result(
|
||||||
|
call_id, name,
|
||||||
|
{"success": True, "data": None, "cached": True, "duplicate": True}
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
seen_calls.add(call_key)
|
||||||
|
|
||||||
|
# Check history for previous call in this session
|
||||||
|
history_result = self._check_duplicate_in_history(name, args)
|
||||||
|
if history_result is not None:
|
||||||
|
result = {**history_result, "cached": True}
|
||||||
|
results.append(self._create_tool_result(call_id, name, result))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
cache_key = self._make_cache_key(name, args)
|
||||||
|
cached_result = self._get_cached(cache_key)
|
||||||
|
if cached_result is not None:
|
||||||
|
result = {**cached_result, "cached": True}
|
||||||
|
results.append(self._create_tool_result(call_id, name, result))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Execute tool
|
||||||
result = self.registry.execute(name, args)
|
result = self.registry.execute(name, args)
|
||||||
|
|
||||||
|
# Cache the result
|
||||||
|
self._set_cache(cache_key, result)
|
||||||
|
|
||||||
|
# Add to history
|
||||||
|
self._call_history.append({
|
||||||
|
"name": name,
|
||||||
|
"args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
|
||||||
|
"result": result
|
||||||
|
})
|
||||||
|
|
||||||
results.append(self._create_tool_result(call_id, name, result))
|
results.append(self._create_tool_result(call_id, name, result))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Tool factory - decorator registration"""
|
"""Tool factory - decorator registration"""
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from .core import ToolDefinition, registry
|
from backend.tools.core import ToolDefinition, registry
|
||||||
|
|
||||||
|
|
||||||
def tool(
|
def tool(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""Backend utilities"""
|
||||||
|
from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user, record_token_usage, build_glm_messages
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ok",
|
||||||
|
"err",
|
||||||
|
"to_dict",
|
||||||
|
"get_or_create_default_user",
|
||||||
|
"record_token_usage",
|
||||||
|
"build_glm_messages",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
"""Common helper functions"""
|
||||||
|
import json
|
||||||
|
from datetime import datetime, date
|
||||||
|
from backend import db
|
||||||
|
from backend.models import Conversation, Message, User, TokenUsage
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_default_user():
|
||||||
|
"""Get or create default user"""
|
||||||
|
user = User.query.filter_by(username="default").first()
|
||||||
|
if not user:
|
||||||
|
user = User(username="default", password="")
|
||||||
|
db.session.add(user)
|
||||||
|
db.session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def ok(data=None, message=None):
|
||||||
|
"""Success response helper"""
|
||||||
|
body = {"code": 0}
|
||||||
|
if data is not None:
|
||||||
|
body["data"] = data
|
||||||
|
if message is not None:
|
||||||
|
body["message"] = message
|
||||||
|
from flask import jsonify
|
||||||
|
return jsonify(body)
|
||||||
|
|
||||||
|
|
||||||
|
def err(code, message):
|
||||||
|
"""Error response helper"""
|
||||||
|
from flask import jsonify
|
||||||
|
return jsonify({"code": code, "message": message}), code
|
||||||
|
|
||||||
|
|
||||||
|
def to_dict(inst, **extra):
|
||||||
|
"""Convert model instance to dict"""
|
||||||
|
d = {c.name: getattr(inst, c.name) for c in inst.__table__.columns}
|
||||||
|
for k in ("created_at", "updated_at"):
|
||||||
|
if k in d and hasattr(d[k], "strftime"):
|
||||||
|
d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
# Parse tool_calls JSON if present
|
||||||
|
if "tool_calls" in d and d["tool_calls"]:
|
||||||
|
try:
|
||||||
|
d["tool_calls"] = json.loads(d["tool_calls"])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Filter out None values for cleaner API response
|
||||||
|
d = {k: v for k, v in d.items() if v is not None}
|
||||||
|
|
||||||
|
d.update(extra)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
||||||
|
"""Record token usage"""
|
||||||
|
today = date.today()
|
||||||
|
usage = TokenUsage.query.filter_by(
|
||||||
|
user_id=user_id, date=today, model=model
|
||||||
|
).first()
|
||||||
|
if usage:
|
||||||
|
usage.prompt_tokens += prompt_tokens
|
||||||
|
usage.completion_tokens += completion_tokens
|
||||||
|
usage.total_tokens += prompt_tokens + completion_tokens
|
||||||
|
else:
|
||||||
|
usage = TokenUsage(
|
||||||
|
user_id=user_id,
|
||||||
|
date=today,
|
||||||
|
model=model,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
db.session.add(usage)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def build_glm_messages(conv):
|
||||||
|
"""Build messages list for GLM API from conversation"""
|
||||||
|
msgs = []
|
||||||
|
if conv.system_prompt:
|
||||||
|
msgs.append({"role": "system", "content": conv.system_prompt})
|
||||||
|
# Query messages directly to avoid detached instance warning
|
||||||
|
messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all()
|
||||||
|
for m in messages:
|
||||||
|
msgs.append({"role": m.role, "content": m.content})
|
||||||
|
return msgs
|
||||||
|
|
@ -450,6 +450,12 @@ init_tools()
|
||||||
| data | `text_process` | 文本处理 | - |
|
| data | `text_process` | 文本处理 | - |
|
||||||
| data | `json_process` | JSON处理 | - |
|
| data | `json_process` | JSON处理 | - |
|
||||||
| weather | `get_weather` | 天气查询 | - (模拟数据) |
|
| weather | `get_weather` | 天气查询 | - (模拟数据) |
|
||||||
|
| file | `file_read` | 读取文件 | - |
|
||||||
|
| file | `file_write` | 写入文件 | - |
|
||||||
|
| file | `file_delete` | 删除文件 | - |
|
||||||
|
| file | `file_list` | 列出目录 | - |
|
||||||
|
| file | `file_exists` | 检查存在 | - |
|
||||||
|
| file | `file_mkdir` | 创建目录 | - |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue