From 8325100c9006d3108de5984fdfc7a9f75f145f5e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 26 Mar 2026 11:48:56 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E9=A1=B9=E7=9B=AE?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 93 +- backend/__init__.py | 2 +- backend/models.py | 23 + backend/routes/__init__.py | 2 + backend/routes/messages.py | 8 +- backend/routes/projects.py | 331 ++++++ backend/services/chat.py | 48 +- backend/tools/builtin/__init__.py | 2 + backend/tools/builtin/file_ops.py | 193 ++-- backend/tools/executor.py | 8 +- backend/utils/__init__.py | 4 +- backend/utils/helpers.py | 11 +- backend/utils/workspace.py | 180 ++++ docs/Design.md | 176 +++- docs/ToolSystemDesign.md | 1052 +++++--------------- frontend/src/App.vue | 11 +- frontend/src/api/index.js | 48 +- frontend/src/components/ProcessBlock.vue | 10 +- frontend/src/components/ProjectManager.vue | 518 ++++++++++ frontend/src/components/Sidebar.vue | 100 +- 20 files changed, 1815 insertions(+), 1005 deletions(-) create mode 100644 backend/routes/projects.py create mode 100644 backend/utils/workspace.py create mode 100644 frontend/src/components/ProjectManager.vue diff --git a/README.md b/README.md index 4062482..cd19642 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ # NanoClaw -基于 GLM 大语言模型的 AI 对话应用,支持工具调用、思维链和流式回复。 +基于 LLM 大语言模型的 AI 对话应用,支持工具调用、思维链、流式回复和工作目录隔离。 ## 功能特性 - 💬 **多轮对话** - 支持上下文管理的多轮对话 - 🔧 **工具调用** - 网页搜索、代码执行、文件操作等 - 🧠 **思维链** - 支持链式思考推理 +- 📁 **工作目录** - 项目级文件隔离,安全操作 - 📊 **Token 统计** - 按日/周/月统计使用量 - 🔄 **流式响应** - 实时 SSE 流式输出 - 💾 **多数据库** - 支持 MySQL、SQLite、PostgreSQL @@ -28,7 +29,7 @@ pip install -e . backend_port: 3000 frontend_port: 4000 -# AI API +# LLM API api_key: {{your-api-key}} api_url: https://open.bigmodel.cn/api/paas/v4/chat/completions @@ -36,40 +37,32 @@ api_url: https://open.bigmodel.cn/api/paas/v4/chat/completions models: - id: glm-5 name: GLM-5 - - id: glm-5-turbo - name: GLM-5 Turbo - - id: glm-4.5 - name: GLM-4.5 - - id: glm-4.6 - name: GLM-4.6 - - id: glm-4.7 - name: GLM-4.7 + - id: glm-4-plus + name: GLM-4 Plus default_model: glm-5 +# Workspace root directory +workspace_root: ./workspaces + # Database Configuration -# Supported types: mysql, sqlite, postgresql db_type: sqlite - -# MySQL/PostgreSQL Settings (ignored for sqlite) -db_host: localhost -db_port: 3306 -db_user: root -db_password: "123456" -db_name: nano_claw - -# SQLite Settings (ignored for mysql/postgresql) db_sqlite_file: nano_claw.db - ``` -### 3. 启动后端 +### 3. 数据库迁移(首次运行或升级) + +```bash +python -m backend.migrations.add_project_support +``` + +### 4. 启动后端 ```bash python -m backend.run ``` -### 4. 启动前端 +### 5. 启动前端 ```bash cd frontend @@ -83,6 +76,10 @@ npm run dev backend/ ├── models.py # SQLAlchemy 数据模型 ├── routes/ # API 路由 +│ ├── conversations.py +│ ├── messages.py +│ ├── projects.py # 项目管理 +│ └── ... ├── services/ # 业务逻辑 │ ├── chat.py # 聊天补全服务 │ └── glm_client.py @@ -90,7 +87,10 @@ backend/ │ ├── core.py # 核心类 │ ├── executor.py # 工具执行器 │ └── builtin/ # 内置工具 -└── utils/ # 辅助函数 +├── utils/ # 辅助函数 +│ ├── helpers.py +│ └── workspace.py # 工作目录工具 +└── migrations/ # 数据库迁移 frontend/ └── src/ @@ -99,6 +99,24 @@ frontend/ └── views/ # 页面 ``` +## 工作目录系统 + +### 概述 + +工作目录系统为文件操作提供安全隔离,确保 AI 只能访问指定项目目录内的文件。 + +### 使用流程 + +1. **创建项目** - 在侧边栏点击"新建项目"或上传文件夹 +2. **选择项目** - 在对话中选择当前工作目录 +3. **文件操作** - AI 自动在项目目录内执行文件操作 + +### 安全机制 + +- 所有文件操作需要 `project_id` 参数 +- 后端强制验证路径在项目目录内 +- 阻止目录遍历攻击(如 `../../../etc/passwd`) + ## API 概览 | 方法 | 路径 | 说明 | @@ -107,26 +125,29 @@ frontend/ | `GET` | `/api/conversations` | 会话列表 | | `GET` | `/api/conversations/:id/messages` | 消息列表 | | `POST` | `/api/conversations/:id/messages` | 发送消息(SSE) | +| `GET` | `/api/projects` | 项目列表 | +| `POST` | `/api/projects` | 创建项目 | +| `POST` | `/api/projects/upload` | 上传文件夹 | | `GET` | `/api/tools` | 工具列表 | | `GET` | `/api/stats/tokens` | Token 统计 | ## 内置工具 -| 分类 | 工具 | -|------|------| -| **爬虫** | web_search, fetch_page, crawl_batch | -| **数据处理** | calculator, text_process, json_process | -| **代码执行** | execute_python(沙箱环境) | -| **文件操作** | file_read, file_write, file_list 等 | -| **天气** | get_weather | +| 分类 | 工具 | 说明 | +|------|------|------| +| **爬虫** | web_search, fetch_page, crawl_batch | 网页搜索和抓取 | +| **数据处理** | calculator, text_process, json_process | 数学计算和文本处理 | +| **代码执行** | execute_python | 沙箱环境执行 Python | +| **文件操作** | file_read, file_write, file_list 等 | **需要 project_id** | +| **天气** | get_weather | 天气查询(模拟) | ## 文档 -- [后端设计](docs/Design.md) - 架构设计、类图、API 文档 -- [工具系统](docs/ToolSystemDesign.md) - 工具开发指南 +- [后端设计](docs/Design.md) - 架构设计、数据模型、API 文档 +- [工具系统](docs/ToolSystemDesign.md) - 工具开发指南、安全设计 ## 技术栈 -- **后端**: Python 3.11+, Flask -- **前端**: Vue 3 -- **大模型**: GLM API(智谱AI) +- **后端**: Python 3.11+, Flask, SQLAlchemy +- **前端**: Vue 3, Vite +- **LLM**: 支持 GLM 等大语言模型 diff --git a/backend/__init__.py b/backend/__init__.py index 50a2170..e5eef25 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -58,7 +58,7 @@ def create_app(): db.init_app(app) # Import after db is initialized - from backend.models import User, Conversation, Message, TokenUsage + from backend.models import User, Conversation, Message, TokenUsage, Project from backend.routes import register_routes from backend.tools import init_tools diff --git a/backend/models.py b/backend/models.py index 87842f7..d5fb24c 100644 --- a/backend/models.py +++ b/backend/models.py @@ -44,6 +44,7 @@ class Conversation(db.Model): id = db.Column(db.String(64), primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True) + project_id = db.Column(db.String(64), db.ForeignKey("projects.id"), nullable=True, index=True) title = db.Column(db.String(255), nullable=False, default="") model = db.Column(db.String(64), nullable=False, default="glm-5") system_prompt = db.Column(db.Text, default="") @@ -91,3 +92,25 @@ class TokenUsage(db.Model): db.UniqueConstraint("user_id", "date", "model", name="uq_user_date_model"), db.Index("ix_token_usage_date_model", "date", "model"), # Composite index ) + + +class Project(db.Model): + """Project model for workspace isolation""" + __tablename__ = "projects" + + id = db.Column(db.String(64), primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False, index=True) + name = db.Column(db.String(255), nullable=False) + path = db.Column(db.String(512), nullable=False) # Relative path within workspace root + description = db.Column(db.Text, default="") + created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), index=True) + updated_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc)) + + # Relationship: one project can have multiple conversations + conversations = db.relationship("Conversation", backref="project", lazy="dynamic", + cascade="all, delete-orphan") + + __table_args__ = ( + db.UniqueConstraint("user_id", "name", name="uq_user_project_name"), + ) diff --git a/backend/routes/__init__.py b/backend/routes/__init__.py index 74e89b6..24de65e 100644 --- a/backend/routes/__init__.py +++ b/backend/routes/__init__.py @@ -5,6 +5,7 @@ from backend.routes.messages import bp as messages_bp, init_chat_service from backend.routes.models import bp as models_bp from backend.routes.tools import bp as tools_bp from backend.routes.stats import bp as stats_bp +from backend.routes.projects import bp as projects_bp from backend.services.glm_client import GLMClient from backend.config import API_URL, API_KEY @@ -21,3 +22,4 @@ def register_routes(app: Flask): app.register_blueprint(models_bp) app.register_blueprint(tools_bp) app.register_blueprint(stats_bp) + app.register_blueprint(projects_bp) diff --git a/backend/routes/messages.py b/backend/routes/messages.py index e9fbf24..ca03591 100644 --- a/backend/routes/messages.py +++ b/backend/routes/messages.py @@ -48,6 +48,7 @@ def message_list(conv_id): d = request.json or {} text = (d.get("text") or "").strip() attachments = d.get("attachments") # [{"name": "a.py", "extension": "py", "content": "..."}] + project_id = d.get("project_id") # Get project_id from request if not text and not attachments: return err(400, "text or attachments is required") @@ -69,9 +70,9 @@ def message_list(conv_id): tools_enabled = d.get("tools_enabled", True) if d.get("stream", False): - return _chat_service.stream_response(conv, tools_enabled) + return _chat_service.stream_response(conv, tools_enabled, project_id) - return _chat_service.sync_response(conv, tools_enabled) + return _chat_service.sync_response(conv, tools_enabled, project_id) @bp.route("/api/conversations//messages/", methods=["DELETE"]) @@ -113,6 +114,7 @@ def regenerate_message(conv_id, msg_id): # 获取工具启用状态 d = request.json or {} tools_enabled = d.get("tools_enabled", True) + project_id = d.get("project_id") # Get project_id from request # 流式重新生成 - return _chat_service.stream_response(conv, tools_enabled) + return _chat_service.stream_response(conv, tools_enabled, project_id) diff --git a/backend/routes/projects.py b/backend/routes/projects.py new file mode 100644 index 0000000..bdb6245 --- /dev/null +++ b/backend/routes/projects.py @@ -0,0 +1,331 @@ +"""Project management API routes""" +import os +import uuid +import shutil +from flask import Blueprint, request, jsonify +from werkzeug.utils import secure_filename + +from backend import db +from backend.models import Project, User +from backend.utils.helpers import ok, err +from backend.utils.workspace import ( + create_project_directory, + delete_project_directory, + get_project_path, + copy_folder_to_project +) + +bp = Blueprint("projects", __name__) + + +@bp.route("/api/projects", methods=["GET"]) +def list_projects(): + """List all projects for a user""" + user_id = request.args.get("user_id", type=int) + + if not user_id: + return err(400, "Missing user_id parameter") + + projects = Project.query.filter_by(user_id=user_id).order_by(Project.updated_at.desc()).all() + + return ok({ + "projects": [ + { + "id": p.id, + "name": p.name, + "path": p.path, + "description": p.description, + "created_at": p.created_at.isoformat() if p.created_at else None, + "updated_at": p.updated_at.isoformat() if p.updated_at else None, + "conversation_count": p.conversations.count() + } + for p in projects + ], + "total": len(projects) + }) + + +@bp.route("/api/projects", methods=["POST"]) +def create_project(): + """Create a new project""" + data = request.get_json() + + if not data: + return err(400, "No data provided") + + user_id = data.get("user_id") + name = data.get("name", "").strip() + description = data.get("description", "") + + if not user_id: + return err(400, "Missing user_id") + + if not name: + return err(400, "Project name is required") + + # Check if user exists + user = User.query.get(user_id) + if not user: + return err(404, "User not found") + + # Check if project name already exists for this user + existing = Project.query.filter_by(user_id=user_id, name=name).first() + if existing: + return err(400, f"Project '{name}' already exists") + + # Create project directory + try: + relative_path, absolute_path = create_project_directory(name, user_id) + except Exception as e: + return err(500, f"Failed to create project directory: {str(e)}") + + # Create project record + project = Project( + id=str(uuid.uuid4()), + user_id=user_id, + name=name, + path=relative_path, + description=description + ) + + db.session.add(project) + db.session.commit() + + return ok({ + "id": project.id, + "name": project.name, + "path": project.path, + "description": project.description, + "created_at": project.created_at.isoformat() + }) + + +@bp.route("/api/projects/", methods=["GET"]) +def get_project(project_id): + """Get project details""" + project = Project.query.get(project_id) + + if not project: + return err(404, "Project not found") + + # Get absolute path + absolute_path = get_project_path(project.id, project.path) + + # Get directory statistics + file_count = sum(1 for _ in absolute_path.rglob("*") if _.is_file()) + dir_count = sum(1 for _ in absolute_path.rglob("*") if _.is_dir()) + total_size = sum(f.stat().st_size for f in absolute_path.rglob("*") if f.is_file()) + + return ok({ + "id": project.id, + "name": project.name, + "path": project.path, + "absolute_path": str(absolute_path), + "description": project.description, + "created_at": project.created_at.isoformat() if project.created_at else None, + "updated_at": project.updated_at.isoformat() if project.updated_at else None, + "conversation_count": project.conversations.count(), + "stats": { + "files": file_count, + "directories": dir_count, + "total_size": total_size + } + }) + + +@bp.route("/api/projects/", methods=["PUT"]) +def update_project(project_id): + """Update project details""" + project = Project.query.get(project_id) + + if not project: + return err(404, "Project not found") + + data = request.get_json() + + if not data: + return err(400, "No data provided") + + # Update name if provided + if "name" in data: + name = data["name"].strip() + if not name: + return err(400, "Project name cannot be empty") + + # Check if new name conflicts with existing project + existing = Project.query.filter( + Project.user_id == project.user_id, + Project.name == name, + Project.id != project_id + ).first() + + if existing: + return err(400, f"Project '{name}' already exists") + + project.name = name + + # Update description if provided + if "description" in data: + project.description = data["description"] + + db.session.commit() + + return ok({ + "id": project.id, + "name": project.name, + "description": project.description, + "updated_at": project.updated_at.isoformat() + }) + + +@bp.route("/api/projects/", methods=["DELETE"]) +def delete_project(project_id): + """Delete a project""" + project = Project.query.get(project_id) + + if not project: + return err(404, "Project not found") + + # Delete project directory + try: + delete_project_directory(project.path) + except Exception as e: + return err(500, f"Failed to delete project directory: {str(e)}") + + # Delete project record (cascades to conversations and messages) + db.session.delete(project) + db.session.commit() + + return ok({"message": "Project deleted successfully"}) + + +@bp.route("/api/projects/upload", methods=["POST"]) +def upload_project_folder(): + """Upload a folder as a new project (via temporary upload)""" + if "folder_path" not in request.json: + return err(400, "Missing folder_path in request body") + + user_id = request.json.get("user_id") + folder_path = request.json.get("folder_path") + project_name = request.json.get("name") + description = request.json.get("description", "") + + if not user_id: + return err(400, "Missing user_id") + + if not folder_path: + return err(400, "Missing folder_path") + + if not project_name: + # Use folder name as project name + project_name = os.path.basename(folder_path) + + # Check if user exists + user = User.query.get(user_id) + if not user: + return err(404, "User not found") + + # Check if project name already exists + existing = Project.query.filter_by(user_id=user_id, name=project_name).first() + if existing: + return err(400, f"Project '{project_name}' already exists") + + # Create project directory first + try: + relative_path, absolute_path = create_project_directory(project_name, user_id) + except Exception as e: + return err(500, f"Failed to create project directory: {str(e)}") + + # Copy folder contents to project directory + try: + stats = copy_folder_to_project(folder_path, absolute_path, project_name) + except Exception as e: + # Clean up created directory on error + shutil.rmtree(absolute_path, ignore_errors=True) + return err(500, f"Failed to copy folder: {str(e)}") + + # Create project record + project = Project( + id=str(uuid.uuid4()), + user_id=user_id, + name=project_name, + path=relative_path, + description=description + ) + + db.session.add(project) + db.session.commit() + + return ok({ + "id": project.id, + "name": project.name, + "path": project.path, + "description": project.description, + "created_at": project.created_at.isoformat(), + "stats": stats + }) + + +@bp.route("/api/projects//files", methods=["GET"]) +def list_project_files(project_id): + """List files in a project directory""" + project = Project.query.get(project_id) + + if not project: + return err(404, "Project not found") + + project_dir = get_project_path(project.id, project.path) + + # Get subdirectory parameter + subdir = request.args.get("path", "") + + try: + target_dir = project_dir / subdir if subdir else project_dir + target_dir = target_dir.resolve() + + # Validate path is within project + target_dir.relative_to(project_dir.resolve()) + except ValueError: + return err(403, "Invalid path: outside project directory") + + if not target_dir.exists(): + return err(404, "Directory not found") + + if not target_dir.is_dir(): + return err(400, "Path is not a directory") + + # List files + files = [] + directories = [] + + try: + for item in target_dir.iterdir(): + # Skip hidden files + if item.name.startswith("."): + continue + + relative_path = item.relative_to(project_dir) + + if item.is_file(): + files.append({ + "name": item.name, + "path": str(relative_path), + "size": item.stat().st_size, + "extension": item.suffix + }) + elif item.is_dir(): + directories.append({ + "name": item.name, + "path": str(relative_path) + }) + except Exception as e: + return err(500, f"Failed to list directory: {str(e)}") + + return ok({ + "project_id": project_id, + "current_path": str(subdir) if subdir else "/", + "files": files, + "directories": directories, + "total_files": len(files), + "total_dirs": len(directories) + }) diff --git a/backend/services/chat.py b/backend/services/chat.py index 872b34d..3bf8d67 100644 --- a/backend/services/chat.py +++ b/backend/services/chat.py @@ -8,7 +8,7 @@ from backend.tools import registry, ToolExecutor from backend.utils.helpers import ( get_or_create_default_user, record_token_usage, - build_glm_messages, + build_messages, ok, err, to_dict, @@ -26,14 +26,23 @@ class ChatService: self.executor = ToolExecutor(registry=registry) - def sync_response(self, conv: Conversation, tools_enabled: bool = True): - """Sync response with tool call support""" + def sync_response(self, conv: Conversation, tools_enabled: bool = True, project_id: str = None): + """Sync response with tool call support + + Args: + conv: Conversation object + tools_enabled: Whether to enable tools + project_id: Project ID for workspace isolation + """ tools = registry.list_all() if tools_enabled else None - messages = build_glm_messages(conv) + messages = build_messages(conv, project_id) # Clear tool call history for new request self.executor.clear_history() + # Build context for tool execution + context = {"project_id": project_id} if project_id else None + all_tool_calls = [] all_tool_results = [] @@ -119,27 +128,35 @@ class ChatService: all_tool_calls.extend(tool_calls) messages.append(message) - tool_results = self.executor.process_tool_calls(tool_calls) + tool_results = self.executor.process_tool_calls(tool_calls, context) all_tool_results.extend(tool_results) messages.extend(tool_results) return err(500, "exceeded maximum tool call iterations") - def stream_response(self, conv: Conversation, tools_enabled: bool = True): + def stream_response(self, conv: Conversation, tools_enabled: bool = True, project_id: str = None): """Stream response with tool call support Uses 'process_step' events to send thinking and tool calls in order, allowing them to be interleaved properly in the frontend. + + Args: + conv: Conversation object + tools_enabled: Whether to enable tools + project_id: Project ID for workspace isolation """ conv_id = conv.id conv_model = conv.model app = current_app._get_current_object() tools = registry.list_all() if tools_enabled else None - initial_messages = build_glm_messages(conv) + initial_messages = build_messages(conv, project_id) # Clear tool call history for new request self.executor.clear_history() + # Build context for tool execution + context = {"project_id": project_id} if project_id else None + def generate(): messages = list(initial_messages) all_tool_calls = [] @@ -232,17 +249,16 @@ class ChatService: step_index += 1 # Execute this single tool call - single_result = self.executor.process_tool_calls([tc]) + single_result = self.executor.process_tool_calls([tc], context) tool_results.extend(single_result) # Send tool result step immediately tr = single_result[0] try: - result_data = json.loads(tr["content"]) - skipped = result_data.get("skipped", False) + result_content = json.loads(tr["content"]) + skipped = result_content.get("skipped", False) except: skipped = False - yield f"event: process_step\ndata: {json.dumps({'index': step_index, 'type': 'tool_result', 'id': tr['tool_call_id'], 'name': tr['name'], 'content': tr['content'], 'skipped': skipped}, ensure_ascii=False)}\n\n" step_index += 1 @@ -330,7 +346,7 @@ class ChatService: ) def _build_tool_calls_json(self, tool_calls: list, tool_results: list) -> list: - """Build tool calls JSON structure""" + """Build tool calls JSON structure - matches streaming format""" result = [] for i, tc in enumerate(tool_calls): result_content = tool_results[i]["content"] if i < len(tool_results) else None @@ -348,10 +364,14 @@ class ChatService: except: pass + # Keep same structure as streaming format result.append({ "id": tc.get("id", ""), - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], + "type": tc.get("type", "function"), + "function": { + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + }, "result": result_content, "success": success, "skipped": skipped, diff --git a/backend/tools/builtin/__init__.py b/backend/tools/builtin/__init__.py index 3f826da..f97685d 100644 --- a/backend/tools/builtin/__init__.py +++ b/backend/tools/builtin/__init__.py @@ -1,4 +1,6 @@ """Built-in tools""" +from backend.tools.builtin.code import * from backend.tools.builtin.crawler import * from backend.tools.builtin.data import * from backend.tools.builtin.file_ops import * +from backend.tools.builtin.weather import * diff --git a/backend/tools/builtin/file_ops.py b/backend/tools/builtin/file_ops.py index 798a9b2..af9bcf4 100644 --- a/backend/tools/builtin/file_ops.py +++ b/backend/tools/builtin/file_ops.py @@ -2,41 +2,55 @@ import os import json from pathlib import Path -from typing import Optional +from typing import Optional, Tuple from backend.tools.factory import tool +from backend import db +from backend.models import Project +from backend.utils.workspace import get_project_path, validate_path_in_project -# Base directory for file operations (sandbox) -# Set to None to allow any path, or set a specific directory for security -BASE_DIR = Path(__file__).parent.parent.parent.parent # project root - - -def _resolve_path(path: str) -> Path: - """Resolve path and ensure it's within allowed directory""" - p = Path(path) - if not p.is_absolute(): - p = BASE_DIR / p - p = p.resolve() +def _resolve_path(path: str, project_id: str = None) -> Tuple[Path, Path]: + """ + Resolve path and ensure it's within project directory - # Security check: ensure path is within BASE_DIR - if BASE_DIR: - try: - p.relative_to(BASE_DIR.resolve()) - except ValueError: - raise ValueError(f"Path '{path}' is outside allowed directory") + Args: + path: File path (relative or absolute) + project_id: Project ID for workspace isolation + + Returns: + Tuple of (resolved absolute path, project directory) + + Raises: + ValueError: If project_id is missing or path is outside project + """ + if not project_id: + raise ValueError("project_id is required for file operations") - return p + # Get project from database + project = db.session.get(Project, project_id) + if not project: + raise ValueError(f"Project not found: {project_id}") + + # Get project directory + project_dir = get_project_path(project.id, project.path) + + # Validate and resolve path + return validate_path_in_project(path, project_dir), project_dir @tool( name="file_read", - description="Read content from a file. Use when you need to read file content.", + description="Read content from a file within the project workspace. Use when you need to read file content.", parameters={ "type": "object", "properties": { "path": { "type": "string", - "description": "File path to read (relative to project root or absolute)" + "description": "File path to read (relative to project root or absolute within project)" + }, + "project_id": { + "type": "string", + "description": "Project ID for workspace isolation (required)" }, "encoding": { "type": "string", @@ -44,7 +58,7 @@ def _resolve_path(path: str) -> Path: "default": "utf-8" } }, - "required": ["path"] + "required": ["path", "project_id"] }, category="file" ) @@ -55,46 +69,53 @@ def file_read(arguments: dict) -> dict: Args: arguments: { "path": "file.txt", + "project_id": "project-uuid", "encoding": "utf-8" } Returns: - {"content": "...", "size": 100} + {"success": true, "content": "...", "size": 100} """ try: - path = _resolve_path(arguments["path"]) + path, project_dir = _resolve_path(arguments["path"], arguments.get("project_id")) encoding = arguments.get("encoding", "utf-8") if not path.exists(): - return {"error": f"File not found: {path}"} + return {"success": False, "error": f"File not found: {path}"} if not path.is_file(): - return {"error": f"Path is not a file: {path}"} + return {"success": False, "error": f"Path is not a file: {path}"} content = path.read_text(encoding=encoding) + return { + "success": True, "content": content, "size": len(content), - "path": str(path) + "path": str(path.relative_to(project_dir)) } except Exception as e: - return {"error": str(e)} + return {"success": False, "error": str(e)} @tool( name="file_write", - description="Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Use when you need to create or update a file.", + description="Write content to a file within the project workspace. Creates the file if it doesn't exist, overwrites if it does. Use when you need to create or update a file.", parameters={ "type": "object", "properties": { "path": { "type": "string", - "description": "File path to write (relative to project root or absolute)" + "description": "File path to write (relative to project root or absolute within project)" }, "content": { "type": "string", "description": "Content to write to the file" }, + "project_id": { + "type": "string", + "description": "Project ID for workspace isolation (required)" + }, "encoding": { "type": "string", "description": "File encoding, default utf-8", @@ -107,7 +128,7 @@ def file_read(arguments: dict) -> dict: "default": "write" } }, - "required": ["path", "content"] + "required": ["path", "content", "project_id"] }, category="file" ) @@ -119,6 +140,7 @@ def file_write(arguments: dict) -> dict: arguments: { "path": "file.txt", "content": "Hello World", + "project_id": "project-uuid", "encoding": "utf-8", "mode": "write" } @@ -127,7 +149,7 @@ def file_write(arguments: dict) -> dict: {"success": true, "size": 11} """ try: - path = _resolve_path(arguments["path"]) + path, project_dir = _resolve_path(arguments["path"], arguments.get("project_id")) content = arguments["content"] encoding = arguments.get("encoding", "utf-8") mode = arguments.get("mode", "write") @@ -145,25 +167,29 @@ def file_write(arguments: dict) -> dict: return { "success": True, "size": len(content), - "path": str(path), + "path": str(path.relative_to(project_dir)), "mode": mode } except Exception as e: - return {"error": str(e)} + return {"success": False, "error": str(e)} @tool( name="file_delete", - description="Delete a file. Use when you need to remove a file.", + description="Delete a file within the project workspace. Use when you need to remove a file.", parameters={ "type": "object", "properties": { "path": { "type": "string", - "description": "File path to delete (relative to project root or absolute)" + "description": "File path to delete (relative to project root or absolute within project)" + }, + "project_id": { + "type": "string", + "description": "Project ID for workspace isolation (required)" } }, - "required": ["path"] + "required": ["path", "project_id"] }, category="file" ) @@ -173,45 +199,51 @@ def file_delete(arguments: dict) -> dict: Args: arguments: { - "path": "file.txt" + "path": "file.txt", + "project_id": "project-uuid" } Returns: {"success": true} """ try: - path = _resolve_path(arguments["path"]) + path, project_dir = _resolve_path(arguments["path"], arguments.get("project_id")) if not path.exists(): - return {"error": f"File not found: {path}"} + return {"success": False, "error": f"File not found: {path}"} if not path.is_file(): - return {"error": f"Path is not a file: {path}"} + return {"success": False, "error": f"Path is not a file: {path}"} + rel_path = str(path.relative_to(project_dir)) path.unlink() - return {"success": True, "path": str(path)} + return {"success": True, "path": rel_path} except Exception as e: - return {"error": str(e)} + return {"success": False, "error": str(e)} @tool( name="file_list", - description="List files and directories in a directory. Use when you need to see what files exist.", + description="List files and directories in a directory within the project workspace. Use when you need to see what files exist.", parameters={ "type": "object", "properties": { "path": { "type": "string", - "description": "Directory path to list (relative to project root or absolute)", + "description": "Directory path to list (relative to project root or absolute within project)", "default": "." }, "pattern": { "type": "string", "description": "Glob pattern to filter files, e.g. '*.py'", "default": "*" + }, + "project_id": { + "type": "string", + "description": "Project ID for workspace isolation (required)" } }, - "required": [] + "required": ["project_id"] }, category="file" ) @@ -222,21 +254,22 @@ def file_list(arguments: dict) -> dict: Args: arguments: { "path": ".", - "pattern": "*" + "pattern": "*", + "project_id": "project-uuid" } Returns: - {"files": [...], "directories": [...]} + {"success": true, "files": [...], "directories": [...]} """ try: - path = _resolve_path(arguments.get("path", ".")) + path, project_dir = _resolve_path(arguments.get("path", "."), arguments.get("project_id")) pattern = arguments.get("pattern", "*") if not path.exists(): - return {"error": f"Directory not found: {path}"} + return {"success": False, "error": f"Directory not found: {path}"} if not path.is_dir(): - return {"error": f"Path is not a directory: {path}"} + return {"success": False, "error": f"Path is not a directory: {path}"} files = [] directories = [] @@ -246,37 +279,42 @@ def file_list(arguments: dict) -> dict: files.append({ "name": item.name, "size": item.stat().st_size, - "path": str(item.relative_to(BASE_DIR)) if BASE_DIR else str(item) + "path": str(item.relative_to(project_dir)) }) elif item.is_dir(): directories.append({ "name": item.name, - "path": str(item.relative_to(BASE_DIR)) if BASE_DIR else str(item) + "path": str(item.relative_to(project_dir)) }) return { - "path": str(path), + "success": True, + "path": str(path.relative_to(project_dir)), "files": files, "directories": directories, "total_files": len(files), "total_dirs": len(directories) } except Exception as e: - return {"error": str(e)} + return {"success": False, "error": str(e)} @tool( name="file_exists", - description="Check if a file or directory exists. Use when you need to verify file existence.", + description="Check if a file or directory exists within the project workspace. Use when you need to verify file existence.", parameters={ "type": "object", "properties": { "path": { "type": "string", - "description": "Path to check (relative to project root or absolute)" + "description": "Path to check (relative to project root or absolute within project)" + }, + "project_id": { + "type": "string", + "description": "Project ID for workspace isolation (required)" } }, - "required": ["path"] + "required": ["path", "project_id"] }, category="file" ) @@ -286,53 +324,58 @@ def file_exists(arguments: dict) -> dict: Args: arguments: { - "path": "file.txt" + "path": "file.txt", + "project_id": "project-uuid" } Returns: {"exists": true, "type": "file"} """ try: - path = _resolve_path(arguments["path"]) + path, project_dir = _resolve_path(arguments["path"], arguments.get("project_id")) if not path.exists(): - return {"exists": False, "path": str(path)} + return {"exists": False, "path": str(path.relative_to(project_dir))} if path.is_file(): return { "exists": True, "type": "file", - "path": str(path), + "path": str(path.relative_to(project_dir)), "size": path.stat().st_size } elif path.is_dir(): return { "exists": True, "type": "directory", - "path": str(path) + "path": str(path.relative_to(project_dir)) } else: return { "exists": True, "type": "other", - "path": str(path) + "path": str(path.relative_to(project_dir)) } except Exception as e: - return {"error": str(e)} + return {"success": False, "error": str(e)} @tool( name="file_mkdir", - description="Create a directory. Creates parent directories if needed. Use when you need to create a folder.", + description="Create a directory within the project workspace. Creates parent directories if needed. Use when you need to create a folder.", parameters={ "type": "object", "properties": { "path": { "type": "string", - "description": "Directory path to create (relative to project root or absolute)" + "description": "Directory path to create (relative to project root or absolute within project)" + }, + "project_id": { + "type": "string", + "description": "Project ID for workspace isolation (required)" } }, - "required": ["path"] + "required": ["path", "project_id"] }, category="file" ) @@ -342,19 +385,23 @@ def file_mkdir(arguments: dict) -> dict: Args: arguments: { - "path": "new/folder" + "path": "new/folder", + "project_id": "project-uuid" } Returns: {"success": true} """ try: - path = _resolve_path(arguments["path"]) + path, project_dir = _resolve_path(arguments["path"], arguments.get("project_id")) + + created = not path.exists() path.mkdir(parents=True, exist_ok=True) + return { "success": True, - "path": str(path), - "created": not path.exists() or path.is_dir() + "path": str(path.relative_to(project_dir)), + "created": created } except Exception as e: - return {"error": str(e)} + return {"success": False, "error": str(e)} diff --git a/backend/tools/executor.py b/backend/tools/executor.py index fc535c9..73b9b2c 100644 --- a/backend/tools/executor.py +++ b/backend/tools/executor.py @@ -70,7 +70,7 @@ class ToolExecutor: Args: tool_calls: Tool call list returned by LLM - context: Optional context info (user_id, etc.) + context: Optional context info (user_id, project_id, etc.) Returns: Tool response message list, can be appended to messages @@ -91,6 +91,12 @@ class ToolExecutor: )) continue + # Inject context into tool arguments + if context: + # For file operation tools, inject project_id automatically + if name.startswith("file_") and "project_id" in context: + args["project_id"] = context["project_id"] + # Check for duplicate within same batch call_key = f"{name}:{json.dumps(args, sort_keys=True)}" if call_key in seen_calls: diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py index c0b9b76..8d481fc 100644 --- a/backend/utils/__init__.py +++ b/backend/utils/__init__.py @@ -1,5 +1,5 @@ """Backend utilities""" -from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user, record_token_usage, build_glm_messages +from backend.utils.helpers import ok, err, to_dict, get_or_create_default_user, record_token_usage, build_messages __all__ = [ "ok", @@ -7,5 +7,5 @@ __all__ = [ "to_dict", "get_or_create_default_user", "record_token_usage", - "build_glm_messages", + "build_messages", ] diff --git a/backend/utils/helpers.py b/backend/utils/helpers.py index 53e8ec4..b3362e0 100644 --- a/backend/utils/helpers.py +++ b/backend/utils/helpers.py @@ -98,9 +98,16 @@ def record_token_usage(user_id, model, prompt_tokens, completion_tokens): db.session.commit() -def build_glm_messages(conv): - """Build messages list for GLM API from conversation""" +def build_messages(conv, project_id=None): + """Build messages list for LLM API from conversation + + Args: + conv: Conversation object + project_id: Project ID (used for context injection, backend enforces workspace isolation) + """ msgs = [] + + # System prompt (project_id is handled by backend for security) if conv.system_prompt: msgs.append({"role": "system", "content": conv.system_prompt}) # Query messages directly to avoid detached instance warning diff --git a/backend/utils/workspace.py b/backend/utils/workspace.py new file mode 100644 index 0000000..af44bdd --- /dev/null +++ b/backend/utils/workspace.py @@ -0,0 +1,180 @@ +"""Workspace path validation utilities""" +import os +import shutil +from pathlib import Path +from typing import Optional +from flask import current_app + +from backend import load_config + + +def get_workspace_root() -> Path: + """Get workspace root directory from config""" + cfg = load_config() + workspace_root = cfg.get("workspace_root", "./workspaces") + + # Convert to absolute path + workspace_path = Path(workspace_root) + if not workspace_path.is_absolute(): + # Relative to project root + workspace_path = Path(__file__).parent.parent.parent / workspace_root + + # Create if not exists + workspace_path.mkdir(parents=True, exist_ok=True) + + return workspace_path.resolve() + + +def get_project_path(project_id: str, project_path: str) -> Path: + """ + Get absolute path for a project + + Args: + project_id: Project ID + project_path: Relative path stored in database + + Returns: + Absolute path to project directory + """ + workspace_root = get_workspace_root() + project_dir = workspace_root / project_path + + # Create if not exists + project_dir.mkdir(parents=True, exist_ok=True) + + return project_dir.resolve() + + +def validate_path_in_project(path: str, project_dir: Path) -> Path: + """ + Validate that a path is within the project directory + + Args: + path: Path to validate (can be relative or absolute) + project_dir: Project directory path + + Returns: + Resolved absolute path + + Raises: + ValueError: If path is outside project directory + """ + p = Path(path) + + # If relative, resolve against project directory + if not p.is_absolute(): + p = project_dir / p + + # Resolve to absolute path + p = p.resolve() + + # Security check: ensure path is within project directory + try: + p.relative_to(project_dir.resolve()) + except ValueError: + raise ValueError(f"Path '{path}' is outside project directory") + + return p + + +def create_project_directory(name: str, user_id: int) -> tuple[str, Path]: + """ + Create a new project directory + + Args: + name: Project name + user_id: User ID + + Returns: + Tuple of (relative_path, absolute_path) + """ + workspace_root = get_workspace_root() + + # Create user-specific directory + user_dir = workspace_root / f"user_{user_id}" + user_dir.mkdir(parents=True, exist_ok=True) + + # Create project directory + project_dir = user_dir / name + + # Handle name conflicts + counter = 1 + original_name = name + while project_dir.exists(): + name = f"{original_name}_{counter}" + project_dir = user_dir / name + counter += 1 + + project_dir.mkdir(parents=True, exist_ok=True) + + # Return relative path (from workspace root) and absolute path + relative_path = f"user_{user_id}/{name}" + return relative_path, project_dir.resolve() + + +def delete_project_directory(project_path: str) -> bool: + """ + Delete a project directory + + Args: + project_path: Relative path from workspace root + + Returns: + True if deleted successfully + """ + workspace_root = get_workspace_root() + project_dir = workspace_root / project_path + + if project_dir.exists() and project_dir.is_dir(): + # Verify it's within workspace root (security check) + try: + project_dir.resolve().relative_to(workspace_root.resolve()) + shutil.rmtree(project_dir) + return True + except ValueError: + raise ValueError("Cannot delete directory outside workspace root") + + return False + + +def copy_folder_to_project(source_path: str, project_dir: Path, project_name: str) -> dict: + """ + Copy a folder to project directory (for folder upload) + + Args: + source_path: Source folder path + project_dir: Target project directory + project_name: Project name + + Returns: + Dict with copy statistics + """ + source = Path(source_path) + + if not source.exists(): + raise ValueError(f"Source path does not exist: {source_path}") + + if not source.is_dir(): + raise ValueError(f"Source path is not a directory: {source_path}") + + # Security check: don't copy from sensitive system directories + sensitive_dirs = ["/etc", "/usr", "/bin", "/sbin", "/root", "/home"] + for sensitive in sensitive_dirs: + if str(source.resolve()).startswith(sensitive): + raise ValueError(f"Cannot copy from system directory: {sensitive}") + + # Copy directory + if project_dir.exists(): + shutil.rmtree(project_dir) + + shutil.copytree(source, project_dir) + + # Count files + file_count = sum(1 for _ in project_dir.rglob("*") if _.is_file()) + dir_count = sum(1 for _ in project_dir.rglob("*") if _.is_dir()) + + return { + "files": file_count, + "directories": dir_count, + "size": sum(f.stat().st_size for f in project_dir.rglob("*") if f.is_file()) + } diff --git a/docs/Design.md b/docs/Design.md index 22c53b3..d9a8b59 100644 --- a/docs/Design.md +++ b/docs/Design.md @@ -16,14 +16,14 @@ graph TB end subgraph External[外部服务] - GLM[GLM API] + LLM[LLM API] WEB[Web Resources] end UI -->|REST/SSE| API API --> SVC API --> TOOLS - SVC --> GLM + SVC --> LLM TOOLS --> WEB SVC --> DB TOOLS --> DB @@ -45,6 +45,7 @@ backend/ │ ├── conversations.py # 会话 CRUD │ ├── messages.py # 消息 CRUD + 聊天 │ ├── models.py # 模型列表 +│ ├── projects.py # 项目管理 │ ├── stats.py # Token 统计 │ └── tools.py # 工具列表 │ @@ -63,12 +64,16 @@ backend/ │ ├── crawler.py # 网页搜索、抓取 │ ├── data.py # 计算器、文本、JSON │ ├── weather.py # 天气查询 -│ ├── file_ops.py # 文件操作 +│ ├── file_ops.py # 文件操作(需要 project_id) │ └── code.py # 代码执行 │ -└── utils/ # 辅助函数 - ├── __init__.py - └── helpers.py # 通用函数 +├── utils/ # 辅助函数 +│ ├── __init__.py +│ ├── helpers.py # 通用函数 +│ └── workspace.py # 工作目录工具 +│ +└── migrations/ # 数据库迁移 + └── add_project_support.py ``` --- @@ -87,11 +92,24 @@ classDiagram +String password +String phone +relationship conversations + +relationship projects + } + + class Project { + +String id + +Integer user_id + +String name + +String path + +String description + +DateTime created_at + +DateTime updated_at + +relationship conversations } class Conversation { +String id +Integer user_id + +String project_id +String title +String model +String system_prompt @@ -124,6 +142,8 @@ classDiagram } User "1" --> "*" Conversation : 拥有 + User "1" --> "*" Project : 拥有 + Project "1" --> "*" Conversation : 包含 Conversation "1" --> "*" Message : 包含 User "1" --> "*" TokenUsage : 消耗 ``` @@ -150,8 +170,11 @@ classDiagram "tool_calls": [ { "id": "call_xxx", - "name": "read_file", - "arguments": "{\"path\": \"...\"}", + "type": "function", + "function": { + "name": "file_read", + "arguments": "{\"path\": \"...\"}" + }, "result": "{\"content\": \"...\"}", "success": true, "skipped": false, @@ -171,8 +194,8 @@ classDiagram -GLMClient glm_client -ToolExecutor executor +Integer MAX_ITERATIONS - +sync_response(conv, tools_enabled) Response - +stream_response(conv, tools_enabled) Response + +sync_response(conv, tools_enabled, project_id) Response + +stream_response(conv, tools_enabled, project_id) Response -_build_tool_calls_json(calls, results) list -_message_to_dict(msg) dict -_process_tool_calls_delta(delta, list) list @@ -249,6 +272,86 @@ classDiagram --- +## 工作目录系统 + +### 概述 + +工作目录系统为文件操作工具提供安全隔离,确保所有文件操作都在项目目录内执行。 + +### 核心函数 + +```python +# backend/utils/workspace.py + +def get_workspace_root() -> Path: + """获取工作区根目录""" + +def get_project_path(project_id: str, project_path: str) -> Path: + """获取项目绝对路径""" + +def validate_path_in_project(path: str, project_dir: Path) -> Path: + """验证路径在项目目录内(核心安全函数)""" + +def create_project_directory(name: str, user_id: int) -> tuple: + """创建项目目录""" + +def delete_project_directory(project_path: str) -> bool: + """删除项目目录""" + +def copy_folder_to_project(source_path: str, project_dir: Path, project_name: str) -> dict: + """复制文件夹到项目目录""" +``` + +### 安全机制 + +`validate_path_in_project()` 是核心安全函数: + +```python +def validate_path_in_project(path: str, project_dir: Path) -> Path: + p = Path(path) + + # 相对路径转换为绝对路径 + if not p.is_absolute(): + p = project_dir / p + + p = p.resolve() + + # 安全检查:确保路径在项目目录内 + try: + p.relative_to(project_dir.resolve()) + except ValueError: + raise ValueError(f"Path '{path}' is outside project directory") + + return p +``` + +即使传入恶意路径,后端也会拒绝: +```python +"../../../etc/passwd" # 尝试跳出项目目录 -> ValueError +"/etc/passwd" # 绝对路径攻击 -> ValueError +``` + +### project_id 自动注入 + +工具执行器自动为文件工具注入 `project_id`: + +```python +# backend/tools/executor.py + +def process_tool_calls(self, tool_calls, context=None): + for call in tool_calls: + name = call["function"]["name"] + args = json.loads(call["function"]["arguments"]) + + # 自动注入 project_id + if context and name.startswith("file_") and "project_id" in context: + args["project_id"] = context["project_id"] + + result = self.registry.execute(name, args) +``` + +--- + ## API 总览 ### 会话管理 @@ -268,6 +371,19 @@ classDiagram | `GET` | `/api/conversations/:id/messages` | 获取消息列表(游标分页) | | `POST` | `/api/conversations/:id/messages` | 发送消息(支持 SSE 流式) | | `DELETE` | `/api/conversations/:id/messages/:mid` | 删除消息 | +| `POST` | `/api/conversations/:id/regenerate/:mid` | 重新生成消息 | + +### 项目管理 + +| 方法 | 路径 | 说明 | +|------|------|------| +| `GET` | `/api/projects` | 获取项目列表 | +| `POST` | `/api/projects` | 创建项目 | +| `GET` | `/api/projects/:id` | 获取项目详情 | +| `PUT` | `/api/projects/:id` | 更新项目 | +| `DELETE` | `/api/projects/:id` | 删除项目 | +| `POST` | `/api/projects/upload` | 上传文件夹作为项目 | +| `GET` | `/api/projects/:id/files` | 列出项目文件 | ### 其他 @@ -292,26 +408,6 @@ classDiagram | `error` | 错误信息 | | `done` | 回复结束,携带 message_id 和 token_count | -### 思考与工具调用交替流程 - -``` -iteration 1: - thinking_start -> 前端清空 streamThinking - thinking (增量) -> 前端累加到 streamThinking - process_step(thinking, "思考内容A") - tool_calls -> 批量通知(兼容) - process_step(tool_call, "file_read") -> 调用工具 - process_step(tool_result, {...}) -> 立即返回结果 - process_step(tool_call, "file_list") -> 下一个工具 - process_step(tool_result, {...}) -> 立即返回结果 - -iteration 2: - thinking_start -> 前端清空 streamThinking - thinking (增量) -> 前端累加到 streamThinking - process_step(thinking, "思考内容B") - done -``` - ### process_step 事件格式 ```json @@ -346,12 +442,25 @@ iteration 2: | `password` | String(255) | 密码(可为空,支持第三方登录) | | `phone` | String(20) | 手机号 | +### Project(项目) + +| 字段 | 类型 | 说明 | +|------|------|------| +| `id` | String(64) | UUID 主键 | +| `user_id` | Integer | 外键关联 User | +| `name` | String(255) | 项目名称(用户内唯一) | +| `path` | String(512) | 相对路径(如 user_1/my_project) | +| `description` | Text | 项目描述 | +| `created_at` | DateTime | 创建时间 | +| `updated_at` | DateTime | 更新时间 | + ### Conversation(会话) | 字段 | 类型 | 默认值 | 说明 | |------|------|--------|------| | `id` | String(64) | UUID | 主键 | | `user_id` | Integer | - | 外键关联 User | +| `project_id` | String(64) | null | 外键关联 Project(可选) | | `title` | String(255) | "" | 会话标题 | | `model` | String(64) | "glm-5" | 模型名称 | | `system_prompt` | Text | "" | 系统提示词 | @@ -440,10 +549,13 @@ GET /api/conversations?limit=20&cursor=conv_abc123 backend_port: 3000 frontend_port: 4000 -# GLM API +# LLM API api_key: your-api-key api_url: https://open.bigmodel.cn/api/paas/v4/chat/completions +# 工作区根目录 +workspace_root: ./workspaces + # 数据库 db_type: mysql # mysql, sqlite, postgresql db_host: localhost @@ -452,4 +564,4 @@ db_user: root db_password: "" db_name: nano_claw db_sqlite_file: app.db # SQLite 时使用 -``` +``` \ No newline at end of file diff --git a/docs/ToolSystemDesign.md b/docs/ToolSystemDesign.md index a990449..eeb0172 100644 --- a/docs/ToolSystemDesign.md +++ b/docs/ToolSystemDesign.md @@ -2,7 +2,7 @@ ## 概述 -NanoClaw 工具调用系统采用简化的工厂模式,支持装饰器注册、缓存优化、重复调用检测等功能。 +NanoClaw 工具调用系统采用简化的工厂模式,支持装饰器注册、缓存优化、重复调用检测、工作目录隔离等功能。 --- @@ -59,858 +59,198 @@ classDiagram --- -## 二、工具定义工厂 +## 二、工具调用格式 -使用工厂函数创建工具,避免复杂的类继承: +### 统一格式 -```mermaid -classDiagram - direction LR +存储和流式传输使用统一格式: - class ToolFactory { - <> - +tool(name, description, parameters, category)$ decorator - +register_tool(name, handler, description, parameters, category)$ void - } +```json +{ + "id": "call_xxx", + "type": "function", + "function": { + "name": "web_search", + "arguments": "{\"query\": \"...\"}" + }, + "result": "{\"success\": true, ...}", + "success": true, + "skipped": false, + "execution_time": 0 +} +``` - class ToolDefinition { - +str name - +str description - +dict parameters - +Callable handler - +str category - } +前端使用 `call.function.name` 获取工具名称。 - ToolFactory ..> ToolDefinition : creates +--- + +## 三、上下文注入 + +### context 参数 + +`process_tool_calls()` 接受 `context` 参数,用于自动注入工具参数: + +```python +# backend/tools/executor.py + +def process_tool_calls( + self, + tool_calls: List[dict], + context: Optional[dict] = None +) -> List[dict]: + """ + Args: + tool_calls: LLM 返回的工具调用列表 + context: 上下文信息,支持: + - project_id: 自动注入到文件工具 + """ + for call in tool_calls: + name = call["function"]["name"] + args = json.loads(call["function"]["arguments"]) + + # 自动注入 project_id 到文件工具 + if context: + if name.startswith("file_") and "project_id" in context: + args["project_id"] = context["project_id"] + + result = self.registry.execute(name, args) +``` + +### 使用示例 + +```python +# backend/services/chat.py + +def stream_response(self, conv, tools_enabled=True, project_id=None): + # 构建上下文 + context = {"project_id": project_id} if project_id else None + + # 处理工具调用时自动注入 + tool_results = self.executor.process_tool_calls(tool_calls, context) ``` --- -## 三、核心类实现 +## 四、文件工具安全设计 -### 3.1 ToolDefinition +### project_id 必需 + +所有文件工具都需要 `project_id` 参数: ```python -from dataclasses import dataclass, field -from typing import Callable, Any - -@dataclass -class ToolDefinition: - """工具定义""" - name: str - description: str - parameters: dict # JSON Schema - handler: Callable[[dict], Any] - category: str = "general" - - def to_openai_format(self) -> dict: - return { - "type": "function", - "function": { - "name": self.name, - "description": self.description, - "parameters": self.parameters - } - } -``` - -### 3.2 ToolResult - -```python -from dataclasses import dataclass -from typing import Any, Optional - -@dataclass -class ToolResult: - """工具执行结果""" - success: bool - data: Any = None - error: Optional[str] = None - - def to_dict(self) -> dict: - return { - "success": self.success, - "data": self.data, - "error": self.error - } - - @classmethod - def ok(cls, data: Any) -> "ToolResult": - return cls(success=True, data=data) - - @classmethod - def fail(cls, error: str) -> "ToolResult": - return cls(success=False, error=error) -``` - -### 3.3 ToolRegistry - -```python -from typing import Dict, List, Optional - -class ToolRegistry: - """工具注册表(单例)""" - _instance = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._tools: Dict[str, ToolDefinition] = {} - return cls._instance - - def register(self, tool: ToolDefinition) -> None: - """注册工具""" - self._tools[tool.name] = tool - - def get(self, name: str) -> Optional[ToolDefinition]: - """获取工具定义""" - return self._tools.get(name) - - def list_all(self) -> List[dict]: - """列出所有工具(OpenAI 格式)""" - return [t.to_openai_format() for t in self._tools.values()] - - def list_by_category(self, category: str) -> List[dict]: - """按类别列出工具""" - return [ - t.to_openai_format() - for t in self._tools.values() - if t.category == category - ] - - def execute(self, name: str, arguments: dict) -> dict: - """执行工具""" - tool = self.get(name) - if not tool: - return ToolResult.fail(f"Tool not found: {name}").to_dict() - - try: - result = tool.handler(arguments) - if isinstance(result, ToolResult): - return result.to_dict() - return ToolResult.ok(result).to_dict() - except Exception as e: - return ToolResult.fail(str(e)).to_dict() - - def remove(self, name: str) -> bool: - """移除工具""" - if name in self._tools: - del self._tools[name] - return True - return False - - def has(self, name: str) -> bool: - """检查工具是否存在""" - return name in self._tools - - -# 全局注册表 -registry = ToolRegistry() -``` - -### 3.4 ToolExecutor - -```python -import json -import time -import hashlib -from typing import List, Dict, Optional - -class ToolExecutor: - """工具执行器(支持缓存和重复检测)""" - - def __init__( - self, - registry: ToolRegistry = None, - api_url: str = None, - api_key: str = None, - enable_cache: bool = True, - cache_ttl: int = 300, # 5分钟 - ): - self.registry = registry or ToolRegistry() - self.api_url = api_url - self.api_key = api_key - self.enable_cache = enable_cache - self.cache_ttl = cache_ttl - self._cache: Dict[str, tuple] = {} # key -> (result, timestamp) - self._call_history: List[dict] = [] # 当前会话的调用历史 - - def _execute_with_retry(self, name: str, arguments: dict) -> dict: - """ - 执行工具,不自动重试。 - 成功或失败都直接返回结果,由模型决定下一步操作。 - """ - result = self.registry.execute(name, arguments) - return result - - def _make_cache_key(self, name: str, args: dict) -> str: - """生成缓存键""" - args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) - return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest() - - def _get_cached(self, key: str) -> Optional[dict]: - """获取缓存结果""" - if not self.enable_cache: - return None - if key in self._cache: - result, timestamp = self._cache[key] - if time.time() - timestamp < self.cache_ttl: - return result - del self._cache[key] - return None - - def _set_cache(self, key: str, result: dict) -> None: - """设置缓存""" - if self.enable_cache: - self._cache[key] = (result, time.time()) - - def _check_duplicate_in_history(self, name: str, args: dict) -> Optional[dict]: - """检查历史中是否有相同调用""" - args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) - for record in self._call_history: - if record["name"] == name and record["args_str"] == args_str: - return record["result"] - return None - - def clear_history(self) -> None: - """清空调用历史(新会话开始时调用)""" - self._call_history.clear() - - def process_tool_calls( - self, - tool_calls: List[dict], - context: dict = None - ) -> List[dict]: - """ - 处理工具调用,返回消息列表 - - Args: - tool_calls: LLM 返回的工具调用列表 - context: 可选上下文信息(user_id 等) - - Returns: - 工具响应消息列表,可直接追加到 messages - """ - results = [] - seen_calls = set() # 当前批次内的重复检测 - - for call in tool_calls: - name = call["function"]["name"] - args_str = call["function"]["arguments"] - call_id = call["id"] - - try: - args = json.loads(args_str) if isinstance(args_str, str) else args_str - except json.JSONDecodeError: - results.append(self._create_error_result( - call_id, name, "Invalid JSON arguments" - )) - continue - - # 检查批次内重复 - call_key = f"{name}:{json.dumps(args, sort_keys=True)}" - if call_key in seen_calls: - results.append(self._create_tool_result( - call_id, name, - {"success": True, "data": None, "cached": True, "duplicate": True} - )) - continue - seen_calls.add(call_key) - - # 检查历史重复 - history_result = self._check_duplicate_in_history(name, args) - if history_result is not None: - result = {**history_result, "cached": True} - results.append(self._create_tool_result(call_id, name, result)) - continue - - # 检查缓存 - cache_key = self._make_cache_key(name, args) - cached_result = self._get_cached(cache_key) - if cached_result is not None: - result = {**cached_result, "cached": True} - results.append(self._create_tool_result(call_id, name, result)) - continue - - # 执行工具 - result = self.registry.execute(name, args) - - # 缓存结果 - self._set_cache(cache_key, result) - - # 添加到历史 - self._call_history.append({ - "name": name, - "args_str": json.dumps(args, sort_keys=True, ensure_ascii=False), - "result": result - }) - - results.append(self._create_tool_result(call_id, name, result)) - - return results - - def _create_tool_result( - self, - call_id: str, - name: str, - result: dict, - execution_time: float = 0 - ) -> dict: - """创建工具结果消息""" - result["execution_time"] = execution_time - return { - "role": "tool", - "tool_call_id": call_id, - "name": name, - "content": json.dumps(result, ensure_ascii=False, default=str) - } - - def _create_error_result( - self, - call_id: str, - name: str, - error: str - ) -> dict: - """创建错误结果消息""" - return { - "role": "tool", - "tool_call_id": call_id, - "name": name, - "content": json.dumps({ - "success": False, - "error": error - }, ensure_ascii=False) - } - - def build_request( - self, - messages: List[dict], - model: str = "glm-5", - tools: List[dict] = None, - **kwargs - ) -> dict: - """构建 API 请求体""" - return { - "model": model, - "messages": messages, - "tools": tools or self.registry.list_all(), - "tool_choice": kwargs.get("tool_choice", "auto"), - **{k: v for k, v in kwargs.items() if k not in ["tool_choice"]} - } -``` - ---- - -## 四、工具工厂模式 - -### 4.1 装饰器注册 - -```python -# backend/tools/factory.py - -from typing import Callable -from backend.tools.core import ToolDefinition, registry - - -def tool( - name: str, - description: str, - parameters: dict, - category: str = "general" -) -> Callable: - """ - 工具注册装饰器 - - 用法: - @tool( - name="web_search", - description="搜索互联网获取信息", - parameters={"type": "object", "properties": {...}}, - category="crawler" - ) - def web_search(arguments: dict) -> dict: - ... - """ - def decorator(func: Callable) -> Callable: - tool_def = ToolDefinition( - name=name, - description=description, - parameters=parameters, - handler=func, - category=category - ) - registry.register(tool_def) - return func - return decorator - - -def register_tool( - name: str, - handler: Callable, - description: str, - parameters: dict, - category: str = "general" -) -> None: - """ - 直接注册工具(无需装饰器) - - 用法: - register_tool( - name="my_tool", - handler=my_function, - description="工具描述", - parameters={...} - ) - """ - tool_def = ToolDefinition( - name=name, - description=description, - parameters=parameters, - handler=handler, - category=category +@tool( + name="file_read", + description="Read content from a file within the project workspace.", + parameters={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "File path"}, + "project_id": {"type": "string", "description": "Project ID (required)"}, + "encoding": {"type": "string", "default": "utf-8"} + }, + "required": ["path", "project_id"] + }, + category="file" +) +def file_read(arguments: dict) -> dict: + path, project_dir = _resolve_path( + arguments["path"], + arguments.get("project_id") ) - registry.register(tool_def) + # ... ``` -### 4.2 使用示例 +### 路径验证 + +`_resolve_path()` 函数强制验证路径在项目内: ```python -# backend/tools/builtin/crawler.py - -from backend.tools.factory import tool -from backend.tools.services import SearchService, FetchService - -# 网页搜索工具 -@tool( - name="web_search", - description="Search the internet for information. Use when you need to find latest news or answer questions that require web search.", - parameters={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search keywords" - }, - "max_results": { - "type": "integer", - "description": "Number of results to return, default 5", - "default": 5 - } - }, - "required": ["query"] - }, - category="crawler" -) -def web_search(arguments: dict) -> dict: - """Web search tool""" - query = arguments["query"] - max_results = arguments.get("max_results", 5) - service = SearchService() - results = service.search(query, max_results) - return {"results": results} - - -# 页面抓取工具 -@tool( - name="fetch_page", - description="Fetch content from a specific webpage. Use when user needs detailed information from a webpage.", - parameters={ - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "URL of the webpage to fetch" - }, - "extract_type": { - "type": "string", - "description": "Extraction type", - "enum": ["text", "links", "structured"], - "default": "text" - } - }, - "required": ["url"] - }, - category="crawler" -) -def fetch_page(arguments: dict) -> dict: - """Page fetch tool""" - url = arguments["url"] - extract_type = arguments.get("extract_type", "text") - service = FetchService() - result = service.fetch(url, extract_type) - return result - - -# 批量抓取工具 -@tool( - name="crawl_batch", - description="Batch fetch multiple webpages. Use when you need to get content from multiple pages at once.", - parameters={ - "type": "object", - "properties": { - "urls": { - "type": "array", - "items": {"type": "string"}, - "description": "List of URLs to fetch" - }, - "extract_type": { - "type": "string", - "enum": ["text", "links", "structured"], - "default": "text" - } - }, - "required": ["urls"] - }, - category="crawler" -) -def crawl_batch(arguments: dict) -> dict: - """Batch fetch tool""" - urls = arguments["urls"] - extract_type = arguments.get("extract_type", "text") - - if len(urls) > 10: - return {"error": "Maximum 10 pages can be fetched at once"} - - service = FetchService() - results = service.fetch_batch(urls, extract_type) - return {"results": results, "total": len(results)} +def _resolve_path(path: str, project_id: str = None) -> Tuple[Path, Path]: + if not project_id: + raise ValueError("project_id is required for file operations") + + project = db.session.get(Project, project_id) + if not project: + raise ValueError(f"Project not found: {project_id}") + + project_dir = get_project_path(project.id, project.path) + + # 核心安全验证 + return validate_path_in_project(path, project_dir), project_dir ``` ---- - -## 五、辅助服务类 - -工具依赖的服务保持独立,不与工具类耦合: +### 安全隔离示例 ```python -# backend/tools/services.py +# 即使模型尝试访问敏感文件 +file_read({"path": "../../../etc/passwd", "project_id": "xxx"}) +# -> ValueError: Path is outside project directory -from typing import List, Dict -from ddgs import DDGS -import re +file_read({"path": "/etc/passwd", "project_id": "xxx"}) +# -> ValueError: Path is outside project directory - -class SearchService: - """搜索服务""" - - def __init__(self, engine: str = "duckduckgo"): - self.engine = engine - - def search( - self, - query: str, - max_results: int = 5, - region: str = "cn-zh" - ) -> List[dict]: - """执行搜索""" - if self.engine == "duckduckgo": - return self._search_duckduckgo(query, max_results, region) - else: - raise ValueError(f"Unsupported search engine: {self.engine}") - - def _search_duckduckgo( - self, - query: str, - max_results: int, - region: str - ) -> List[dict]: - """DuckDuckGo 搜索""" - with DDGS() as ddgs: - results = list(ddgs.text( - query, - max_results=max_results, - region=region - )) - - return [ - { - "title": r.get("title", ""), - "url": r.get("href", ""), - "snippet": r.get("body", "") - } - for r in results - ] - - -class FetchService: - """页面抓取服务""" - - def __init__(self, timeout: float = 30.0, user_agent: str = None): - self.timeout = timeout - self.user_agent = user_agent or ( - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " - "AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/120.0.0.0 Safari/537.36" - ) - - def fetch(self, url: str, extract_type: str = "text") -> dict: - """抓取单个页面""" - import httpx - - try: - resp = httpx.get( - url, - timeout=self.timeout, - follow_redirects=True, - headers={"User-Agent": self.user_agent} - ) - resp.raise_for_status() - except Exception as e: - return {"error": str(e), "url": url} - - html = resp.text - extractor = ContentExtractor(html) - - if extract_type == "text": - return { - "url": url, - "text": extractor.extract_text() - } - elif extract_type == "links": - return { - "url": url, - "links": extractor.extract_links() - } - else: - return extractor.extract_structured(url) - - def fetch_batch( - self, - urls: List[str], - extract_type: str = "text", - max_concurrent: int = 5 - ) -> List[dict]: - """批量抓取页面""" - results = [] - for url in urls: - results.append(self.fetch(url, extract_type)) - return results - - -class ContentExtractor: - """内容提取器""" - - def __init__(self, html: str): - self.html = html - self._soup = None - - @property - def soup(self): - if self._soup is None: - try: - from bs4 import BeautifulSoup - self._soup = BeautifulSoup(self.html, "html.parser") - except ImportError: - raise ImportError("Please install beautifulsoup4: pip install beautifulsoup4") - return self._soup - - def extract_text(self) -> str: - """提取纯文本""" - # 移除脚本和样式 - for tag in self.soup(["script", "style", "nav", "footer", "header"]): - tag.decompose() - - text = self.soup.get_text(separator="\n", strip=True) - # 清理多余空白 - text = re.sub(r"\n{3,}", "\n\n", text) - return text - - def extract_links(self) -> List[dict]: - """提取链接""" - links = [] - for a in self.soup.find_all("a", href=True): - text = a.get_text(strip=True) - href = a["href"] - if text and href and not href.startswith(("#", "javascript:")): - links.append({"text": text, "href": href}) - return links[:50] # 限制数量 - - def extract_structured(self, url: str = "") -> dict: - """提取结构化内容""" - soup = self.soup - - # 提取标题 - title = "" - if soup.title: - title = soup.title.string or "" - - # 提取 meta 描述 - description = "" - meta_desc = soup.find("meta", attrs={"name": "description"}) - if meta_desc: - description = meta_desc.get("content", "") - - return { - "url": url, - "title": title.strip(), - "description": description.strip(), - "text": self.extract_text()[:5000], # 限制长度 - "links": self.extract_links()[:20] - } - - -class CalculatorService: - """安全计算服务""" - - ALLOWED_OPS = { - "add", "sub", "mul", "truediv", "floordiv", - "mod", "pow", "neg", "abs" - } - - def evaluate(self, expression: str) -> dict: - """安全计算数学表达式""" - import ast - import operator - - ops = { - ast.Add: operator.add, - ast.Sub: operator.sub, - ast.Mult: operator.mul, - ast.Div: operator.truediv, - ast.FloorDiv: operator.floordiv, - ast.Mod: operator.mod, - ast.Pow: operator.pow, - ast.USub: operator.neg, - ast.UAdd: operator.pos, - } - - try: - # 解析表达式 - node = ast.parse(expression, mode="eval") - - # 验证节点类型 - for child in ast.walk(node): - if isinstance(child, ast.Call): - return {"error": "Function calls not allowed"} - if isinstance(child, ast.Name): - return {"error": "Variable names not allowed"} - - # 安全执行 - result = eval( - compile(node, "", "eval"), - {"__builtins__": {}}, - {} - ) - - return {"result": result} - - except Exception as e: - return {"error": f"Calculation error: {str(e)}"} +# 只有项目内路径才被允许 +file_read({"path": "src/main.py", "project_id": "xxx"}) +# -> 成功读取 ``` --- -## 六、工具初始化 +## 五、工具清单 -```python -# backend/tools/__init__.py +### 5.1 爬虫工具 (crawler) -""" -NanoClaw Tool System +| 工具名称 | 描述 | 参数 | +|---------|------|------| +| `web_search` | 搜索互联网获取信息 | `query`: 搜索关键词
`max_results`: 结果数量(默认 5) | +| `fetch_page` | 抓取单个网页内容 | `url`: 网页 URL
`extract_type`: 提取类型(text/links/structured) | +| `crawl_batch` | 批量抓取多个网页(最多 10 个) | `urls`: URL 列表
`extract_type`: 提取类型 | -Usage: - from backend.tools import registry, ToolExecutor, tool - from backend.tools import init_tools +### 5.2 数据处理工具 (data) - # 初始化内置工具 - init_tools() +| 工具名称 | 描述 | 参数 | +|---------|------|------| +| `calculator` | 执行数学计算 | `expression`: 数学表达式 | +| `text_process` | 文本处理 | `text`: 文本内容
`operation`: 操作类型 | +| `json_process` | JSON 处理 | `json_string`: JSON 字符串
`operation`: 操作类型 | - # 列出所有工具 - tools = registry.list_all() +### 5.3 代码执行 (code) - # 执行工具 - result = registry.execute("web_search", {"query": "Python"}) -""" +| 工具名称 | 描述 | 参数 | +|---------|------|------| +| `execute_python` | 在沙箱环境中执行 Python 代码 | `code`: Python 代码 | -from backend.tools.core import ToolDefinition, ToolResult, ToolRegistry, registry -from backend.tools.factory import tool, register_tool -from backend.tools.executor import ToolExecutor +安全措施: +- 白名单模块限制 +- 危险内置函数禁止 +- 10 秒超时限制 +- 无文件系统访问 +- 无网络访问 +### 5.4 文件操作工具 (file) -def init_tools() -> None: - """ - 初始化所有内置工具 +**所有文件工具需要 `project_id` 参数** - 导入 builtin 模块会自动注册所有装饰器定义的工具 - """ - from backend.tools.builtin import crawler, data, weather, file_ops # noqa: F401 +| 工具名称 | 描述 | 参数 | +|---------|------|------| +| `file_read` | 读取文件内容 | `path`, `project_id`, `encoding` | +| `file_write` | 写入文件 | `path`, `content`, `project_id`, `mode` | +| `file_delete` | 删除文件 | `path`, `project_id` | +| `file_list` | 列出目录内容 | `path`, `pattern`, `project_id` | +| `file_exists` | 检查文件是否存在 | `path`, `project_id` | +| `file_mkdir` | 创建目录 | `path`, `project_id` | +### 5.5 天气工具 (weather) -# 公开 API 导出 -__all__ = [ - # 核心类 - "ToolDefinition", - "ToolResult", - "ToolRegistry", - "ToolExecutor", - # 实例 - "registry", - # 工厂函数 - "tool", - "register_tool", - # 初始化 - "init_tools", -] -``` +| 工具名称 | 描述 | 参数 | +|---------|------|------| +| `get_weather` | 查询天气信息(模拟) | `city`: 城市名称 | --- -## 七、工具清单 +## 六、核心特性 -### 7.1 爬虫工具 (crawler) - -| 工具名称 | 描述 | 参数 | -| --------------- | --------------------------- | --------------------------------------- | -| `web_search` | 搜索互联网获取信息 | `query`: 搜索关键词
`max_results`: 结果数量(默认 5) | -| `fetch_page` | 抓取单个网页内容 | `url`: 网页 URL
`extract_type`: 提取类型(text/links/structured) | -| `crawl_batch` | 批量抓取多个网页(最多 10 个) | `urls`: URL 列表
`extract_type`: 提取类型 | - -### 7.2 数据处理工具 (data) - -| 工具名称 | 描述 | 参数 | -| --------------- | --------------------------- | --------------------------------------- | -| `calculator` | 执行数学计算(支持加减乘除、幂、模等) | `expression`: 数学表达式 | -| `text_process` | 文本处理(计数、格式转换等) | `text`: 文本内容
`operation`: 操作类型(count/lines/words/upper/lower/reverse) | -| `json_process` | JSON 处理(解析、格式化、提取、验证) | `json_string`: JSON 字符串
`operation`: 操作类型(parse/format/keys/validate) | - -### 7.3 天气工具 (weather) - -| 工具名称 | 描述 | 参数 | -| --------------- | --------------------------- | --------------------------------------- | -| `get_weather` | 查询指定城市的天气信息(模拟数据) | `city`: 城市名称(如:北京、上海、广州) | - -### 7.4 文件操作工具 (file) - -| 工具名称 | 描述 | 参数 | -| --------------- | --------------------------- | --------------------------------------- | -| `file_read` | 读取文件内容 | `path`: 文件路径
`encoding`: 编码(默认 utf-8) | -| `file_write` | 写入文件(支持覆盖和追加) | `path`: 文件路径
`content`: 内容
`mode`: 写入模式(write/append) | -| `file_delete` | 删除文件 | `path`: 文件路径 | -| `file_list` | 列出目录内容 | `path`: 目录路径(默认 .)
`pattern`: 文件模式(默认 *) | -| `file_exists` | 检查文件或目录是否存在 | `path`: 路径 | -| `file_mkdir` | 创建目录(自动创建父目录) | `path`: 目录路径 | - -**安全说明**:文件操作工具限制在项目根目录内,防止越权访问。 - ---- - -## 八、与旧设计对比 - -| 方面 | 旧设计 | 新设计 | -| --------- | ----------------- | ----------------- | -| 类数量 | 30+ | ~10 | -| 工具定义 | 继承 BaseTool | 装饰器 + 函数 | -| 中间抽象层 | 5个(CrawlerTool 等) | 无 | -| 扩展方式 | 创建子类 | 写函数 + 装饰器 | -| 缓存机制 | 无 | 支持结果缓存(TTL 可配置) | -| 重复检测 | 无 | 支持会话内重复调用检测 | -| 代码量 | 多 | 少 | - ---- - -## 九、核心特性 - -### 9.1 装饰器注册 +### 6.1 装饰器注册 简化工具定义,只需一个装饰器: @@ -918,48 +258,100 @@ __all__ = [ @tool( name="my_tool", description="工具描述", - parameters={...}, + parameters={ + "type": "object", + "properties": { + "param1": {"type": "string", "description": "参数1"} + }, + "required": ["param1"] + }, category="custom" ) def my_tool(arguments: dict) -> dict: - # 工具实现 return {"result": "ok"} ``` -### 9.2 智能缓存 +### 6.2 智能缓存 - **结果缓存**:相同参数的工具调用结果会被缓存(默认 5 分钟) - **可配置 TTL**:通过 `cache_ttl` 参数设置缓存过期时间 - **可禁用**:通过 `enable_cache=False` 关闭缓存 -### 9.3 重复检测 +### 6.3 重复检测 - **批次内去重**:同一批次中相同工具+参数的调用会被跳过 - **历史去重**:同一会话内已调用过的工具会直接返回缓存结果 - **自动清理**:新会话开始时调用 `clear_history()` 清理历史 -### 9.4 无自动重试 +### 6.4 无自动重试 - **直接返回结果**:工具执行成功或失败都直接返回,不自动重试 -- **模型决策**:失败时返回错误信息,由模型决定是否重试或尝试其他工具 -- **灵活性**:模型可以根据错误类型选择不同的解决策略 +- **模型决策**:失败时返回错误信息,由模型决定是否重试 -### 9.5 安全设计 +### 6.5 安全设计 -- **计算器安全**:禁止函数调用和变量名,只支持数学运算 -- **文件沙箱**:文件操作限制在项目根目录内,防止越权访问 -- **错误处理**:所有工具执行都有 try-catch,不会因工具错误导致系统崩溃 +- **文件沙箱**:所有文件操作限制在项目目录内 +- **代码沙箱**:Python 执行限制模块和函数 +- **错误处理**:所有工具执行都有 try-catch --- -## 十、总结 +## 七、工具初始化 -简化后的设计特点: +```python +# backend/tools/__init__.py -1. **核心类**:`ToolDefinition`、`ToolRegistry`、`ToolExecutor`、`ToolResult` -2. **工厂模式**:使用 `@tool` 装饰器注册工具 -3. **服务分离**:工具依赖的服务独立,不与工具类耦合 -4. **性能优化**:支持缓存和重复检测,减少重复计算和网络请求 -5. **智能决策**:工具执行失败时不自动重试,由模型决定下一步操作 -6. **易于扩展**:新增工具只需写一个函数并加装饰器 -7. **安全可靠**:文件沙箱、安全计算、完善的错误处理 +def init_tools() -> None: + """初始化所有内置工具""" + from backend.tools.builtin import ( + code, crawler, data, weather, file_ops + ) +``` + +--- + +## 八、扩展新工具 + +### 添加新工具 + +1. 在 `backend/tools/builtin/` 下创建或编辑文件 +2. 使用 `@tool` 装饰器定义工具 +3. 在 `backend/tools/builtin/__init__.py` 中导入 + +### 示例:添加数据库查询工具 + +```python +# backend/tools/builtin/database.py + +from backend.tools.factory import tool + +@tool( + name="db_query", + description="Execute a read-only database query", + parameters={ + "type": "object", + "properties": { + "sql": { + "type": "string", + "description": "SELECT query (read-only)" + }, + "project_id": { + "type": "string", + "description": "Project ID for isolation" + } + }, + "required": ["sql", "project_id"] + }, + category="database" +) +def db_query(arguments: dict) -> dict: + sql = arguments["sql"] + project_id = arguments["project_id"] + + # 安全检查:只允许 SELECT + if not sql.strip().upper().startswith("SELECT"): + return {"success": False, "error": "Only SELECT queries allowed"} + + # 执行查询... + return {"success": True, "rows": [...]} +``` diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 4e1e6d9..35f21f8 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -5,10 +5,12 @@ :current-id="currentConvId" :loading="loadingConvs" :has-more="hasMoreConvs" + :current-project="currentProject" @select="selectConversation" @create="createConversation" @delete="deleteConversation" @load-more="loadMoreConversations" + @select-project="selectProject" /> conversations.value.find(c => c.id === currentConvId.value) || null @@ -211,7 +214,7 @@ async function sendMessage(data) { streamToolCalls.value = [] streamProcessSteps.value = [] - currentStreamPromise = messageApi.send(convId, { text, attachments }, { + currentStreamPromise = messageApi.send(convId, { text, attachments, projectId: currentProject.value?.id }, { stream: true, toolsEnabled: toolsEnabled.value, onThinkingStart() { @@ -383,6 +386,7 @@ async function regenerateMessage(msgId) { currentStreamPromise = messageApi.regenerate(convId, msgId, { toolsEnabled: toolsEnabled.value, + projectId: currentProject.value?.id, onThinkingStart() { if (currentConvId.value === convId) { streamThinking.value = '' @@ -493,6 +497,11 @@ function updateToolsEnabled(val) { localStorage.setItem('tools_enabled', String(val)) } +// -- Select project -- +function selectProject(project) { + currentProject.value = project +} + // -- Init -- onMounted(() => { loadConversations() diff --git a/frontend/src/api/index.js b/frontend/src/api/index.js index 7c0fad5..7e7eee4 100644 --- a/frontend/src/api/index.js +++ b/frontend/src/api/index.js @@ -99,7 +99,7 @@ export const messageApi = { if (!stream) { return request(`/conversations/${convId}/messages`, { method: 'POST', - body: { text: data.text, attachments: data.attachments, stream: false, tools_enabled: toolsEnabled }, + body: { text: data.text, attachments: data.attachments, stream: false, tools_enabled: toolsEnabled, project_id: data.projectId }, }) } @@ -110,7 +110,7 @@ export const messageApi = { const res = await fetch(`${BASE}/conversations/${convId}/messages`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ text: data.text, attachments: data.attachments, stream: true, tools_enabled: toolsEnabled }), + body: JSON.stringify({ text: data.text, attachments: data.attachments, stream: true, tools_enabled: toolsEnabled, project_id: data.projectId }), signal: controller.signal, }) @@ -173,7 +173,7 @@ export const messageApi = { return request(`/conversations/${convId}/messages/${msgId}`, { method: 'DELETE' }) }, - regenerate(convId, msgId, { toolsEnabled = true, onThinkingStart, onThinking, onMessage, onToolCalls, onToolResult, onProcessStep, onDone, onError } = {}) { + regenerate(convId, msgId, { toolsEnabled = true, projectId, onThinkingStart, onThinking, onMessage, onToolCalls, onToolResult, onProcessStep, onDone, onError } = {}) { const controller = new AbortController() const promise = (async () => { @@ -181,7 +181,7 @@ export const messageApi = { const res = await fetch(`${BASE}/conversations/${convId}/regenerate/${msgId}`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ tools_enabled: toolsEnabled }), + body: JSON.stringify({ tools_enabled: toolsEnabled, project_id: projectId }), signal: controller.signal, }) @@ -240,3 +240,43 @@ export const messageApi = { return promise }, } + +export const projectApi = { + list(userId) { + return request(`/projects?user_id=${userId}`) + }, + + create(data) { + return request('/projects', { + method: 'POST', + body: data, + }) + }, + + get(projectId) { + return request(`/projects/${projectId}`) + }, + + update(projectId, data) { + return request(`/projects/${projectId}`, { + method: 'PUT', + body: data, + }) + }, + + delete(projectId) { + return request(`/projects/${projectId}`, { method: 'DELETE' }) + }, + + uploadFolder(data) { + return request('/projects/upload', { + method: 'POST', + body: data, + }) + }, + + listFiles(projectId, path = '') { + const params = path ? `?path=${encodeURIComponent(path)}` : '' + return request(`/projects/${projectId}/files${params}`) + }, +} diff --git a/frontend/src/components/ProcessBlock.vue b/frontend/src/components/ProcessBlock.vue index d3943bb..31d0d74 100644 --- a/frontend/src/components/ProcessBlock.vue +++ b/frontend/src/components/ProcessBlock.vue @@ -175,10 +175,12 @@ const processItems = computed(() => { if (props.toolCalls && props.toolCalls.length > 0) { props.toolCalls.forEach((call, i) => { + const toolName = call.function?.name || '未知工具' + items.push({ type: 'tool_call', - label: `调用工具: ${call.function?.name || '未知工具'}`, - toolName: call.function?.name || '未知工具', + label: `调用工具: ${toolName}`, + toolName: toolName, arguments: formatArgs(call.function?.arguments), id: call.id, index: idx, @@ -191,7 +193,7 @@ const processItems = computed(() => { const resultSummary = getResultSummary(call.result) items.push({ type: 'tool_result', - label: `工具返回: ${call.function?.name || '未知工具'}`, + label: `工具返回: ${toolName}`, content: formatResult(call.result), summary: resultSummary.text, isSuccess: resultSummary.success, @@ -204,7 +206,7 @@ const processItems = computed(() => { } else if (props.streaming) { // 工具正在执行中 items[items.length - 1].loading = true - items[items.length - 1].label = `执行工具: ${call.function?.name || '未知工具'}` + items[items.length - 1].label = `执行工具: ${toolName}` } }) } diff --git a/frontend/src/components/ProjectManager.vue b/frontend/src/components/ProjectManager.vue new file mode 100644 index 0000000..bba4bc9 --- /dev/null +++ b/frontend/src/components/ProjectManager.vue @@ -0,0 +1,518 @@ + + + + + diff --git a/frontend/src/components/Sidebar.vue b/frontend/src/components/Sidebar.vue index 8235f23..de8fdf4 100644 --- a/frontend/src/components/Sidebar.vue +++ b/frontend/src/components/Sidebar.vue @@ -1,5 +1,39 @@