From 362ab15338c8b8fc3bef2cec2c73bc11f0d4b346 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 24 Mar 2026 23:41:27 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E5=90=8E?= =?UTF-8?q?=E7=AB=AF=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/__init__.py | 6 +- backend/config.py | 9 + backend/models.py | 2 +- backend/routes.py | 588 ------------------------------ backend/routes/__init__.py | 23 ++ backend/routes/conversations.py | 72 ++++ backend/routes/messages.py | 75 ++++ backend/routes/models.py | 12 + backend/routes/stats.py | 84 +++++ backend/routes/tools.py | 16 + backend/services/__init__.py | 8 + backend/services/chat.py | 254 +++++++++++++ backend/services/glm_client.py | 48 +++ backend/tools/__init__.py | 8 +- backend/tools/builtin/__init__.py | 5 +- backend/tools/builtin/crawler.py | 4 +- backend/tools/builtin/data.py | 4 +- backend/tools/builtin/file_ops.py | 360 ++++++++++++++++++ backend/tools/builtin/weather.py | 2 +- backend/tools/executor.py | 89 ++++- backend/tools/factory.py | 2 +- backend/utils/__init__.py | 11 + backend/utils/helpers.py | 88 +++++ docs/ToolSystemDesign.md | 6 + 24 files changed, 1167 insertions(+), 609 deletions(-) create mode 100644 backend/config.py delete mode 100644 backend/routes.py create mode 100644 backend/routes/__init__.py create mode 100644 backend/routes/conversations.py create mode 100644 backend/routes/messages.py create mode 100644 backend/routes/models.py create mode 100644 backend/routes/stats.py create mode 100644 backend/routes/tools.py create mode 100644 backend/services/__init__.py create mode 100644 backend/services/chat.py create mode 100644 backend/services/glm_client.py create mode 100644 backend/tools/builtin/file_ops.py create mode 100644 backend/utils/__init__.py create mode 100644 backend/utils/helpers.py diff --git a/backend/__init__.py b/backend/__init__.py index f603d89..487ef59 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -32,9 +32,9 @@ def create_app(): db.init_app(app) # Import after db is initialized - from .models import User, Conversation, Message, TokenUsage - from .routes import register_routes - from .tools import init_tools + from backend.models import User, Conversation, Message, TokenUsage + from backend.routes import register_routes + from backend.tools import init_tools register_routes(app) init_tools() diff --git a/backend/config.py b/backend/config.py new file mode 100644 index 0000000..2603ca5 --- /dev/null +++ b/backend/config.py @@ -0,0 +1,9 @@ +"""Configuration management""" +from backend import load_config + +_cfg = load_config() + +API_URL = _cfg.get("api_url") +API_KEY = _cfg["api_key"] +MODELS = _cfg.get("models", []) +DEFAULT_MODEL = _cfg.get("default_model", "glm-5") diff --git a/backend/models.py b/backend/models.py index 3186f40..3712993 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone from sqlalchemy.dialects.mysql import LONGTEXT -from . import db +from backend import db class User(db.Model): diff --git a/backend/routes.py b/backend/routes.py deleted file mode 100644 index 08f32fd..0000000 --- a/backend/routes.py +++ /dev/null @@ -1,588 +0,0 @@ -import uuid -import json -import os -import requests -from datetime import datetime -from flask import request, jsonify, Response, Blueprint, current_app -from . import db -from .models import Conversation, Message, User, TokenUsage -from . import load_config -from .tools import registry, ToolExecutor - -bp = Blueprint("api", __name__) - -cfg = load_config() -API_URL = cfg.get("api_url") -API_KEY = cfg["api_key"] -MODELS = cfg.get("models", []) -DEFAULT_MODEL = cfg.get("default_model", "glm-5") - - -# -- Helpers ---------------------------------------------- - -def get_or_create_default_user(): - user = User.query.filter_by(username="default").first() - if not user: - user = User(username="default", password="") - db.session.add(user) - db.session.commit() - return user - - -def ok(data=None, message=None): - body = {"code": 0} - if data is not None: - body["data"] = data - if message is not None: - body["message"] = message - return jsonify(body) - - -def err(code, message): - return jsonify({"code": code, "message": message}), code - - -def to_dict(inst, **extra): - d = {c.name: getattr(inst, c.name) for c in inst.__table__.columns} - for k in ("created_at", "updated_at"): - if k in d and hasattr(d[k], "strftime"): - d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ") - - # Parse tool_calls JSON if present - if "tool_calls" in d and d["tool_calls"]: - try: - d["tool_calls"] = json.loads(d["tool_calls"]) - except: - pass - - # Filter out None values for cleaner API response - d = {k: v for k, v in d.items() if v is not None} - - d.update(extra) - return d - - -def record_token_usage(user_id, model, prompt_tokens, completion_tokens): - """Record token usage""" - from datetime import date - today = date.today() - usage = TokenUsage.query.filter_by( - user_id=user_id, date=today, model=model - ).first() - if usage: - usage.prompt_tokens += prompt_tokens - usage.completion_tokens += completion_tokens - usage.total_tokens += prompt_tokens + completion_tokens - else: - usage = TokenUsage( - user_id=user_id, - date=today, - model=model, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - db.session.add(usage) - db.session.commit() - - -def build_glm_messages(conv): - """Build messages list for GLM API from conversation""" - msgs = [] - if conv.system_prompt: - msgs.append({"role": "system", "content": conv.system_prompt}) - # Query messages directly to avoid detached instance warning - messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all() - for m in messages: - msgs.append({"role": m.role, "content": m.content}) - return msgs - - -# -- Models API ------------------------------------------- - -@bp.route("/api/models", methods=["GET"]) -def list_models(): - """Get available model list""" - return ok(MODELS) - - -# -- Tools API -------------------------------------------- - -@bp.route("/api/tools", methods=["GET"]) -def list_tools(): - """Get available tool list""" - tools = registry.list_all() - return ok({ - "tools": tools, - "total": len(tools) - }) - - -# -- Token Usage Statistics -------------------------------- - -@bp.route("/api/stats/tokens", methods=["GET"]) -def token_stats(): - """Get token usage statistics""" - from sqlalchemy import func - from datetime import date, timedelta - - user = get_or_create_default_user() - period = request.args.get("period", "daily") # daily, weekly, monthly - - today = date.today() - - if period == "daily": - # Today's statistics - stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all() - result = { - "period": "daily", - "date": today.isoformat(), - "prompt_tokens": sum(s.prompt_tokens for s in stats), - "completion_tokens": sum(s.completion_tokens for s in stats), - "total_tokens": sum(s.total_tokens for s in stats), - "by_model": {s.model: {"prompt": s.prompt_tokens, "completion": s.completion_tokens, "total": s.total_tokens} for s in stats} - } - elif period == "weekly": - # Weekly statistics (last 7 days) - start_date = today - timedelta(days=6) - stats = TokenUsage.query.filter( - TokenUsage.user_id == user.id, - TokenUsage.date >= start_date, - TokenUsage.date <= today - ).all() - - daily_data = {} - for s in stats: - d = s.date.isoformat() - if d not in daily_data: - daily_data[d] = {"prompt": 0, "completion": 0, "total": 0} - daily_data[d]["prompt"] += s.prompt_tokens - daily_data[d]["completion"] += s.completion_tokens - daily_data[d]["total"] += s.total_tokens - - # Fill missing dates - for i in range(7): - d = (today - timedelta(days=6-i)).isoformat() - if d not in daily_data: - daily_data[d] = {"prompt": 0, "completion": 0, "total": 0} - - result = { - "period": "weekly", - "start_date": start_date.isoformat(), - "end_date": today.isoformat(), - "prompt_tokens": sum(s.prompt_tokens for s in stats), - "completion_tokens": sum(s.completion_tokens for s in stats), - "total_tokens": sum(s.total_tokens for s in stats), - "daily": daily_data - } - elif period == "monthly": - # Monthly statistics (last 30 days) - start_date = today - timedelta(days=29) - stats = TokenUsage.query.filter( - TokenUsage.user_id == user.id, - TokenUsage.date >= start_date, - TokenUsage.date <= today - ).all() - - daily_data = {} - for s in stats: - d = s.date.isoformat() - if d not in daily_data: - daily_data[d] = {"prompt": 0, "completion": 0, "total": 0} - daily_data[d]["prompt"] += s.prompt_tokens - daily_data[d]["completion"] += s.completion_tokens - daily_data[d]["total"] += s.total_tokens - - # Fill missing dates - for i in range(30): - d = (today - timedelta(days=29-i)).isoformat() - if d not in daily_data: - daily_data[d] = {"prompt": 0, "completion": 0, "total": 0} - - result = { - "period": "monthly", - "start_date": start_date.isoformat(), - "end_date": today.isoformat(), - "prompt_tokens": sum(s.prompt_tokens for s in stats), - "completion_tokens": sum(s.completion_tokens for s in stats), - "total_tokens": sum(s.total_tokens for s in stats), - "daily": daily_data - } - else: - return err(400, "invalid period") - - return ok(result) - - -# -- Conversation CRUD ------------------------------------ - -@bp.route("/api/conversations", methods=["GET", "POST"]) -def conversation_list(): - if request.method == "POST": - d = request.json or {} - user = get_or_create_default_user() - conv = Conversation( - id=str(uuid.uuid4()), - user_id=user.id, - title=d.get("title", ""), - model=d.get("model", DEFAULT_MODEL), - system_prompt=d.get("system_prompt", ""), - temperature=d.get("temperature", 1.0), - max_tokens=d.get("max_tokens", 65536), - thinking_enabled=d.get("thinking_enabled", False), - ) - db.session.add(conv) - db.session.commit() - return ok(to_dict(conv)) - - cursor = request.args.get("cursor") - limit = min(int(request.args.get("limit", 20)), 100) - user = get_or_create_default_user() - q = Conversation.query.filter_by(user_id=user.id) - if cursor: - q = q.filter(Conversation.updated_at < ( - db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.utcnow)) - rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all() - - items = [to_dict(r, message_count=r.messages.count()) for r in rows[:limit]] - return ok({ - "items": items, - "next_cursor": items[-1]["id"] if len(rows) > limit else None, - "has_more": len(rows) > limit, - }) - - -@bp.route("/api/conversations/", methods=["GET", "PATCH", "DELETE"]) -def conversation_detail(conv_id): - conv = db.session.get(Conversation, conv_id) - if not conv: - return err(404, "conversation not found") - - if request.method == "GET": - return ok(to_dict(conv)) - - if request.method == "DELETE": - db.session.delete(conv) - db.session.commit() - return ok(message="deleted") - - d = request.json or {} - for k in ("title", "model", "system_prompt", "temperature", "max_tokens", "thinking_enabled"): - if k in d: - setattr(conv, k, d[k]) - db.session.commit() - return ok(to_dict(conv)) - - -# -- Messages --------------------------------------------- - -@bp.route("/api/conversations//messages", methods=["GET", "POST"]) -def message_list(conv_id): - conv = db.session.get(Conversation, conv_id) - if not conv: - return err(404, "conversation not found") - - if request.method == "GET": - cursor = request.args.get("cursor") - limit = min(int(request.args.get("limit", 50)), 100) - q = Message.query.filter_by(conversation_id=conv_id) - if cursor: - q = q.filter(Message.created_at < ( - db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.utcnow)) - rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all() - - items = [to_dict(r) for r in rows[:limit]] - return ok({ - "items": items, - "next_cursor": items[-1]["id"] if len(rows) > limit else None, - "has_more": len(rows) > limit, - }) - - d = request.json or {} - content = (d.get("content") or "").strip() - if not content: - return err(400, "content is required") - - user_msg = Message(id=str(uuid.uuid4()), conversation_id=conv_id, role="user", content=content) - db.session.add(user_msg) - db.session.commit() - - tools_enabled = d.get("tools_enabled", True) - - if d.get("stream", False): - return _stream_response(conv, tools_enabled) - - return _sync_response(conv, tools_enabled) - - -@bp.route("/api/conversations//messages/", methods=["DELETE"]) -def delete_message(conv_id, msg_id): - conv = db.session.get(Conversation, conv_id) - if not conv: - return err(404, "conversation not found") - msg = db.session.get(Message, msg_id) - if not msg or msg.conversation_id != conv_id: - return err(404, "message not found") - db.session.delete(msg) - db.session.commit() - return ok(message="deleted") - - -# -- Chat Completion ---------------------------------- - -def _call_glm(conv, stream=False, tools=None, messages=None): - """Call GLM API""" - body = { - "model": conv.model, - "messages": messages if messages is not None else build_glm_messages(conv), - "max_tokens": conv.max_tokens, - "temperature": conv.temperature, - } - if conv.thinking_enabled: - body["thinking"] = {"type": "enabled"} - if tools: - body["tools"] = tools - body["tool_choice"] = "auto" - if stream: - body["stream"] = True - return requests.post( - API_URL, - headers={"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"}, - json=body, stream=stream, timeout=120, - ) - - -def _sync_response(conv, tools_enabled=True): - """Sync response with tool call support""" - executor = ToolExecutor(registry=registry) - tools = registry.list_all() if tools_enabled else None - messages = build_glm_messages(conv) - max_iterations = 5 # Max tool call iterations - - # Collect all tool calls and results - all_tool_calls = [] - all_tool_results = [] - - for _ in range(max_iterations): - try: - resp = _call_glm(conv, tools=tools, messages=messages) - resp.raise_for_status() - result = resp.json() - except Exception as e: - return err(500, f"upstream error: {e}") - - choice = result["choices"][0] - message = choice["message"] - - # If no tool calls, return final result - if not message.get("tool_calls"): - usage = result.get("usage", {}) - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - - # Merge tool results into tool_calls - merged_tool_calls = [] - for i, tc in enumerate(all_tool_calls): - merged_tc = dict(tc) - if i < len(all_tool_results): - merged_tc["result"] = all_tool_results[i]["content"] - merged_tool_calls.append(merged_tc) - - # Save assistant message with all tool calls (including results) - msg = Message( - id=str(uuid.uuid4()), - conversation_id=conv.id, - role="assistant", - content=message.get("content", ""), - token_count=completion_tokens, - thinking_content=message.get("reasoning_content", ""), - tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None - ) - db.session.add(msg) - db.session.commit() - - user = get_or_create_default_user() - record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens) - - return ok({ - "message": to_dict(msg, thinking_content=msg.thinking_content or None), - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": usage.get("total_tokens", 0) - }, - }) - - # Process tool calls - tool_calls = message["tool_calls"] - all_tool_calls.extend(tool_calls) - messages.append(message) - - # Execute tools and add results - tool_results = executor.process_tool_calls(tool_calls) - all_tool_results.extend(tool_results) - messages.extend(tool_results) - - return err(500, "exceeded maximum tool call iterations") - - -def _stream_response(conv, tools_enabled=True): - """Stream response with tool call support""" - conv_id = conv.id - conv_model = conv.model - app = current_app._get_current_object() - executor = ToolExecutor(registry=registry) - tools = registry.list_all() if tools_enabled else None - # Build messages BEFORE entering generator (in request context) - initial_messages = build_glm_messages(conv) - - def generate(): - messages = list(initial_messages) # Copy to avoid mutation - max_iterations = 5 - - # Collect all tool calls and results - all_tool_calls = [] - all_tool_results = [] - total_content = "" - total_thinking = "" - total_tokens = 0 - total_prompt_tokens = 0 - - for iteration in range(max_iterations): - full_content = "" - full_thinking = "" - token_count = 0 - prompt_tokens = 0 - msg_id = str(uuid.uuid4()) - tool_calls_list = [] - current_tool_call = None - - try: - with app.app_context(): - active_conv = db.session.get(Conversation, conv_id) - resp = _call_glm(active_conv, stream=True, tools=tools, messages=messages) - resp.raise_for_status() - - for line in resp.iter_lines(): - if not line: - continue - line = line.decode("utf-8") - if not line.startswith("data: "): - continue - data_str = line[6:] - if data_str == "[DONE]": - break - try: - chunk = json.loads(data_str) - except json.JSONDecodeError: - continue - - delta = chunk["choices"][0].get("delta", {}) - - # Process thinking chain - reasoning = delta.get("reasoning_content", "") - if reasoning: - full_thinking += reasoning - yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n" - - # Process text content - text = delta.get("content", "") - if text: - full_content += text - yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n" - - # Process tool calls - tool_calls_delta = delta.get("tool_calls", []) - for tc in tool_calls_delta: - idx = tc.get("index", 0) - if idx >= len(tool_calls_list): - tool_calls_list.append({ - "id": tc.get("id", ""), - "type": tc.get("type", "function"), - "function": {"name": "", "arguments": ""} - }) - if tc.get("id"): - tool_calls_list[idx]["id"] = tc["id"] - if tc.get("function"): - if tc["function"].get("name"): - tool_calls_list[idx]["function"]["name"] = tc["function"]["name"] - if tc["function"].get("arguments"): - tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"] - - usage = chunk.get("usage", {}) - if usage: - token_count = usage.get("completion_tokens", 0) - prompt_tokens = usage.get("prompt_tokens", 0) - - except Exception as e: - yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n" - return - - # If tool calls exist, execute and continue loop - if tool_calls_list: - # Collect tool calls - all_tool_calls.extend(tool_calls_list) - - # Send tool call info - yield f"event: tool_calls\ndata: {json.dumps({'calls': tool_calls_list}, ensure_ascii=False)}\n\n" - - # Execute tools - tool_results = executor.process_tool_calls(tool_calls_list) - messages.append({ - "role": "assistant", - "content": full_content or None, - "tool_calls": tool_calls_list - }) - messages.extend(tool_results) - - # Collect tool results - all_tool_results.extend(tool_results) - - # Send tool results - for tr in tool_results: - yield f"event: tool_result\ndata: {json.dumps({'name': tr['name'], 'content': tr['content']}, ensure_ascii=False)}\n\n" - - continue - - # No tool calls, finish - save everything - total_content = full_content - total_thinking = full_thinking - total_tokens = token_count - total_prompt_tokens = prompt_tokens - - # Merge tool results into tool_calls - merged_tool_calls = [] - for i, tc in enumerate(all_tool_calls): - merged_tc = dict(tc) - if i < len(all_tool_results): - merged_tc["result"] = all_tool_results[i]["content"] - merged_tool_calls.append(merged_tc) - - with app.app_context(): - # Save assistant message with all tool calls (including results) - msg = Message( - id=msg_id, - conversation_id=conv_id, - role="assistant", - content=total_content, - token_count=total_tokens, - thinking_content=total_thinking, - tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None - ) - db.session.add(msg) - db.session.commit() - - user = get_or_create_default_user() - record_token_usage(user.id, conv_model, total_prompt_tokens, total_tokens) - - yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': total_tokens})}\n\n" - return - - yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n" - - return Response(generate(), mimetype="text/event-stream", - headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) - - -def register_routes(app): - app.register_blueprint(bp) diff --git a/backend/routes/__init__.py b/backend/routes/__init__.py new file mode 100644 index 0000000..74e89b6 --- /dev/null +++ b/backend/routes/__init__.py @@ -0,0 +1,23 @@ +"""Route registration""" +from flask import Flask +from backend.routes.conversations import bp as conversations_bp +from backend.routes.messages import bp as messages_bp, init_chat_service +from backend.routes.models import bp as models_bp +from backend.routes.tools import bp as tools_bp +from backend.routes.stats import bp as stats_bp +from backend.services.glm_client import GLMClient +from backend.config import API_URL, API_KEY + + +def register_routes(app: Flask): + """Register all route blueprints""" + # Initialize GLM client and chat service + glm_client = GLMClient(API_URL, API_KEY) + init_chat_service(glm_client) + + # Register blueprints + app.register_blueprint(conversations_bp) + app.register_blueprint(messages_bp) + app.register_blueprint(models_bp) + app.register_blueprint(tools_bp) + app.register_blueprint(stats_bp) diff --git a/backend/routes/conversations.py b/backend/routes/conversations.py new file mode 100644 index 0000000..909bd58 --- /dev/null +++ b/backend/routes/conversations.py @@ -0,0 +1,72 @@ +"""Conversation API routes""" +import uuid +from datetime import datetime +from flask import Blueprint, request +from backend import db +from backend.models import Conversation +from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user +from backend.config import DEFAULT_MODEL + +bp = Blueprint("conversations", __name__) + + +@bp.route("/api/conversations", methods=["GET", "POST"]) +def conversation_list(): + """List or create conversations""" + if request.method == "POST": + d = request.json or {} + user = get_or_create_default_user() + conv = Conversation( + id=str(uuid.uuid4()), + user_id=user.id, + title=d.get("title", ""), + model=d.get("model", DEFAULT_MODEL), + system_prompt=d.get("system_prompt", ""), + temperature=d.get("temperature", 1.0), + max_tokens=d.get("max_tokens", 65536), + thinking_enabled=d.get("thinking_enabled", False), + ) + db.session.add(conv) + db.session.commit() + return ok(to_dict(conv)) + + # GET - list conversations + cursor = request.args.get("cursor") + limit = min(int(request.args.get("limit", 20)), 100) + user = get_or_create_default_user() + q = Conversation.query.filter_by(user_id=user.id) + if cursor: + q = q.filter(Conversation.updated_at < ( + db.session.query(Conversation.updated_at).filter_by(id=cursor).scalar() or datetime.utcnow)) + rows = q.order_by(Conversation.updated_at.desc()).limit(limit + 1).all() + + items = [to_dict(r, message_count=r.messages.count()) for r in rows[:limit]] + return ok({ + "items": items, + "next_cursor": items[-1]["id"] if len(rows) > limit else None, + "has_more": len(rows) > limit, + }) + + +@bp.route("/api/conversations/", methods=["GET", "PATCH", "DELETE"]) +def conversation_detail(conv_id): + """Get, update or delete a conversation""" + conv = db.session.get(Conversation, conv_id) + if not conv: + return err(404, "conversation not found") + + if request.method == "GET": + return ok(to_dict(conv)) + + if request.method == "DELETE": + db.session.delete(conv) + db.session.commit() + return ok(message="deleted") + + # PATCH - update conversation + d = request.json or {} + for k in ("title", "model", "system_prompt", "temperature", "max_tokens", "thinking_enabled"): + if k in d: + setattr(conv, k, d[k]) + db.session.commit() + return ok(to_dict(conv)) diff --git a/backend/routes/messages.py b/backend/routes/messages.py new file mode 100644 index 0000000..2f71d99 --- /dev/null +++ b/backend/routes/messages.py @@ -0,0 +1,75 @@ +"""Message API routes""" +import uuid +from datetime import datetime +from flask import Blueprint, request +from backend import db +from backend.models import Conversation, Message +from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user +from backend.services.chat import ChatService + + +bp = Blueprint("messages", __name__) + +# ChatService will be injected during registration +_chat_service = None + + +def init_chat_service(glm_client): + """Initialize chat service with GLM client""" + global _chat_service + _chat_service = ChatService(glm_client) + + +@bp.route("/api/conversations//messages", methods=["GET", "POST"]) +def message_list(conv_id): + """List or create messages""" + conv = db.session.get(Conversation, conv_id) + if not conv: + return err(404, "conversation not found") + + if request.method == "GET": + cursor = request.args.get("cursor") + limit = min(int(request.args.get("limit", 50)), 100) + q = Message.query.filter_by(conversation_id=conv_id) + if cursor: + q = q.filter(Message.created_at < ( + db.session.query(Message.created_at).filter_by(id=cursor).scalar() or datetime.utcnow)) + rows = q.order_by(Message.created_at.asc()).limit(limit + 1).all() + + items = [to_dict(r) for r in rows[:limit]] + return ok({ + "items": items, + "next_cursor": items[-1]["id"] if len(rows) > limit else None, + "has_more": len(rows) > limit, + }) + + # POST - create message and get AI response + d = request.json or {} + content = (d.get("content") or "").strip() + if not content: + return err(400, "content is required") + + user_msg = Message(id=str(uuid.uuid4()), conversation_id=conv_id, role="user", content=content) + db.session.add(user_msg) + db.session.commit() + + tools_enabled = d.get("tools_enabled", True) + + if d.get("stream", False): + return _chat_service.stream_response(conv, tools_enabled) + + return _chat_service.sync_response(conv, tools_enabled) + + +@bp.route("/api/conversations//messages/", methods=["DELETE"]) +def delete_message(conv_id, msg_id): + """Delete a message""" + conv = db.session.get(Conversation, conv_id) + if not conv: + return err(404, "conversation not found") + msg = db.session.get(Message, msg_id) + if not msg or msg.conversation_id != conv_id: + return err(404, "message not found") + db.session.delete(msg) + db.session.commit() + return ok(message="deleted") diff --git a/backend/routes/models.py b/backend/routes/models.py new file mode 100644 index 0000000..14cf6f2 --- /dev/null +++ b/backend/routes/models.py @@ -0,0 +1,12 @@ +"""Model list API routes""" +from flask import Blueprint +from backend.utils.helpers import ok +from backend.config import MODELS + +bp = Blueprint("models", __name__) + + +@bp.route("/api/models", methods=["GET"]) +def list_models(): + """Get available model list""" + return ok(MODELS) diff --git a/backend/routes/stats.py b/backend/routes/stats.py new file mode 100644 index 0000000..5454cf8 --- /dev/null +++ b/backend/routes/stats.py @@ -0,0 +1,84 @@ +"""Token statistics API routes""" +from datetime import date, timedelta +from flask import Blueprint, request +from sqlalchemy import func +from backend.models import TokenUsage +from backend.utils.helpers import ok, err, get_or_create_default_user + +bp = Blueprint("stats", __name__) + + +@bp.route("/api/stats/tokens", methods=["GET"]) +def token_stats(): + """Get token usage statistics""" + user = get_or_create_default_user() + period = request.args.get("period", "daily") + + today = date.today() + + if period == "daily": + stats = TokenUsage.query.filter_by(user_id=user.id, date=today).all() + result = { + "period": "daily", + "date": today.isoformat(), + "prompt_tokens": sum(s.prompt_tokens for s in stats), + "completion_tokens": sum(s.completion_tokens for s in stats), + "total_tokens": sum(s.total_tokens for s in stats), + "by_model": { + s.model: { + "prompt": s.prompt_tokens, + "completion": s.completion_tokens, + "total": s.total_tokens + } for s in stats + } + } + elif period == "weekly": + start_date = today - timedelta(days=6) + stats = TokenUsage.query.filter( + TokenUsage.user_id == user.id, + TokenUsage.date >= start_date, + TokenUsage.date <= today + ).all() + + result = _build_period_result(stats, "weekly", start_date, today, 7) + elif period == "monthly": + start_date = today - timedelta(days=29) + stats = TokenUsage.query.filter( + TokenUsage.user_id == user.id, + TokenUsage.date >= start_date, + TokenUsage.date <= today + ).all() + + result = _build_period_result(stats, "monthly", start_date, today, 30) + else: + return err(400, "invalid period") + + return ok(result) + + +def _build_period_result(stats, period, start_date, end_date, days): + """Build result for period-based statistics""" + daily_data = {} + for s in stats: + d = s.date.isoformat() + if d not in daily_data: + daily_data[d] = {"prompt": 0, "completion": 0, "total": 0} + daily_data[d]["prompt"] += s.prompt_tokens + daily_data[d]["completion"] += s.completion_tokens + daily_data[d]["total"] += s.total_tokens + + # Fill missing dates + for i in range(days): + d = (end_date - timedelta(days=days - 1 - i)).isoformat() + if d not in daily_data: + daily_data[d] = {"prompt": 0, "completion": 0, "total": 0} + + return { + "period": period, + "start_date": start_date.isoformat(), + "end_date": end_date.isoformat(), + "prompt_tokens": sum(s.prompt_tokens for s in stats), + "completion_tokens": sum(s.completion_tokens for s in stats), + "total_tokens": sum(s.total_tokens for s in stats), + "daily": daily_data + } diff --git a/backend/routes/tools.py b/backend/routes/tools.py new file mode 100644 index 0000000..a54bedf --- /dev/null +++ b/backend/routes/tools.py @@ -0,0 +1,16 @@ +"""Tool list API routes""" +from flask import Blueprint +from backend.tools import registry +from backend.utils.helpers import ok + +bp = Blueprint("tools", __name__) + + +@bp.route("/api/tools", methods=["GET"]) +def list_tools(): + """Get available tool list""" + tools = registry.list_all() + return ok({ + "tools": tools, + "total": len(tools) + }) diff --git a/backend/services/__init__.py b/backend/services/__init__.py new file mode 100644 index 0000000..6a94280 --- /dev/null +++ b/backend/services/__init__.py @@ -0,0 +1,8 @@ +"""Backend services""" +from backend.services.glm_client import GLMClient +from backend.services.chat import ChatService + +__all__ = [ + "GLMClient", + "ChatService", +] diff --git a/backend/services/chat.py b/backend/services/chat.py new file mode 100644 index 0000000..193ef6d --- /dev/null +++ b/backend/services/chat.py @@ -0,0 +1,254 @@ +"""Chat completion service""" +import json +import uuid +from flask import current_app, Response +from backend import db +from backend.models import Conversation, Message +from backend.tools import registry, ToolExecutor +from backend.utils.helpers import ( + get_or_create_default_user, + record_token_usage, + build_glm_messages, + ok, + err, + to_dict, +) +from backend.services.glm_client import GLMClient + + +class ChatService: + """Chat completion service with tool support""" + + MAX_ITERATIONS = 5 + + def __init__(self, glm_client: GLMClient): + self.glm_client = glm_client + self.executor = ToolExecutor(registry=registry) + + def sync_response(self, conv: Conversation, tools_enabled: bool = True): + """Sync response with tool call support""" + tools = registry.list_all() if tools_enabled else None + messages = build_glm_messages(conv) + + # Clear tool call history for new request + self.executor.clear_history() + + all_tool_calls = [] + all_tool_results = [] + + for _ in range(self.MAX_ITERATIONS): + try: + resp = self.glm_client.call( + model=conv.model, + messages=messages, + max_tokens=conv.max_tokens, + temperature=conv.temperature, + thinking_enabled=conv.thinking_enabled, + tools=tools, + ) + resp.raise_for_status() + result = resp.json() + except Exception as e: + return err(500, f"upstream error: {e}") + + choice = result["choices"][0] + message = choice["message"] + + # No tool calls - return final result + if not message.get("tool_calls"): + usage = result.get("usage", {}) + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) + + merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results) + + msg = Message( + id=str(uuid.uuid4()), + conversation_id=conv.id, + role="assistant", + content=message.get("content", ""), + token_count=completion_tokens, + thinking_content=message.get("reasoning_content", ""), + tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None + ) + db.session.add(msg) + db.session.commit() + + user = get_or_create_default_user() + record_token_usage(user.id, conv.model, prompt_tokens, completion_tokens) + + return ok({ + "message": to_dict(msg, thinking_content=msg.thinking_content or None), + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": usage.get("total_tokens", 0) + }, + }) + + # Process tool calls + tool_calls = message["tool_calls"] + all_tool_calls.extend(tool_calls) + messages.append(message) + + tool_results = self.executor.process_tool_calls(tool_calls) + all_tool_results.extend(tool_results) + messages.extend(tool_results) + + return err(500, "exceeded maximum tool call iterations") + + def stream_response(self, conv: Conversation, tools_enabled: bool = True): + """Stream response with tool call support""" + conv_id = conv.id + conv_model = conv.model + app = current_app._get_current_object() + tools = registry.list_all() if tools_enabled else None + initial_messages = build_glm_messages(conv) + + # Clear tool call history for new request + self.executor.clear_history() + + def generate(): + messages = list(initial_messages) + all_tool_calls = [] + all_tool_results = [] + + for iteration in range(self.MAX_ITERATIONS): + full_content = "" + full_thinking = "" + token_count = 0 + prompt_tokens = 0 + msg_id = str(uuid.uuid4()) + tool_calls_list = [] + + try: + with app.app_context(): + active_conv = db.session.get(Conversation, conv_id) + resp = self.glm_client.call( + model=active_conv.model, + messages=messages, + max_tokens=active_conv.max_tokens, + temperature=active_conv.temperature, + thinking_enabled=active_conv.thinking_enabled, + tools=tools, + stream=True, + ) + resp.raise_for_status() + + for line in resp.iter_lines(): + if not line: + continue + line = line.decode("utf-8") + if not line.startswith("data: "): + continue + data_str = line[6:] + if data_str == "[DONE]": + break + try: + chunk = json.loads(data_str) + except json.JSONDecodeError: + continue + + delta = chunk["choices"][0].get("delta", {}) + + # Process thinking + reasoning = delta.get("reasoning_content", "") + if reasoning: + full_thinking += reasoning + yield f"event: thinking\ndata: {json.dumps({'content': reasoning}, ensure_ascii=False)}\n\n" + + # Process text + text = delta.get("content", "") + if text: + full_content += text + yield f"event: message\ndata: {json.dumps({'content': text}, ensure_ascii=False)}\n\n" + + # Process tool calls + tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list) + + usage = chunk.get("usage", {}) + if usage: + token_count = usage.get("completion_tokens", 0) + prompt_tokens = usage.get("prompt_tokens", 0) + + except Exception as e: + yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n" + return + + # Tool calls exist - execute and continue + if tool_calls_list: + all_tool_calls.extend(tool_calls_list) + yield f"event: tool_calls\ndata: {json.dumps({'calls': tool_calls_list}, ensure_ascii=False)}\n\n" + + tool_results = self.executor.process_tool_calls(tool_calls_list) + messages.append({ + "role": "assistant", + "content": full_content or None, + "tool_calls": tool_calls_list + }) + messages.extend(tool_results) + all_tool_results.extend(tool_results) + + for tr in tool_results: + yield f"event: tool_result\ndata: {json.dumps({'name': tr['name'], 'content': tr['content']}, ensure_ascii=False)}\n\n" + continue + + # No tool calls - finish + merged_tool_calls = self._merge_tool_results(all_tool_calls, all_tool_results) + + with app.app_context(): + msg = Message( + id=msg_id, + conversation_id=conv_id, + role="assistant", + content=full_content, + token_count=token_count, + thinking_content=full_thinking, + tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None + ) + db.session.add(msg) + db.session.commit() + + user = get_or_create_default_user() + record_token_usage(user.id, conv_model, prompt_tokens, token_count) + + yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': token_count})}\n\n" + return + + yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n" + + return Response( + generate(), + mimetype="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"} + ) + + def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list: + """Process tool calls from streaming delta""" + tool_calls_delta = delta.get("tool_calls", []) + for tc in tool_calls_delta: + idx = tc.get("index", 0) + if idx >= len(tool_calls_list): + tool_calls_list.append({ + "id": tc.get("id", ""), + "type": tc.get("type", "function"), + "function": {"name": "", "arguments": ""} + }) + if tc.get("id"): + tool_calls_list[idx]["id"] = tc["id"] + if tc.get("function"): + if tc["function"].get("name"): + tool_calls_list[idx]["function"]["name"] = tc["function"]["name"] + if tc["function"].get("arguments"): + tool_calls_list[idx]["function"]["arguments"] += tc["function"]["arguments"] + return tool_calls_list + + def _merge_tool_results(self, tool_calls: list, tool_results: list) -> list: + """Merge tool results into tool calls""" + merged = [] + for i, tc in enumerate(tool_calls): + merged_tc = dict(tc) + if i < len(tool_results): + merged_tc["result"] = tool_results[i]["content"] + merged.append(merged_tc) + return merged diff --git a/backend/services/glm_client.py b/backend/services/glm_client.py new file mode 100644 index 0000000..aaf062b --- /dev/null +++ b/backend/services/glm_client.py @@ -0,0 +1,48 @@ +"""GLM API client""" +import requests +from typing import Optional, List + + +class GLMClient: + """GLM API client for chat completions""" + + def __init__(self, api_url: str, api_key: str): + self.api_url = api_url + self.api_key = api_key + + def call( + self, + model: str, + messages: List[dict], + max_tokens: int = 65536, + temperature: float = 1.0, + thinking_enabled: bool = False, + tools: Optional[List[dict]] = None, + stream: bool = False, + timeout: int = 120, + ): + """Call GLM API""" + body = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + } + if thinking_enabled: + body["thinking"] = {"type": "enabled"} + if tools: + body["tools"] = tools + body["tool_choice"] = "auto" + if stream: + body["stream"] = True + + return requests.post( + self.api_url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + }, + json=body, + stream=stream, + timeout=timeout, + ) diff --git a/backend/tools/__init__.py b/backend/tools/__init__.py index bc3a71b..be47237 100644 --- a/backend/tools/__init__.py +++ b/backend/tools/__init__.py @@ -15,9 +15,9 @@ Usage: result = registry.execute("web_search", {"query": "Python"}) """ -from .core import ToolDefinition, ToolResult, ToolRegistry, registry -from .factory import tool, register_tool -from .executor import ToolExecutor +from backend.tools.core import ToolDefinition, ToolResult, ToolRegistry, registry +from backend.tools.factory import tool, register_tool +from backend.tools.executor import ToolExecutor def init_tools() -> None: @@ -26,7 +26,7 @@ def init_tools() -> None: Importing builtin module automatically registers all decorator-defined tools """ - from .builtin import crawler, data, weather # noqa: F401 + from backend.tools.builtin import crawler, data, weather, file_ops # noqa: F401 # Public API exports diff --git a/backend/tools/builtin/__init__.py b/backend/tools/builtin/__init__.py index 4baa390..3f826da 100644 --- a/backend/tools/builtin/__init__.py +++ b/backend/tools/builtin/__init__.py @@ -1,3 +1,4 @@ """Built-in tools""" -from .crawler import * -from .data import * +from backend.tools.builtin.crawler import * +from backend.tools.builtin.data import * +from backend.tools.builtin.file_ops import * diff --git a/backend/tools/builtin/crawler.py b/backend/tools/builtin/crawler.py index 04bcbae..49d8f39 100644 --- a/backend/tools/builtin/crawler.py +++ b/backend/tools/builtin/crawler.py @@ -1,6 +1,6 @@ """Crawler related tools""" -from ..factory import tool -from ..services import SearchService, FetchService +from backend.tools.factory import tool +from backend.tools.services import SearchService, FetchService @tool( diff --git a/backend/tools/builtin/data.py b/backend/tools/builtin/data.py index caaa7ca..f2d98aa 100644 --- a/backend/tools/builtin/data.py +++ b/backend/tools/builtin/data.py @@ -1,6 +1,6 @@ """Data processing related tools""" -from ..factory import tool -from ..services import CalculatorService +from backend.tools.factory import tool +from backend.tools.services import CalculatorService @tool( diff --git a/backend/tools/builtin/file_ops.py b/backend/tools/builtin/file_ops.py new file mode 100644 index 0000000..798a9b2 --- /dev/null +++ b/backend/tools/builtin/file_ops.py @@ -0,0 +1,360 @@ +"""File operation tools""" +import os +import json +from pathlib import Path +from typing import Optional +from backend.tools.factory import tool + + +# Base directory for file operations (sandbox) +# Set to None to allow any path, or set a specific directory for security +BASE_DIR = Path(__file__).parent.parent.parent.parent # project root + + +def _resolve_path(path: str) -> Path: + """Resolve path and ensure it's within allowed directory""" + p = Path(path) + if not p.is_absolute(): + p = BASE_DIR / p + p = p.resolve() + + # Security check: ensure path is within BASE_DIR + if BASE_DIR: + try: + p.relative_to(BASE_DIR.resolve()) + except ValueError: + raise ValueError(f"Path '{path}' is outside allowed directory") + + return p + + +@tool( + name="file_read", + description="Read content from a file. Use when you need to read file content.", + parameters={ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path to read (relative to project root or absolute)" + }, + "encoding": { + "type": "string", + "description": "File encoding, default utf-8", + "default": "utf-8" + } + }, + "required": ["path"] + }, + category="file" +) +def file_read(arguments: dict) -> dict: + """ + Read file tool + + Args: + arguments: { + "path": "file.txt", + "encoding": "utf-8" + } + + Returns: + {"content": "...", "size": 100} + """ + try: + path = _resolve_path(arguments["path"]) + encoding = arguments.get("encoding", "utf-8") + + if not path.exists(): + return {"error": f"File not found: {path}"} + + if not path.is_file(): + return {"error": f"Path is not a file: {path}"} + + content = path.read_text(encoding=encoding) + return { + "content": content, + "size": len(content), + "path": str(path) + } + except Exception as e: + return {"error": str(e)} + + +@tool( + name="file_write", + description="Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Use when you need to create or update a file.", + parameters={ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path to write (relative to project root or absolute)" + }, + "content": { + "type": "string", + "description": "Content to write to the file" + }, + "encoding": { + "type": "string", + "description": "File encoding, default utf-8", + "default": "utf-8" + }, + "mode": { + "type": "string", + "description": "Write mode: 'write' (overwrite) or 'append'", + "enum": ["write", "append"], + "default": "write" + } + }, + "required": ["path", "content"] + }, + category="file" +) +def file_write(arguments: dict) -> dict: + """ + Write file tool + + Args: + arguments: { + "path": "file.txt", + "content": "Hello World", + "encoding": "utf-8", + "mode": "write" + } + + Returns: + {"success": true, "size": 11} + """ + try: + path = _resolve_path(arguments["path"]) + content = arguments["content"] + encoding = arguments.get("encoding", "utf-8") + mode = arguments.get("mode", "write") + + # Create parent directories if needed + path.parent.mkdir(parents=True, exist_ok=True) + + # Write or append + if mode == "append": + with open(path, "a", encoding=encoding) as f: + f.write(content) + else: + path.write_text(content, encoding=encoding) + + return { + "success": True, + "size": len(content), + "path": str(path), + "mode": mode + } + except Exception as e: + return {"error": str(e)} + + +@tool( + name="file_delete", + description="Delete a file. Use when you need to remove a file.", + parameters={ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path to delete (relative to project root or absolute)" + } + }, + "required": ["path"] + }, + category="file" +) +def file_delete(arguments: dict) -> dict: + """ + Delete file tool + + Args: + arguments: { + "path": "file.txt" + } + + Returns: + {"success": true} + """ + try: + path = _resolve_path(arguments["path"]) + + if not path.exists(): + return {"error": f"File not found: {path}"} + + if not path.is_file(): + return {"error": f"Path is not a file: {path}"} + + path.unlink() + return {"success": True, "path": str(path)} + except Exception as e: + return {"error": str(e)} + + +@tool( + name="file_list", + description="List files and directories in a directory. Use when you need to see what files exist.", + parameters={ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory path to list (relative to project root or absolute)", + "default": "." + }, + "pattern": { + "type": "string", + "description": "Glob pattern to filter files, e.g. '*.py'", + "default": "*" + } + }, + "required": [] + }, + category="file" +) +def file_list(arguments: dict) -> dict: + """ + List directory contents + + Args: + arguments: { + "path": ".", + "pattern": "*" + } + + Returns: + {"files": [...], "directories": [...]} + """ + try: + path = _resolve_path(arguments.get("path", ".")) + pattern = arguments.get("pattern", "*") + + if not path.exists(): + return {"error": f"Directory not found: {path}"} + + if not path.is_dir(): + return {"error": f"Path is not a directory: {path}"} + + files = [] + directories = [] + + for item in path.glob(pattern): + if item.is_file(): + files.append({ + "name": item.name, + "size": item.stat().st_size, + "path": str(item.relative_to(BASE_DIR)) if BASE_DIR else str(item) + }) + elif item.is_dir(): + directories.append({ + "name": item.name, + "path": str(item.relative_to(BASE_DIR)) if BASE_DIR else str(item) + }) + + return { + "path": str(path), + "files": files, + "directories": directories, + "total_files": len(files), + "total_dirs": len(directories) + } + except Exception as e: + return {"error": str(e)} + + +@tool( + name="file_exists", + description="Check if a file or directory exists. Use when you need to verify file existence.", + parameters={ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to check (relative to project root or absolute)" + } + }, + "required": ["path"] + }, + category="file" +) +def file_exists(arguments: dict) -> dict: + """ + Check if file/directory exists + + Args: + arguments: { + "path": "file.txt" + } + + Returns: + {"exists": true, "type": "file"} + """ + try: + path = _resolve_path(arguments["path"]) + + if not path.exists(): + return {"exists": False, "path": str(path)} + + if path.is_file(): + return { + "exists": True, + "type": "file", + "path": str(path), + "size": path.stat().st_size + } + elif path.is_dir(): + return { + "exists": True, + "type": "directory", + "path": str(path) + } + else: + return { + "exists": True, + "type": "other", + "path": str(path) + } + except Exception as e: + return {"error": str(e)} + + +@tool( + name="file_mkdir", + description="Create a directory. Creates parent directories if needed. Use when you need to create a folder.", + parameters={ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory path to create (relative to project root or absolute)" + } + }, + "required": ["path"] + }, + category="file" +) +def file_mkdir(arguments: dict) -> dict: + """ + Create directory + + Args: + arguments: { + "path": "new/folder" + } + + Returns: + {"success": true} + """ + try: + path = _resolve_path(arguments["path"]) + path.mkdir(parents=True, exist_ok=True) + return { + "success": True, + "path": str(path), + "created": not path.exists() or path.is_dir() + } + except Exception as e: + return {"error": str(e)} diff --git a/backend/tools/builtin/weather.py b/backend/tools/builtin/weather.py index f70ad14..ce0f1da 100644 --- a/backend/tools/builtin/weather.py +++ b/backend/tools/builtin/weather.py @@ -1,5 +1,5 @@ """Weather related tools""" -from ..factory import tool +from backend.tools.factory import tool @tool( diff --git a/backend/tools/executor.py b/backend/tools/executor.py index a20dbdc..38bc676 100644 --- a/backend/tools/executor.py +++ b/backend/tools/executor.py @@ -1,22 +1,62 @@ -"""Tool executor""" +"""Tool executor with caching and deduplication""" import json import time -from typing import List, Dict, Optional, Generator, Any -from .core import ToolRegistry, registry +import hashlib +from typing import List, Dict, Optional, Any +from backend.tools.core import ToolRegistry, registry class ToolExecutor: - """Tool call executor""" + """Tool call executor with caching and deduplication""" def __init__( self, registry: Optional[ToolRegistry] = None, api_url: Optional[str] = None, - api_key: Optional[str] = None + api_key: Optional[str] = None, + enable_cache: bool = True, + cache_ttl: int = 300, # 5 minutes ): self.registry = registry or ToolRegistry() self.api_url = api_url self.api_key = api_key + self.enable_cache = enable_cache + self.cache_ttl = cache_ttl + self._cache: Dict[str, tuple] = {} # key -> (result, timestamp) + self._call_history: List[dict] = [] # Track calls in current session + + def _make_cache_key(self, name: str, args: dict) -> str: + """Generate cache key from tool name and arguments""" + args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) + return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest() + + def _get_cached(self, key: str) -> Optional[dict]: + """Get cached result if valid""" + if not self.enable_cache: + return None + if key in self._cache: + result, timestamp = self._cache[key] + if time.time() - timestamp < self.cache_ttl: + return result + del self._cache[key] + return None + + def _set_cache(self, key: str, result: dict) -> None: + """Cache a result""" + if self.enable_cache: + self._cache[key] = (result, time.time()) + + def _check_duplicate_in_history(self, name: str, args: dict) -> Optional[dict]: + """Check if same tool+args was called before in this session""" + args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) + for record in self._call_history: + if record["name"] == name and record["args_str"] == args_str: + return record["result"] + return None + + def clear_history(self) -> None: + """Clear call history (call this at start of new conversation turn)""" + self._call_history.clear() def process_tool_calls( self, @@ -34,6 +74,7 @@ class ToolExecutor: Tool response message list, can be appended to messages """ results = [] + seen_calls = set() # Track calls within this batch for call in tool_calls: name = call["function"]["name"] @@ -48,7 +89,45 @@ class ToolExecutor: )) continue + # Check for duplicate within same batch + call_key = f"{name}:{json.dumps(args, sort_keys=True)}" + if call_key in seen_calls: + # Skip duplicate, but still return a result + results.append(self._create_tool_result( + call_id, name, + {"success": True, "data": None, "cached": True, "duplicate": True} + )) + continue + seen_calls.add(call_key) + + # Check history for previous call in this session + history_result = self._check_duplicate_in_history(name, args) + if history_result is not None: + result = {**history_result, "cached": True} + results.append(self._create_tool_result(call_id, name, result)) + continue + + # Check cache + cache_key = self._make_cache_key(name, args) + cached_result = self._get_cached(cache_key) + if cached_result is not None: + result = {**cached_result, "cached": True} + results.append(self._create_tool_result(call_id, name, result)) + continue + + # Execute tool result = self.registry.execute(name, args) + + # Cache the result + self._set_cache(cache_key, result) + + # Add to history + self._call_history.append({ + "name": name, + "args_str": json.dumps(args, sort_keys=True, ensure_ascii=False), + "result": result + }) + results.append(self._create_tool_result(call_id, name, result)) return results diff --git a/backend/tools/factory.py b/backend/tools/factory.py index a48fb8c..1047dc7 100644 --- a/backend/tools/factory.py +++ b/backend/tools/factory.py @@ -1,6 +1,6 @@ """Tool factory - decorator registration""" from typing import Callable -from .core import ToolDefinition, registry +from backend.tools.core import ToolDefinition, registry def tool( diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py new file mode 100644 index 0000000..c0b9b76 --- /dev/null +++ b/backend/utils/__init__.py @@ -0,0 +1,11 @@ +"""Backend utilities""" +from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user, record_token_usage, build_glm_messages + +__all__ = [ + "ok", + "err", + "to_dict", + "get_or_create_default_user", + "record_token_usage", + "build_glm_messages", +] diff --git a/backend/utils/helpers.py b/backend/utils/helpers.py new file mode 100644 index 0000000..06034b0 --- /dev/null +++ b/backend/utils/helpers.py @@ -0,0 +1,88 @@ +"""Common helper functions""" +import json +from datetime import datetime, date +from backend import db +from backend.models import Conversation, Message, User, TokenUsage + + +def get_or_create_default_user(): + """Get or create default user""" + user = User.query.filter_by(username="default").first() + if not user: + user = User(username="default", password="") + db.session.add(user) + db.session.commit() + return user + + +def ok(data=None, message=None): + """Success response helper""" + body = {"code": 0} + if data is not None: + body["data"] = data + if message is not None: + body["message"] = message + from flask import jsonify + return jsonify(body) + + +def err(code, message): + """Error response helper""" + from flask import jsonify + return jsonify({"code": code, "message": message}), code + + +def to_dict(inst, **extra): + """Convert model instance to dict""" + d = {c.name: getattr(inst, c.name) for c in inst.__table__.columns} + for k in ("created_at", "updated_at"): + if k in d and hasattr(d[k], "strftime"): + d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ") + + # Parse tool_calls JSON if present + if "tool_calls" in d and d["tool_calls"]: + try: + d["tool_calls"] = json.loads(d["tool_calls"]) + except: + pass + + # Filter out None values for cleaner API response + d = {k: v for k, v in d.items() if v is not None} + + d.update(extra) + return d + + +def record_token_usage(user_id, model, prompt_tokens, completion_tokens): + """Record token usage""" + today = date.today() + usage = TokenUsage.query.filter_by( + user_id=user_id, date=today, model=model + ).first() + if usage: + usage.prompt_tokens += prompt_tokens + usage.completion_tokens += completion_tokens + usage.total_tokens += prompt_tokens + completion_tokens + else: + usage = TokenUsage( + user_id=user_id, + date=today, + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + db.session.add(usage) + db.session.commit() + + +def build_glm_messages(conv): + """Build messages list for GLM API from conversation""" + msgs = [] + if conv.system_prompt: + msgs.append({"role": "system", "content": conv.system_prompt}) + # Query messages directly to avoid detached instance warning + messages = Message.query.filter_by(conversation_id=conv.id).order_by(Message.created_at.asc()).all() + for m in messages: + msgs.append({"role": m.role, "content": m.content}) + return msgs diff --git a/docs/ToolSystemDesign.md b/docs/ToolSystemDesign.md index 3357a53..766838f 100644 --- a/docs/ToolSystemDesign.md +++ b/docs/ToolSystemDesign.md @@ -450,6 +450,12 @@ init_tools() | data | `text_process` | 文本处理 | - | | data | `json_process` | JSON处理 | - | | weather | `get_weather` | 天气查询 | - (模拟数据) | +| file | `file_read` | 读取文件 | - | +| file | `file_write` | 写入文件 | - | +| file | `file_delete` | 删除文件 | - | +| file | `file_list` | 列出目录 | - | +| file | `file_exists` | 检查存在 | - | +| file | `file_mkdir` | 创建目录 | - | ---