refactor: 优化后端结构
This commit is contained in:
parent
dc70a4a1f2
commit
362ab15338
|
|
@ -32,9 +32,9 @@ def create_app():
|
|||
db.init_app(app)
|
||||
|
||||
# Import after db is initialized
|
||||
from .models import User, Conversation, Message, TokenUsage
|
||||
from .routes import register_routes
|
||||
from .tools import init_tools
|
||||
from backend.models import User, Conversation, Message, TokenUsage
|
||||
from backend.routes import register_routes
|
||||
from backend.tools import init_tools
|
||||
|
||||
register_routes(app)
|
||||
init_tools()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
"""Configuration management"""
|
||||
from backend import load_config
|
||||
|
||||
_cfg = load_config()
|
||||
|
||||
API_URL = _cfg.get("api_url")
|
||||
API_KEY = _cfg["api_key"]
|
||||
MODELS = _cfg.get("models", [])
|
||||
DEFAULT_MODEL = _cfg.get("default_model", "glm-5")
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from . import db
|
||||
from backend import db
|
||||
|
||||
|
||||
class User(db.Model):
|
||||
|
|
|
|||
|
|
@ -1,588 +0,0 @@
|
|||
import uuid
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
from datetime import datetime
|
||||
from flask import request, jsonify, Response, Blueprint, current_app
|
||||
from . import db
|
||||
from .models import Conversation, Message, User, TokenUsage
|
||||
from . import load_config
|
||||
from .tools import registry, ToolExecutor
|
||||
|
||||
bp = Blueprint("api", __name__)
|
||||
|
||||
cfg = load_config()
|
||||
API_URL = cfg.get("api_url")
|
||||
API_KEY = cfg["api_key"]
|
||||
MODELS = cfg.get("models", [])
|
||||
DEFAULT_MODEL = cfg.get("default_model", "glm-5")
|
||||
|
||||
|
||||
# -- Helpers ----------------------------------------------
|
||||
|
||||
def get_or_create_default_user():
|
||||
user = User.query.filter_by(username="default").first()
|
||||
if not user:
|
||||
user = User(username="default", password="")
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def ok(data=None, message=None):
|
||||
body = {"code": 0}
|
||||
if data is not None:
|
||||
body["data"] = data
|
||||
if message is not None:
|
||||
body["message"] = message
|
||||
return jsonify(body)
|
||||
|
||||
|
||||
def err(code, message):
|
||||
return jsonify({"code": code, "message": message}), code
|
||||
|
||||
|
||||
def to_dict(inst, **extra):
|
||||
d = {c.name: getattr(inst, c.name) for c in inst.__table__.columns}
|
||||
for k in ("created_at", "updated_at"):
|
||||
if k in d and hasattr(d[k], "strftime"):
|
||||
d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# Parse tool_calls JSON if present
|
||||
if "tool_calls" in d and d["tool_calls"]:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(d["tool_calls"])
|
||||
except:
|
||||
pass
|
||||
|
||||
# Filter out None values for cleaner API response
|
||||
d = {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
d.update(extra)
|
||||
return d
|
||||
|
||||
|
||||
def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
||||
"""Record token usage"""
|
||||
from datetime import date
|
||||
today = date.today()
|
||||
usage = TokenUsage.query.filter_by(
|
||||
user_id=user_id, date=today, model=model
|
||||
).first()
|
||||
if usage:
|
||||
usage.prompt_tokens += prompt_tokens
|
||||
usage.completion_tokens += completion_tokens
|
||||
usage.total_tokens += prompt_tokens + completion_tokens
|
||||
else:
|
||||
usage = TokenUsage(
|
||||
user_id=user_id,
|
||||
date=today,
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
db.session.add(usage)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def build_glm_messages(conv):
|
||||
"""Build messages list for GLM API from conversation"""
|
||||
msgs = []
|
||||
if conv.system_prompt:
|
||||
msgs.append({"role": "system", "content": conv.system_prompt})
|
||||
# Query messages directly to avoid detached instance warning
|
||||
messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all()
|
||||
for m in messages:
|
||||
msgs.append({"role": m.role, "content": m.content})
|
||||
return msgs
|
||||
|
||||
|
||||
# -- Models API -------------------------------------------
|
||||
|
||||
@bp.route("/api/models", methods=["GET"])
|
||||
def list_models():
|
||||
"""Get available model list"""
|
||||
return ok(MODELS)
|
||||
|
||||
|
||||
# -- Tools API --------------------------------------------
|
||||
|
||||
@bp.route("/api/tools", methods=["GET"])
|
||||
def list_tools():
|
||||
"""Get available tool list"""
|
||||
tools = registry.list_all()
|
||||
return ok({
|
||||
"tools": tools,
|
||||
"total": len(tools)
|
||||
})
|
||||
|
||||
|
||||
# -- Token Usage Statistics --------------------------------
|
||||
|
||||
@bp.route("/api/stats/tokens", methods=["GET"])
|
||||
def token_stats():
|
||||
"""Get token usage statistics"""
|
||||
from sqlalchemy import func
|
||||
from datetime import date, timedelta
|
||||
|
||||
user = get_or_create_default_user()
|
||||
period = request.args.get("period", "daily") # daily, weekly, monthly
|
||||
|
||||
today = date.today()
|
||||
|
||||
if period == "daily":
|
||||
# Today's statistics
|
||||
stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all()
|
||||
result = {
|
||||
"period": "daily",
|
||||
"date": today.isoformat(),
|
||||
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
||||
"completion_tokens": sum(s.completion_tokens for s in stats),
|
||||
"total_tokens": sum(s.total_tokens for s in stats),
|
||||
"by_model": {s.model: {"prompt": s.prompt_tokens, "completion": s.completion_tokens, "total": s.total_tokens} for s in stats}
|
||||
}
|
||||
elif period == "weekly":
|
||||
# Weekly statistics (last 7 days)
|
||||
start_date = today - timedelta(days=6)
|
||||
stats = TokenUsage.query.filter(
|
||||
TokenUsage.user_id == user.id,
|
||||
TokenUsage.date >= start_date,
|
||||
TokenUsage.date <= today
|
||||
).all()
|
||||
|
||||
daily_data = {}
|
||||
for s in stats:
|
||||
d = s.date.isoformat()
|
||||
if d not in daily_data:
|
||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
||||
daily_data[d]["prompt"] += s.prompt_tokens
|
||||
daily_data[d]["completion"] += s.completion_tokens
|
||||
daily_data[d]["total"] += s.total_tokens
|
||||
|
||||
# Fill missing dates
|
||||
for i in range(7):
|
||||
d = (today - timedelta(days=6-i)).isoformat()
|
||||
if d not in daily_data:
|
||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
||||
|
||||
result = {
|
||||
"period": "weekly",
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": today.isoformat(),
|
||||
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
||||
"completion_tokens": sum(s.completion_tokens for s in stats),
|
||||
"total_tokens": sum(s.total_tokens for s in stats),
|
||||
"daily": daily_data
|
||||
}
|
||||
elif period == "monthly":
|
||||
# Monthly statistics (last 30 days)
|
||||
start_date = today - timedelta(days=29)
|
||||
stats = TokenUsage.query.filter(
|
||||
TokenUsage.user_id == user.id,
|
||||
TokenUsage.date >= start_date,
|
||||
TokenUsage.date <= today
|
||||
).all()
|
||||
|
||||
daily_data = {}
|
||||
for s in stats:
|
||||
d = s.date.isoformat()
|
||||
if d not in daily_data:
|
||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
||||
daily_data[d]["prompt"] += s.prompt_tokens
|
||||
daily_data[d]["completion"] += s.completion_tokens
|
||||
daily_data[d]["total"] += s.total_tokens
|
||||
|
||||
# Fill missing dates
|
||||
for i in range(30):
|
||||
d = (today - timedelta(days=29-i)).isoformat()
|
||||
if d not in daily_data:
|
||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
||||
|
||||
result = {
|
||||
"period": "monthly",
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": today.isoformat(),
|
||||
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
||||
"completion_tokens": sum(s.completion_tokens for s in stats),
|
||||
"total_tokens": sum(s.total_tokens for s in stats),
|
||||
"daily": daily_data
|
||||
}
|
||||
else:
|
||||
return err(400, "invalid period")
|
||||
|
||||
return ok(result)
|
||||
|
||||
|
||||
# -- Conversation CRUD ------------------------------------
|
||||
|
||||
@bp.route("/api/conversations", methods=["GET", "POST"])
|
||||
def conversation_list():
|
||||
if request.method == "POST":
|
||||
d = request.json or {}
|
||||
user = get_or_create_default_user()
|
||||
conv = Conversation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user.id,
|
||||
title=d.get("title", ""),
|
||||
model=d.get("model", DEFAULT_MODEL),
|
||||
system_prompt=d.get("system_prompt", ""),
|
||||
temperature=d.get("temperature", 1.0),
|
||||
max_tokens=d.get("max_tokens", 65536),
|
||||
thinking_enabled=d.get("thinking_enabled", False),
|
||||
)
|
||||
db.session.add(conv)
|
||||
db.session.commit()
|
||||
return ok(to_dict(conv))
|
||||
|
||||
cursor = request.args.get("cursor")
|
||||
limit = min(int(request.args.get("limit", 20)), 100)
|
||||
user = get_or_create_default_user()
|
||||
q = Conversation.query.filter_by(user_id=user.id)
|
||||
if cursor:
|
||||
q = q.filter(Conversation.updated_at < (
|
||||
db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
||||
rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all()
|
||||
|
||||
items = [to_dict(r, message_count=r.messages.count()) for r in rows[:limit]]
|
||||
return ok({
|
||||
"items": items,
|
||||
"next_cursor": items[-1]["id"] if len(rows) > limit else None,
|
||||
"has_more": len(rows) > limit,
|
||||
})
|
||||
|
||||
|
||||
@bp.route("/api/conversations/<conv_id>", methods=["GET", "PATCH", "DELETE"])
|
||||
def conversation_detail(conv_id):
|
||||
conv = db.session.get(Conversation, conv_id)
|
||||
if not conv:
|
||||
return err(404, "conversation not found")
|
||||
|
||||
if request.method == "GET":
|
||||
return ok(to_dict(conv))
|
||||
|
||||
if request.method == "DELETE":
|
||||
db.session.delete(conv)
|
||||
db.session.commit()
|
||||
return ok(message="deleted")
|
||||
|
||||
d = request.json or {}
|
||||
for k in ("title", "model", "system_prompt", "temperature", "max_tokens", "thinking_enabled"):
|
||||
if k in d:
|
||||
setattr(conv, k, d[k])
|
||||
db.session.commit()
|
||||
return ok(to_dict(conv))
|
||||
|
||||
|
||||
# -- Messages ---------------------------------------------
|
||||
|
||||
@bp.route("/api/conversations/<conv_id>/messages", methods=["GET", "POST"])
|
||||
def message_list(conv_id):
|
||||
conv = db.session.get(Conversation, conv_id)
|
||||
if not conv:
|
||||
return err(404, "conversation not found")
|
||||
|
||||
if request.method == "GET":
|
||||
cursor = request.args.get("cursor")
|
||||
limit = min(int(request.args.get("limit", 50)), 100)
|
||||
q = Message.query.filter_by(conversation_id=conv_id)
|
||||
if cursor:
|
||||
q = q.filter(Message.created_at < (
|
||||
db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
||||
rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all()
|
||||
|
||||
items = [to_dict(r) for r in rows[:limit]]
|
||||
return ok({
|
||||
"items": items,
|
||||
"next_cursor": items[-1]["id"] if len(rows) > limit else None,
|
||||
"has_more": len(rows) > limit,
|
||||
})
|
||||
|
||||
d = request.json or {}
|
||||
content = (d.get("content") or "").strip()
|
||||
if not content:
|
||||
return err(400, "content is required")
|
||||
|
||||
user_msg = Message(id=str(uuid.uuid4()), conversation_id=conv_id, role="user", content=content)
|
||||
db.session.add(user_msg)
|
||||
db.session.commit()
|
||||
|
||||
tools_enabled = d.get("tools_enabled", True)
|
||||
|
||||
if d.get("stream", False):
|
||||
return _stream_response(conv, tools_enabled)
|
||||
|
||||
return _sync_response(conv, tools_enabled)
|
||||
|
||||
|
||||
@bp.route("/api/conversations/<conv_id>/messages/<msg_id>", methods=["DELETE"])
|
||||
def delete_message(conv_id, msg_id):
|
||||
conv = db.session.get(Conversation, conv_id)
|
||||
if not conv:
|
||||
return err(404, "conversation not found")
|
||||
msg = db.session.get(Message, msg_id)
|
||||
if not msg or msg.conversation_id != conv_id:
|
||||
return err(404, "message not found")
|
||||
db.session.delete(msg)
|
||||
db.session.commit()
|
||||
return ok(message="deleted")
|
||||
|
||||
|
||||
# -- Chat Completion ----------------------------------
|
||||
|
||||
def _call_glm(conv, stream=False, tools=None, messages=None):
|
||||
"""Call GLM API"""
|
||||
body = {
|
||||
"model": conv.model,
|
||||
"messages": messages if messages is not None else build_glm_messages(conv),
|
||||
"max_tokens": conv.max_tokens,
|
||||
"temperature": conv.temperature,
|
||||
}
|
||||
if conv.thinking_enabled:
|
||||
body["thinking"] = {"type": "enabled"}
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
body["tool_choice"] = "auto"
|
||||
if stream:
|
||||
body["stream"] = True
|
||||
return requests.post(
|
||||
API_URL,
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"},
|
||||
json=body, stream=stream, timeout=120,
|
||||
)
|
||||
|
||||
|
||||
def _sync_response(conv, tools_enabled=True):
|
||||
"""Sync response with tool call support"""
|
||||
executor = ToolExecutor(registry=registry)
|
||||
tools = registry.list_all() if tools_enabled else None
|
||||
messages = build_glm_messages(conv)
|
||||
max_iterations = 5 # Max tool call iterations
|
||||
|
||||
# Collect all tool calls and results
|
||||
all_tool_calls = []
|
||||
all_tool_results = []
|
||||
|
||||
for _ in range(max_iterations):
|
||||
try:
|
||||
resp = _call_glm(conv, tools=tools, messages=messages)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
except Exception as e:
|
||||
return err(500, f"upstream error: {e}")
|
||||
|
||||
choice = result["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
# If no tool calls, return final result
|
||||
if not message.get("tool_calls"):
|
||||
usage = result.get("usage", {})
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
# Merge tool results into tool_calls
|
||||
merged_tool_calls = []
|
||||
for i, tc in enumerate(all_tool_calls):
|
||||
merged_tc = dict(tc)
|
||||
if i < len(all_tool_results):
|
||||
merged_tc["result"] = all_tool_results[i]["content"]
|
||||
merged_tool_calls.append(merged_tc)
|
||||
|
||||
# Save assistant message with all tool calls (including results)
|
||||
msg = Message(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conv.id,
|
||||
role="assistant",
|
||||
content=message.get("content", ""),
|
||||
token_count=completion_tokens,
|
||||
thinking_content=message.get("reasoning_content", ""),
|
||||
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
||||
)
|
||||
db.session.add(msg)
|
||||
db.session.commit()
|
||||
|
||||
user = get_or_create_default_user()
|
||||
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
|
||||
|
||||
return ok({
|
||||
"message": to_dict(msg, thinking_content=msg.thinking_content or None),
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": usage.get("total_tokens", 0)
|
||||
},
|
||||
})
|
||||
|
||||
# Process tool calls
|
||||
tool_calls = message["tool_calls"]
|
||||
all_tool_calls.extend(tool_calls)
|
||||
messages.append(message)
|
||||
|
||||
# Execute tools and add results
|
||||
tool_results = executor.process_tool_calls(tool_calls)
|
||||
all_tool_results.extend(tool_results)
|
||||
messages.extend(tool_results)
|
||||
|
||||
return err(500, "exceeded maximum tool call iterations")
|
||||
|
||||
|
||||
def _stream_response(conv, tools_enabled=True):
|
||||
"""Stream response with tool call support"""
|
||||
conv_id = conv.id
|
||||
conv_model = conv.model
|
||||
app = current_app._get_current_object()
|
||||
executor = ToolExecutor(registry=registry)
|
||||
tools = registry.list_all() if tools_enabled else None
|
||||
# Build messages BEFORE entering generator (in request context)
|
||||
initial_messages = build_glm_messages(conv)
|
||||
|
||||
def generate():
|
||||
messages = list(initial_messages) # Copy to avoid mutation
|
||||
max_iterations = 5
|
||||
|
||||
# Collect all tool calls and results
|
||||
all_tool_calls = []
|
||||
all_tool_results = []
|
||||
total_content = ""
|
||||
total_thinking = ""
|
||||
total_tokens = 0
|
||||
total_prompt_tokens = 0
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
full_content = ""
|
||||
full_thinking = ""
|
||||
token_count = 0
|
||||
prompt_tokens = 0
|
||||
msg_id = str(uuid.uuid4())
|
||||
tool_calls_list = []
|
||||
current_tool_call = None
|
||||
|
||||
try:
|
||||
with app.app_context():
|
||||
active_conv = db.session.get(Conversation, conv_id)
|
||||
resp = _call_glm(active_conv, stream=True, tools=tools, messages=messages)
|
||||
resp.raise_for_status()
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode("utf-8")
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
|
||||
# Process thinking chain
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
full_thinking += reasoning
|
||||
yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Process text content
|
||||
text = delta.get("content", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Process tool calls
|
||||
tool_calls_delta = delta.get("tool_calls", [])
|
||||
for tc in tool_calls_delta:
|
||||
idx = tc.get("index", 0)
|
||||
if idx >= len(tool_calls_list):
|
||||
tool_calls_list.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": tc.get("type", "function"),
|
||||
"function": {"name": "", "arguments": ""}
|
||||
})
|
||||
if tc.get("id"):
|
||||
tool_calls_list[idx]["id"] = tc["id"]
|
||||
if tc.get("function"):
|
||||
if tc["function"].get("name"):
|
||||
tool_calls_list[idx]["function"]["name"] = tc["function"]["name"]
|
||||
if tc["function"].get("arguments"):
|
||||
tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"]
|
||||
|
||||
usage = chunk.get("usage", {})
|
||||
if usage:
|
||||
token_count = usage.get("completion_tokens", 0)
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
|
||||
except Exception as e:
|
||||
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# If tool calls exist, execute and continue loop
|
||||
if tool_calls_list:
|
||||
# Collect tool calls
|
||||
all_tool_calls.extend(tool_calls_list)
|
||||
|
||||
# Send tool call info
|
||||
yield f"event: tool_calls\ndata: {json.dumps({'calls': tool_calls_list}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Execute tools
|
||||
tool_results = executor.process_tool_calls(tool_calls_list)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": full_content or None,
|
||||
"tool_calls": tool_calls_list
|
||||
})
|
||||
messages.extend(tool_results)
|
||||
|
||||
# Collect tool results
|
||||
all_tool_results.extend(tool_results)
|
||||
|
||||
# Send tool results
|
||||
for tr in tool_results:
|
||||
yield f"event: tool_result\ndata: {json.dumps({'name': tr['name'], 'content': tr['content']}, ensure_ascii=False)}\n\n"
|
||||
|
||||
continue
|
||||
|
||||
# No tool calls, finish - save everything
|
||||
total_content = full_content
|
||||
total_thinking = full_thinking
|
||||
total_tokens = token_count
|
||||
total_prompt_tokens = prompt_tokens
|
||||
|
||||
# Merge tool results into tool_calls
|
||||
merged_tool_calls = []
|
||||
for i, tc in enumerate(all_tool_calls):
|
||||
merged_tc = dict(tc)
|
||||
if i < len(all_tool_results):
|
||||
merged_tc["result"] = all_tool_results[i]["content"]
|
||||
merged_tool_calls.append(merged_tc)
|
||||
|
||||
with app.app_context():
|
||||
# Save assistant message with all tool calls (including results)
|
||||
msg = Message(
|
||||
id=msg_id,
|
||||
conversation_id=conv_id,
|
||||
role="assistant",
|
||||
content=total_content,
|
||||
token_count=total_tokens,
|
||||
thinking_content=total_thinking,
|
||||
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
||||
)
|
||||
db.session.add(msg)
|
||||
db.session.commit()
|
||||
|
||||
user = get_or_create_default_user()
|
||||
record_token_usage(user.id, conv_model, total_prompt_tokens, total_tokens)
|
||||
|
||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': total_tokens})}\n\n"
|
||||
return
|
||||
|
||||
yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return Response(generate(), mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
def register_routes(app):
|
||||
app.register_blueprint(bp)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Route registration"""
|
||||
from flask import Flask
|
||||
from backend.routes.conversations import bp as conversations_bp
|
||||
from backend.routes.messages import bp as messages_bp, init_chat_service
|
||||
from backend.routes.models import bp as models_bp
|
||||
from backend.routes.tools import bp as tools_bp
|
||||
from backend.routes.stats import bp as stats_bp
|
||||
from backend.services.glm_client import GLMClient
|
||||
from backend.config import API_URL, API_KEY
|
||||
|
||||
|
||||
def register_routes(app: Flask):
|
||||
"""Register all route blueprints"""
|
||||
# Initialize GLM client and chat service
|
||||
glm_client = GLMClient(API_URL, API_KEY)
|
||||
init_chat_service(glm_client)
|
||||
|
||||
# Register blueprints
|
||||
app.register_blueprint(conversations_bp)
|
||||
app.register_blueprint(messages_bp)
|
||||
app.register_blueprint(models_bp)
|
||||
app.register_blueprint(tools_bp)
|
||||
app.register_blueprint(stats_bp)
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
"""Conversation API routes"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from flask import Blueprint, request
|
||||
from backend import db
|
||||
from backend.models import Conversation
|
||||
from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user
|
||||
from backend.config import DEFAULT_MODEL
|
||||
|
||||
bp = Blueprint("conversations", __name__)
|
||||
|
||||
|
||||
@bp.route("/api/conversations", methods=["GET", "POST"])
|
||||
def conversation_list():
|
||||
"""List or create conversations"""
|
||||
if request.method == "POST":
|
||||
d = request.json or {}
|
||||
user = get_or_create_default_user()
|
||||
conv = Conversation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user.id,
|
||||
title=d.get("title", ""),
|
||||
model=d.get("model", DEFAULT_MODEL),
|
||||
system_prompt=d.get("system_prompt", ""),
|
||||
temperature=d.get("temperature", 1.0),
|
||||
max_tokens=d.get("max_tokens", 65536),
|
||||
thinking_enabled=d.get("thinking_enabled", False),
|
||||
)
|
||||
db.session.add(conv)
|
||||
db.session.commit()
|
||||
return ok(to_dict(conv))
|
||||
|
||||
# GET - list conversations
|
||||
cursor = request.args.get("cursor")
|
||||
limit = min(int(request.args.get("limit", 20)), 100)
|
||||
user = get_or_create_default_user()
|
||||
q = Conversation.query.filter_by(user_id=user.id)
|
||||
if cursor:
|
||||
q = q.filter(Conversation.updated_at < (
|
||||
db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
||||
rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all()
|
||||
|
||||
items = [to_dict(r, message_count=r.messages.count()) for r in rows[:limit]]
|
||||
return ok({
|
||||
"items": items,
|
||||
"next_cursor": items[-1]["id"] if len(rows) > limit else None,
|
||||
"has_more": len(rows) > limit,
|
||||
})
|
||||
|
||||
|
||||
@bp.route("/api/conversations/<conv_id>", methods=["GET", "PATCH", "DELETE"])
|
||||
def conversation_detail(conv_id):
|
||||
"""Get, update or delete a conversation"""
|
||||
conv = db.session.get(Conversation, conv_id)
|
||||
if not conv:
|
||||
return err(404, "conversation not found")
|
||||
|
||||
if request.method == "GET":
|
||||
return ok(to_dict(conv))
|
||||
|
||||
if request.method == "DELETE":
|
||||
db.session.delete(conv)
|
||||
db.session.commit()
|
||||
return ok(message="deleted")
|
||||
|
||||
# PATCH - update conversation
|
||||
d = request.json or {}
|
||||
for k in ("title", "model", "system_prompt", "temperature", "max_tokens", "thinking_enabled"):
|
||||
if k in d:
|
||||
setattr(conv, k, d[k])
|
||||
db.session.commit()
|
||||
return ok(to_dict(conv))
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""Message API routes"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from flask import Blueprint, request
|
||||
from backend import db
|
||||
from backend.models import Conversation, Message
|
||||
from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user
|
||||
from backend.services.chat import ChatService
|
||||
|
||||
|
||||
bp = Blueprint("messages", __name__)
|
||||
|
||||
# ChatService will be injected during registration
|
||||
_chat_service = None
|
||||
|
||||
|
||||
def init_chat_service(glm_client):
|
||||
"""Initialize chat service with GLM client"""
|
||||
global _chat_service
|
||||
_chat_service = ChatService(glm_client)
|
||||
|
||||
|
||||
@bp.route("/api/conversations/<conv_id>/messages", methods=["GET", "POST"])
|
||||
def message_list(conv_id):
|
||||
"""List or create messages"""
|
||||
conv = db.session.get(Conversation, conv_id)
|
||||
if not conv:
|
||||
return err(404, "conversation not found")
|
||||
|
||||
if request.method == "GET":
|
||||
cursor = request.args.get("cursor")
|
||||
limit = min(int(request.args.get("limit", 50)), 100)
|
||||
q = Message.query.filter_by(conversation_id=conv_id)
|
||||
if cursor:
|
||||
q = q.filter(Message.created_at < (
|
||||
db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.utcnow))
|
||||
rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all()
|
||||
|
||||
items = [to_dict(r) for r in rows[:limit]]
|
||||
return ok({
|
||||
"items": items,
|
||||
"next_cursor": items[-1]["id"] if len(rows) > limit else None,
|
||||
"has_more": len(rows) > limit,
|
||||
})
|
||||
|
||||
# POST - create message and get AI response
|
||||
d = request.json or {}
|
||||
content = (d.get("content") or "").strip()
|
||||
if not content:
|
||||
return err(400, "content is required")
|
||||
|
||||
user_msg = Message(id=str(uuid.uuid4()), conversation_id=conv_id, role="user", content=content)
|
||||
db.session.add(user_msg)
|
||||
db.session.commit()
|
||||
|
||||
tools_enabled = d.get("tools_enabled", True)
|
||||
|
||||
if d.get("stream", False):
|
||||
return _chat_service.stream_response(conv, tools_enabled)
|
||||
|
||||
return _chat_service.sync_response(conv, tools_enabled)
|
||||
|
||||
|
||||
@bp.route("/api/conversations/<conv_id>/messages/<msg_id>", methods=["DELETE"])
|
||||
def delete_message(conv_id, msg_id):
|
||||
"""Delete a message"""
|
||||
conv = db.session.get(Conversation, conv_id)
|
||||
if not conv:
|
||||
return err(404, "conversation not found")
|
||||
msg = db.session.get(Message, msg_id)
|
||||
if not msg or msg.conversation_id != conv_id:
|
||||
return err(404, "message not found")
|
||||
db.session.delete(msg)
|
||||
db.session.commit()
|
||||
return ok(message="deleted")
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Model list API routes"""
|
||||
from flask import Blueprint
|
||||
from backend.utils.helpers import ok
|
||||
from backend.config import MODELS
|
||||
|
||||
bp = Blueprint("models", __name__)
|
||||
|
||||
|
||||
@bp.route("/api/models", methods=["GET"])
|
||||
def list_models():
|
||||
"""Get available model list"""
|
||||
return ok(MODELS)
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
"""Token statistics API routes"""
|
||||
from datetime import date, timedelta
|
||||
from flask import Blueprint, request
|
||||
from sqlalchemy import func
|
||||
from backend.models import TokenUsage
|
||||
from backend.utils.helpers import ok, err, get_or_create_default_user
|
||||
|
||||
bp = Blueprint("stats", __name__)
|
||||
|
||||
|
||||
@bp.route("/api/stats/tokens", methods=["GET"])
|
||||
def token_stats():
|
||||
"""Get token usage statistics"""
|
||||
user = get_or_create_default_user()
|
||||
period = request.args.get("period", "daily")
|
||||
|
||||
today = date.today()
|
||||
|
||||
if period == "daily":
|
||||
stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all()
|
||||
result = {
|
||||
"period": "daily",
|
||||
"date": today.isoformat(),
|
||||
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
||||
"completion_tokens": sum(s.completion_tokens for s in stats),
|
||||
"total_tokens": sum(s.total_tokens for s in stats),
|
||||
"by_model": {
|
||||
s.model: {
|
||||
"prompt": s.prompt_tokens,
|
||||
"completion": s.completion_tokens,
|
||||
"total": s.total_tokens
|
||||
} for s in stats
|
||||
}
|
||||
}
|
||||
elif period == "weekly":
|
||||
start_date = today - timedelta(days=6)
|
||||
stats = TokenUsage.query.filter(
|
||||
TokenUsage.user_id == user.id,
|
||||
TokenUsage.date >= start_date,
|
||||
TokenUsage.date <= today
|
||||
).all()
|
||||
|
||||
result = _build_period_result(stats, "weekly", start_date, today, 7)
|
||||
elif period == "monthly":
|
||||
start_date = today - timedelta(days=29)
|
||||
stats = TokenUsage.query.filter(
|
||||
TokenUsage.user_id == user.id,
|
||||
TokenUsage.date >= start_date,
|
||||
TokenUsage.date <= today
|
||||
).all()
|
||||
|
||||
result = _build_period_result(stats, "monthly", start_date, today, 30)
|
||||
else:
|
||||
return err(400, "invalid period")
|
||||
|
||||
return ok(result)
|
||||
|
||||
|
||||
def _build_period_result(stats, period, start_date, end_date, days):
|
||||
"""Build result for period-based statistics"""
|
||||
daily_data = {}
|
||||
for s in stats:
|
||||
d = s.date.isoformat()
|
||||
if d not in daily_data:
|
||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
||||
daily_data[d]["prompt"] += s.prompt_tokens
|
||||
daily_data[d]["completion"] += s.completion_tokens
|
||||
daily_data[d]["total"] += s.total_tokens
|
||||
|
||||
# Fill missing dates
|
||||
for i in range(days):
|
||||
d = (end_date - timedelta(days=days - 1 - i)).isoformat()
|
||||
if d not in daily_data:
|
||||
daily_data[d] = {"prompt": 0, "completion": 0, "total": 0}
|
||||
|
||||
return {
|
||||
"period": period,
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"prompt_tokens": sum(s.prompt_tokens for s in stats),
|
||||
"completion_tokens": sum(s.completion_tokens for s in stats),
|
||||
"total_tokens": sum(s.total_tokens for s in stats),
|
||||
"daily": daily_data
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""Tool list API routes"""
|
||||
from flask import Blueprint
|
||||
from backend.tools import registry
|
||||
from backend.utils.helpers import ok
|
||||
|
||||
bp = Blueprint("tools", __name__)
|
||||
|
||||
|
||||
@bp.route("/api/tools", methods=["GET"])
|
||||
def list_tools():
|
||||
"""Get available tool list"""
|
||||
tools = registry.list_all()
|
||||
return ok({
|
||||
"tools": tools,
|
||||
"total": len(tools)
|
||||
})
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
"""Backend services"""
|
||||
from backend.services.glm_client import GLMClient
|
||||
from backend.services.chat import ChatService
|
||||
|
||||
__all__ = [
|
||||
"GLMClient",
|
||||
"ChatService",
|
||||
]
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
"""Chat completion service"""
|
||||
import json
|
||||
import uuid
|
||||
from flask import current_app, Response
|
||||
from backend import db
|
||||
from backend.models import Conversation, Message
|
||||
from backend.tools import registry, ToolExecutor
|
||||
from backend.utils.helpers import (
|
||||
get_or_create_default_user,
|
||||
record_token_usage,
|
||||
build_glm_messages,
|
||||
ok,
|
||||
err,
|
||||
to_dict,
|
||||
)
|
||||
from backend.services.glm_client import GLMClient
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""Chat completion service with tool support"""
|
||||
|
||||
MAX_ITERATIONS = 5
|
||||
|
||||
def __init__(self, glm_client: GLMClient):
|
||||
self.glm_client = glm_client
|
||||
self.executor = ToolExecutor(registry=registry)
|
||||
|
||||
def sync_response(self, conv: Conversation, tools_enabled: bool = True):
|
||||
"""Sync response with tool call support"""
|
||||
tools = registry.list_all() if tools_enabled else None
|
||||
messages = build_glm_messages(conv)
|
||||
|
||||
# Clear tool call history for new request
|
||||
self.executor.clear_history()
|
||||
|
||||
all_tool_calls = []
|
||||
all_tool_results = []
|
||||
|
||||
for _ in range(self.MAX_ITERATIONS):
|
||||
try:
|
||||
resp = self.glm_client.call(
|
||||
model=conv.model,
|
||||
messages=messages,
|
||||
max_tokens=conv.max_tokens,
|
||||
temperature=conv.temperature,
|
||||
thinking_enabled=conv.thinking_enabled,
|
||||
tools=tools,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
except Exception as e:
|
||||
return err(500, f"upstream error: {e}")
|
||||
|
||||
choice = result["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
# No tool calls - return final result
|
||||
if not message.get("tool_calls"):
|
||||
usage = result.get("usage", {})
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results)
|
||||
|
||||
msg = Message(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conv.id,
|
||||
role="assistant",
|
||||
content=message.get("content", ""),
|
||||
token_count=completion_tokens,
|
||||
thinking_content=message.get("reasoning_content", ""),
|
||||
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
||||
)
|
||||
db.session.add(msg)
|
||||
db.session.commit()
|
||||
|
||||
user = get_or_create_default_user()
|
||||
record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens)
|
||||
|
||||
return ok({
|
||||
"message": to_dict(msg, thinking_content=msg.thinking_content or None),
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": usage.get("total_tokens", 0)
|
||||
},
|
||||
})
|
||||
|
||||
# Process tool calls
|
||||
tool_calls = message["tool_calls"]
|
||||
all_tool_calls.extend(tool_calls)
|
||||
messages.append(message)
|
||||
|
||||
tool_results = self.executor.process_tool_calls(tool_calls)
|
||||
all_tool_results.extend(tool_results)
|
||||
messages.extend(tool_results)
|
||||
|
||||
return err(500, "exceeded maximum tool call iterations")
|
||||
|
||||
def stream_response(self, conv: Conversation, tools_enabled: bool = True):
|
||||
"""Stream response with tool call support"""
|
||||
conv_id = conv.id
|
||||
conv_model = conv.model
|
||||
app = current_app._get_current_object()
|
||||
tools = registry.list_all() if tools_enabled else None
|
||||
initial_messages = build_glm_messages(conv)
|
||||
|
||||
# Clear tool call history for new request
|
||||
self.executor.clear_history()
|
||||
|
||||
def generate():
|
||||
messages = list(initial_messages)
|
||||
all_tool_calls = []
|
||||
all_tool_results = []
|
||||
|
||||
for iteration in range(self.MAX_ITERATIONS):
|
||||
full_content = ""
|
||||
full_thinking = ""
|
||||
token_count = 0
|
||||
prompt_tokens = 0
|
||||
msg_id = str(uuid.uuid4())
|
||||
tool_calls_list = []
|
||||
|
||||
try:
|
||||
with app.app_context():
|
||||
active_conv = db.session.get(Conversation, conv_id)
|
||||
resp = self.glm_client.call(
|
||||
model=active_conv.model,
|
||||
messages=messages,
|
||||
max_tokens=active_conv.max_tokens,
|
||||
temperature=active_conv.temperature,
|
||||
thinking_enabled=active_conv.thinking_enabled,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode("utf-8")
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
|
||||
# Process thinking
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
full_thinking += reasoning
|
||||
yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Process text
|
||||
text = delta.get("content", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Process tool calls
|
||||
tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list)
|
||||
|
||||
usage = chunk.get("usage", {})
|
||||
if usage:
|
||||
token_count = usage.get("completion_tokens", 0)
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
|
||||
except Exception as e:
|
||||
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# Tool calls exist - execute and continue
|
||||
if tool_calls_list:
|
||||
all_tool_calls.extend(tool_calls_list)
|
||||
yield f"event: tool_calls\ndata: {json.dumps({'calls': tool_calls_list}, ensure_ascii=False)}\n\n"
|
||||
|
||||
tool_results = self.executor.process_tool_calls(tool_calls_list)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": full_content or None,
|
||||
"tool_calls": tool_calls_list
|
||||
})
|
||||
messages.extend(tool_results)
|
||||
all_tool_results.extend(tool_results)
|
||||
|
||||
for tr in tool_results:
|
||||
yield f"event: tool_result\ndata: {json.dumps({'name': tr['name'], 'content': tr['content']}, ensure_ascii=False)}\n\n"
|
||||
continue
|
||||
|
||||
# No tool calls - finish
|
||||
merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results)
|
||||
|
||||
with app.app_context():
|
||||
msg = Message(
|
||||
id=msg_id,
|
||||
conversation_id=conv_id,
|
||||
role="assistant",
|
||||
content=full_content,
|
||||
token_count=token_count,
|
||||
thinking_content=full_thinking,
|
||||
tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None
|
||||
)
|
||||
db.session.add(msg)
|
||||
db.session.commit()
|
||||
|
||||
user = get_or_create_default_user()
|
||||
record_token_usage(user.id, conv_model, prompt_tokens, token_count)
|
||||
|
||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': token_count})}\n\n"
|
||||
return
|
||||
|
||||
yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return Response(
|
||||
generate(),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}
|
||||
)
|
||||
|
||||
def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list:
|
||||
"""Process tool calls from streaming delta"""
|
||||
tool_calls_delta = delta.get("tool_calls", [])
|
||||
for tc in tool_calls_delta:
|
||||
idx = tc.get("index", 0)
|
||||
if idx >= len(tool_calls_list):
|
||||
tool_calls_list.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": tc.get("type", "function"),
|
||||
"function": {"name": "", "arguments": ""}
|
||||
})
|
||||
if tc.get("id"):
|
||||
tool_calls_list[idx]["id"] = tc["id"]
|
||||
if tc.get("function"):
|
||||
if tc["function"].get("name"):
|
||||
tool_calls_list[idx]["function"]["name"] = tc["function"]["name"]
|
||||
if tc["function"].get("arguments"):
|
||||
tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"]
|
||||
return tool_calls_list
|
||||
|
||||
def _merge_tool_results(self, tool_calls: list, tool_results: list) -> list:
|
||||
"""Merge tool results into tool calls"""
|
||||
merged = []
|
||||
for i, tc in enumerate(tool_calls):
|
||||
merged_tc = dict(tc)
|
||||
if i < len(tool_results):
|
||||
merged_tc["result"] = tool_results[i]["content"]
|
||||
merged.append(merged_tc)
|
||||
return merged
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
"""GLM API client"""
|
||||
import requests
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class GLMClient:
|
||||
"""GLM API client for chat completions"""
|
||||
|
||||
def __init__(self, api_url: str, api_key: str):
|
||||
self.api_url = api_url
|
||||
self.api_key = api_key
|
||||
|
||||
def call(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[dict],
|
||||
max_tokens: int = 65536,
|
||||
temperature: float = 1.0,
|
||||
thinking_enabled: bool = False,
|
||||
tools: Optional[List[dict]] = None,
|
||||
stream: bool = False,
|
||||
timeout: int = 120,
|
||||
):
|
||||
"""Call GLM API"""
|
||||
body = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if thinking_enabled:
|
||||
body["thinking"] = {"type": "enabled"}
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
body["tool_choice"] = "auto"
|
||||
if stream:
|
||||
body["stream"] = True
|
||||
|
||||
return requests.post(
|
||||
self.api_url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
},
|
||||
json=body,
|
||||
stream=stream,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
|
@ -15,9 +15,9 @@ Usage:
|
|||
result = registry.execute("web_search", {"query": "Python"})
|
||||
"""
|
||||
|
||||
from .core import ToolDefinition, ToolResult, ToolRegistry, registry
|
||||
from .factory import tool, register_tool
|
||||
from .executor import ToolExecutor
|
||||
from backend.tools.core import ToolDefinition, ToolResult, ToolRegistry, registry
|
||||
from backend.tools.factory import tool, register_tool
|
||||
from backend.tools.executor import ToolExecutor
|
||||
|
||||
|
||||
def init_tools() -> None:
|
||||
|
|
@ -26,7 +26,7 @@ def init_tools() -> None:
|
|||
|
||||
Importing builtin module automatically registers all decorator-defined tools
|
||||
"""
|
||||
from .builtin import crawler, data, weather # noqa: F401
|
||||
from backend.tools.builtin import crawler, data, weather, file_ops # noqa: F401
|
||||
|
||||
|
||||
# Public API exports
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
"""Built-in tools"""
|
||||
from .crawler import *
|
||||
from .data import *
|
||||
from backend.tools.builtin.crawler import *
|
||||
from backend.tools.builtin.data import *
|
||||
from backend.tools.builtin.file_ops import *
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Crawler related tools"""
|
||||
from ..factory import tool
|
||||
from ..services import SearchService, FetchService
|
||||
from backend.tools.factory import tool
|
||||
from backend.tools.services import SearchService, FetchService
|
||||
|
||||
|
||||
@tool(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Data processing related tools"""
|
||||
from ..factory import tool
|
||||
from ..services import CalculatorService
|
||||
from backend.tools.factory import tool
|
||||
from backend.tools.services import CalculatorService
|
||||
|
||||
|
||||
@tool(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,360 @@
|
|||
"""File operation tools"""
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from backend.tools.factory import tool
|
||||
|
||||
|
||||
# Base directory for file operations (sandbox)
|
||||
# Set to None to allow any path, or set a specific directory for security
|
||||
BASE_DIR = Path(__file__).parent.parent.parent.parent # project root
|
||||
|
||||
|
||||
def _resolve_path(path: str) -> Path:
|
||||
"""Resolve path and ensure it's within allowed directory"""
|
||||
p = Path(path)
|
||||
if not p.is_absolute():
|
||||
p = BASE_DIR / p
|
||||
p = p.resolve()
|
||||
|
||||
# Security check: ensure path is within BASE_DIR
|
||||
if BASE_DIR:
|
||||
try:
|
||||
p.relative_to(BASE_DIR.resolve())
|
||||
except ValueError:
|
||||
raise ValueError(f"Path '{path}' is outside allowed directory")
|
||||
|
||||
return p
|
||||
|
||||
|
||||
@tool(
|
||||
name="file_read",
|
||||
description="Read content from a file. Use when you need to read file content.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to read (relative to project root or absolute)"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding, default utf-8",
|
||||
"default": "utf-8"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
category="file"
|
||||
)
|
||||
def file_read(arguments: dict) -> dict:
|
||||
"""
|
||||
Read file tool
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"path": "file.txt",
|
||||
"encoding": "utf-8"
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"content": "...", "size": 100}
|
||||
"""
|
||||
try:
|
||||
path = _resolve_path(arguments["path"])
|
||||
encoding = arguments.get("encoding", "utf-8")
|
||||
|
||||
if not path.exists():
|
||||
return {"error": f"File not found: {path}"}
|
||||
|
||||
if not path.is_file():
|
||||
return {"error": f"Path is not a file: {path}"}
|
||||
|
||||
content = path.read_text(encoding=encoding)
|
||||
return {
|
||||
"content": content,
|
||||
"size": len(content),
|
||||
"path": str(path)
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
@tool(
|
||||
name="file_write",
|
||||
description="Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Use when you need to create or update a file.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to write (relative to project root or absolute)"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding, default utf-8",
|
||||
"default": "utf-8"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"description": "Write mode: 'write' (overwrite) or 'append'",
|
||||
"enum": ["write", "append"],
|
||||
"default": "write"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
},
|
||||
category="file"
|
||||
)
|
||||
def file_write(arguments: dict) -> dict:
|
||||
"""
|
||||
Write file tool
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"path": "file.txt",
|
||||
"content": "Hello World",
|
||||
"encoding": "utf-8",
|
||||
"mode": "write"
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"success": true, "size": 11}
|
||||
"""
|
||||
try:
|
||||
path = _resolve_path(arguments["path"])
|
||||
content = arguments["content"]
|
||||
encoding = arguments.get("encoding", "utf-8")
|
||||
mode = arguments.get("mode", "write")
|
||||
|
||||
# Create parent directories if needed
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write or append
|
||||
if mode == "append":
|
||||
with open(path, "a", encoding=encoding) as f:
|
||||
f.write(content)
|
||||
else:
|
||||
path.write_text(content, encoding=encoding)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"size": len(content),
|
||||
"path": str(path),
|
||||
"mode": mode
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
@tool(
|
||||
name="file_delete",
|
||||
description="Delete a file. Use when you need to remove a file.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to delete (relative to project root or absolute)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
category="file"
|
||||
)
|
||||
def file_delete(arguments: dict) -> dict:
|
||||
"""
|
||||
Delete file tool
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"path": "file.txt"
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"success": true}
|
||||
"""
|
||||
try:
|
||||
path = _resolve_path(arguments["path"])
|
||||
|
||||
if not path.exists():
|
||||
return {"error": f"File not found: {path}"}
|
||||
|
||||
if not path.is_file():
|
||||
return {"error": f"Path is not a file: {path}"}
|
||||
|
||||
path.unlink()
|
||||
return {"success": True, "path": str(path)}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
@tool(
|
||||
name="file_list",
|
||||
description="List files and directories in a directory. Use when you need to see what files exist.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path to list (relative to project root or absolute)",
|
||||
"default": "."
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to filter files, e.g. '*.py'",
|
||||
"default": "*"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
category="file"
|
||||
)
|
||||
def file_list(arguments: dict) -> dict:
|
||||
"""
|
||||
List directory contents
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"path": ".",
|
||||
"pattern": "*"
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"files": [...], "directories": [...]}
|
||||
"""
|
||||
try:
|
||||
path = _resolve_path(arguments.get("path", "."))
|
||||
pattern = arguments.get("pattern", "*")
|
||||
|
||||
if not path.exists():
|
||||
return {"error": f"Directory not found: {path}"}
|
||||
|
||||
if not path.is_dir():
|
||||
return {"error": f"Path is not a directory: {path}"}
|
||||
|
||||
files = []
|
||||
directories = []
|
||||
|
||||
for item in path.glob(pattern):
|
||||
if item.is_file():
|
||||
files.append({
|
||||
"name": item.name,
|
||||
"size": item.stat().st_size,
|
||||
"path": str(item.relative_to(BASE_DIR)) if BASE_DIR else str(item)
|
||||
})
|
||||
elif item.is_dir():
|
||||
directories.append({
|
||||
"name": item.name,
|
||||
"path": str(item.relative_to(BASE_DIR)) if BASE_DIR else str(item)
|
||||
})
|
||||
|
||||
return {
|
||||
"path": str(path),
|
||||
"files": files,
|
||||
"directories": directories,
|
||||
"total_files": len(files),
|
||||
"total_dirs": len(directories)
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
@tool(
|
||||
name="file_exists",
|
||||
description="Check if a file or directory exists. Use when you need to verify file existence.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to check (relative to project root or absolute)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
category="file"
|
||||
)
|
||||
def file_exists(arguments: dict) -> dict:
|
||||
"""
|
||||
Check if file/directory exists
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"path": "file.txt"
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"exists": true, "type": "file"}
|
||||
"""
|
||||
try:
|
||||
path = _resolve_path(arguments["path"])
|
||||
|
||||
if not path.exists():
|
||||
return {"exists": False, "path": str(path)}
|
||||
|
||||
if path.is_file():
|
||||
return {
|
||||
"exists": True,
|
||||
"type": "file",
|
||||
"path": str(path),
|
||||
"size": path.stat().st_size
|
||||
}
|
||||
elif path.is_dir():
|
||||
return {
|
||||
"exists": True,
|
||||
"type": "directory",
|
||||
"path": str(path)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"exists": True,
|
||||
"type": "other",
|
||||
"path": str(path)
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
@tool(
|
||||
name="file_mkdir",
|
||||
description="Create a directory. Creates parent directories if needed. Use when you need to create a folder.",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path to create (relative to project root or absolute)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
category="file"
|
||||
)
|
||||
def file_mkdir(arguments: dict) -> dict:
|
||||
"""
|
||||
Create directory
|
||||
|
||||
Args:
|
||||
arguments: {
|
||||
"path": "new/folder"
|
||||
}
|
||||
|
||||
Returns:
|
||||
{"success": true}
|
||||
"""
|
||||
try:
|
||||
path = _resolve_path(arguments["path"])
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return {
|
||||
"success": True,
|
||||
"path": str(path),
|
||||
"created": not path.exists() or path.is_dir()
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
"""Weather related tools"""
|
||||
from ..factory import tool
|
||||
from backend.tools.factory import tool
|
||||
|
||||
|
||||
@tool(
|
||||
|
|
|
|||
|
|
@ -1,22 +1,62 @@
|
|||
"""Tool executor"""
|
||||
"""Tool executor with caching and deduplication"""
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Optional, Generator, Any
|
||||
from .core import ToolRegistry, registry
|
||||
import hashlib
|
||||
from typing import List, Dict, Optional, Any
|
||||
from backend.tools.core import ToolRegistry, registry
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Tool call executor"""
|
||||
"""Tool call executor with caching and deduplication"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: Optional[ToolRegistry] = None,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
api_key: Optional[str] = None,
|
||||
enable_cache: bool = True,
|
||||
cache_ttl: int = 300, # 5 minutes
|
||||
):
|
||||
self.registry = registry or ToolRegistry()
|
||||
self.api_url = api_url
|
||||
self.api_key = api_key
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self._cache: Dict[str, tuple] = {} # key -> (result, timestamp)
|
||||
self._call_history: List[dict] = [] # Track calls in current session
|
||||
|
||||
def _make_cache_key(self, name: str, args: dict) -> str:
|
||||
"""Generate cache key from tool name and arguments"""
|
||||
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest()
|
||||
|
||||
def _get_cached(self, key: str) -> Optional[dict]:
|
||||
"""Get cached result if valid"""
|
||||
if not self.enable_cache:
|
||||
return None
|
||||
if key in self._cache:
|
||||
result, timestamp = self._cache[key]
|
||||
if time.time() - timestamp < self.cache_ttl:
|
||||
return result
|
||||
del self._cache[key]
|
||||
return None
|
||||
|
||||
def _set_cache(self, key: str, result: dict) -> None:
|
||||
"""Cache a result"""
|
||||
if self.enable_cache:
|
||||
self._cache[key] = (result, time.time())
|
||||
|
||||
def _check_duplicate_in_history(self, name: str, args: dict) -> Optional[dict]:
|
||||
"""Check if same tool+args was called before in this session"""
|
||||
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||
for record in self._call_history:
|
||||
if record["name"] == name and record["args_str"] == args_str:
|
||||
return record["result"]
|
||||
return None
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear call history (call this at start of new conversation turn)"""
|
||||
self._call_history.clear()
|
||||
|
||||
def process_tool_calls(
|
||||
self,
|
||||
|
|
@ -34,6 +74,7 @@ class ToolExecutor:
|
|||
Tool response message list, can be appended to messages
|
||||
"""
|
||||
results = []
|
||||
seen_calls = set() # Track calls within this batch
|
||||
|
||||
for call in tool_calls:
|
||||
name = call["function"]["name"]
|
||||
|
|
@ -48,7 +89,45 @@ class ToolExecutor:
|
|||
))
|
||||
continue
|
||||
|
||||
# Check for duplicate within same batch
|
||||
call_key = f"{name}:{json.dumps(args, sort_keys=True)}"
|
||||
if call_key in seen_calls:
|
||||
# Skip duplicate, but still return a result
|
||||
results.append(self._create_tool_result(
|
||||
call_id, name,
|
||||
{"success": True, "data": None, "cached": True, "duplicate": True}
|
||||
))
|
||||
continue
|
||||
seen_calls.add(call_key)
|
||||
|
||||
# Check history for previous call in this session
|
||||
history_result = self._check_duplicate_in_history(name, args)
|
||||
if history_result is not None:
|
||||
result = {**history_result, "cached": True}
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
continue
|
||||
|
||||
# Check cache
|
||||
cache_key = self._make_cache_key(name, args)
|
||||
cached_result = self._get_cached(cache_key)
|
||||
if cached_result is not None:
|
||||
result = {**cached_result, "cached": True}
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
continue
|
||||
|
||||
# Execute tool
|
||||
result = self.registry.execute(name, args)
|
||||
|
||||
# Cache the result
|
||||
self._set_cache(cache_key, result)
|
||||
|
||||
# Add to history
|
||||
self._call_history.append({
|
||||
"name": name,
|
||||
"args_str": json.dumps(args, sort_keys=True, ensure_ascii=False),
|
||||
"result": result
|
||||
})
|
||||
|
||||
results.append(self._create_tool_result(call_id, name, result))
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Tool factory - decorator registration"""
|
||||
from typing import Callable
|
||||
from .core import ToolDefinition, registry
|
||||
from backend.tools.core import ToolDefinition, registry
|
||||
|
||||
|
||||
def tool(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
"""Backend utilities"""
|
||||
from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user, record_token_usage, build_glm_messages
|
||||
|
||||
__all__ = [
|
||||
"ok",
|
||||
"err",
|
||||
"to_dict",
|
||||
"get_or_create_default_user",
|
||||
"record_token_usage",
|
||||
"build_glm_messages",
|
||||
]
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
"""Common helper functions"""
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from backend import db
|
||||
from backend.models import Conversation, Message, User, TokenUsage
|
||||
|
||||
|
||||
def get_or_create_default_user():
|
||||
"""Get or create default user"""
|
||||
user = User.query.filter_by(username="default").first()
|
||||
if not user:
|
||||
user = User(username="default", password="")
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def ok(data=None, message=None):
|
||||
"""Success response helper"""
|
||||
body = {"code": 0}
|
||||
if data is not None:
|
||||
body["data"] = data
|
||||
if message is not None:
|
||||
body["message"] = message
|
||||
from flask import jsonify
|
||||
return jsonify(body)
|
||||
|
||||
|
||||
def err(code, message):
|
||||
"""Error response helper"""
|
||||
from flask import jsonify
|
||||
return jsonify({"code": code, "message": message}), code
|
||||
|
||||
|
||||
def to_dict(inst, **extra):
|
||||
"""Convert model instance to dict"""
|
||||
d = {c.name: getattr(inst, c.name) for c in inst.__table__.columns}
|
||||
for k in ("created_at", "updated_at"):
|
||||
if k in d and hasattr(d[k], "strftime"):
|
||||
d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# Parse tool_calls JSON if present
|
||||
if "tool_calls" in d and d["tool_calls"]:
|
||||
try:
|
||||
d["tool_calls"] = json.loads(d["tool_calls"])
|
||||
except:
|
||||
pass
|
||||
|
||||
# Filter out None values for cleaner API response
|
||||
d = {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
d.update(extra)
|
||||
return d
|
||||
|
||||
|
||||
def record_token_usage(user_id, model, prompt_tokens, completion_tokens):
|
||||
"""Record token usage"""
|
||||
today = date.today()
|
||||
usage = TokenUsage.query.filter_by(
|
||||
user_id=user_id, date=today, model=model
|
||||
).first()
|
||||
if usage:
|
||||
usage.prompt_tokens += prompt_tokens
|
||||
usage.completion_tokens += completion_tokens
|
||||
usage.total_tokens += prompt_tokens + completion_tokens
|
||||
else:
|
||||
usage = TokenUsage(
|
||||
user_id=user_id,
|
||||
date=today,
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
db.session.add(usage)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def build_glm_messages(conv):
|
||||
"""Build messages list for GLM API from conversation"""
|
||||
msgs = []
|
||||
if conv.system_prompt:
|
||||
msgs.append({"role": "system", "content": conv.system_prompt})
|
||||
# Query messages directly to avoid detached instance warning
|
||||
messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all()
|
||||
for m in messages:
|
||||
msgs.append({"role": m.role, "content": m.content})
|
||||
return msgs
|
||||
|
|
@ -450,6 +450,12 @@ init_tools()
|
|||
| data | `text_process` | 文本处理 | - |
|
||||
| data | `json_process` | JSON处理 | - |
|
||||
| weather | `get_weather` | 天气查询 | - (模拟数据) |
|
||||
| file | `file_read` | 读取文件 | - |
|
||||
| file | `file_write` | 写入文件 | - |
|
||||
| file | `file_delete` | 删除文件 | - |
|
||||
| file | `file_list` | 列出目录 | - |
|
||||
| file | `file_exists` | 检查存在 | - |
|
||||
| file | `file_mkdir` | 创建目录 | - |
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue