diff --git a/backend/__init__.py b/backend/__init__.py index 87bf0a5..f603d89 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -2,6 +2,7 @@ import os import yaml from flask import Flask from flask_sqlalchemy import SQLAlchemy +from flask_cors import CORS from pathlib import Path # Initialize db BEFORE importing models/routes that depend on it @@ -25,6 +26,9 @@ def create_app(): ) app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False + # Enable CORS for all routes + CORS(app) + db.init_app(app) # Import after db is initialized diff --git a/backend/models.py b/backend/models.py index f94c8cc..179f45f 100644 --- a/backend/models.py +++ b/backend/models.py @@ -39,10 +39,16 @@ class Message(db.Model): id = db.Column(db.String(64), primary_key=True) conversation_id = db.Column(db.String(64), db.ForeignKey("conversations.id"), nullable=False) - role = db.Column(db.String(16), nullable=False) + role = db.Column(db.String(16), nullable=False) # user, assistant, system, tool content = db.Column(db.Text, default="") token_count = db.Column(db.Integer, default=0) thinking_content = db.Column(db.Text, default="") + + # Tool call support + tool_calls = db.Column(db.Text) # JSON string: tool call requests (assistant messages) + tool_call_id = db.Column(db.String(64)) # Tool call ID (tool messages) + name = db.Column(db.String(64)) # Tool name (tool messages) + created_at = db.Column(db.DateTime, default=datetime.utcnow) diff --git a/backend/routes.py b/backend/routes.py index 596ef2d..08f32fd 100644 --- a/backend/routes.py +++ b/backend/routes.py @@ -47,6 +47,17 @@ def to_dict(inst, **extra): for k in ("created_at", "updated_at"): if k in d and hasattr(d[k], "strftime"): d[k] = d[k].strftime("%Y-%m-%dT%H:%M:%SZ") + + # Parse tool_calls JSON if present + if "tool_calls" in d and d["tool_calls"]: + try: + d["tool_calls"] = json.loads(d["tool_calls"]) + except: + pass + + # Filter out None values for cleaner API response + d = {k: v for k, v in d.items() if v is not None} + d.update(extra) return d @@ -296,10 +307,12 @@ def message_list(conv_id): db.session.add(user_msg) db.session.commit() - if d.get("stream", False): - return _stream_response(conv) + tools_enabled = d.get("tools_enabled", True) - return _sync_response(conv) + if d.get("stream", False): + return _stream_response(conv, tools_enabled) + + return _sync_response(conv, tools_enabled) @bp.route("/api/conversations//messages/", methods=["DELETE"]) @@ -339,16 +352,20 @@ def _call_glm(conv, stream=False, tools=None, messages=None): ) -def _sync_response(conv): +def _sync_response(conv, tools_enabled=True): """Sync response with tool call support""" executor = ToolExecutor(registry=registry) - tools = registry.list_all() + tools = registry.list_all() if tools_enabled else None messages = build_glm_messages(conv) max_iterations = 5 # Max tool call iterations + + # Collect all tool calls and results + all_tool_calls = [] + all_tool_results = [] for _ in range(max_iterations): try: - resp = _call_glm(conv, tools=tools if tools else None, messages=messages) + resp = _call_glm(conv, tools=tools, messages=messages) resp.raise_for_status() result = resp.json() except Exception as e: @@ -363,11 +380,23 @@ def _sync_response(conv): prompt_tokens = usage.get("prompt_tokens", 0) completion_tokens = usage.get("completion_tokens", 0) + # Merge tool results into tool_calls + merged_tool_calls = [] + for i, tc in enumerate(all_tool_calls): + merged_tc = dict(tc) + if i < len(all_tool_results): + merged_tc["result"] = all_tool_results[i]["content"] + merged_tool_calls.append(merged_tc) + + # Save assistant message with all tool calls (including results) msg = Message( - id=str(uuid.uuid4()), conversation_id=conv.id, role="assistant", + id=str(uuid.uuid4()), + conversation_id=conv.id, + role="assistant", content=message.get("content", ""), token_count=completion_tokens, thinking_content=message.get("reasoning_content", ""), + tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None ) db.session.add(msg) db.session.commit() @@ -386,39 +415,38 @@ def _sync_response(conv): # Process tool calls tool_calls = message["tool_calls"] + all_tool_calls.extend(tool_calls) messages.append(message) # Execute tools and add results tool_results = executor.process_tool_calls(tool_calls) + all_tool_results.extend(tool_results) messages.extend(tool_results) - # Save tool call records to database - for i, call in enumerate(tool_calls): - tool_msg = Message( - id=str(uuid.uuid4()), - conversation_id=conv.id, - role="tool", - content=tool_results[i]["content"] - ) - db.session.add(tool_msg) - db.session.commit() - return err(500, "exceeded maximum tool call iterations") -def _stream_response(conv): +def _stream_response(conv, tools_enabled=True): """Stream response with tool call support""" conv_id = conv.id conv_model = conv.model app = current_app._get_current_object() executor = ToolExecutor(registry=registry) - tools = registry.list_all() + tools = registry.list_all() if tools_enabled else None # Build messages BEFORE entering generator (in request context) initial_messages = build_glm_messages(conv) def generate(): messages = list(initial_messages) # Copy to avoid mutation max_iterations = 5 + + # Collect all tool calls and results + all_tool_calls = [] + all_tool_results = [] + total_content = "" + total_thinking = "" + total_tokens = 0 + total_prompt_tokens = 0 for iteration in range(max_iterations): full_content = "" @@ -432,7 +460,7 @@ def _stream_response(conv): try: with app.app_context(): active_conv = db.session.get(Conversation, conv_id) - resp = _call_glm(active_conv, stream=True, tools=tools if tools else None, messages=messages) + resp = _call_glm(active_conv, stream=True, tools=tools, messages=messages) resp.raise_for_status() for line in resp.iter_lines(): @@ -492,6 +520,9 @@ def _stream_response(conv): # If tool calls exist, execute and continue loop if tool_calls_list: + # Collect tool calls + all_tool_calls.extend(tool_calls_list) + # Send tool call info yield f"event: tool_calls\ndata: {json.dumps({'calls': tool_calls_list}, ensure_ascii=False)}\n\n" @@ -503,6 +534,9 @@ def _stream_response(conv): "tool_calls": tool_calls_list }) messages.extend(tool_results) + + # Collect tool results + all_tool_results.extend(tool_results) # Send tool results for tr in tool_results: @@ -510,19 +544,38 @@ def _stream_response(conv): continue - # No tool calls, finish + # No tool calls, finish - save everything + total_content = full_content + total_thinking = full_thinking + total_tokens = token_count + total_prompt_tokens = prompt_tokens + + # Merge tool results into tool_calls + merged_tool_calls = [] + for i, tc in enumerate(all_tool_calls): + merged_tc = dict(tc) + if i < len(all_tool_results): + merged_tc["result"] = all_tool_results[i]["content"] + merged_tool_calls.append(merged_tc) + with app.app_context(): + # Save assistant message with all tool calls (including results) msg = Message( - id=msg_id, conversation_id=conv_id, role="assistant", - content=full_content, token_count=token_count, thinking_content=full_thinking, + id=msg_id, + conversation_id=conv_id, + role="assistant", + content=total_content, + token_count=total_tokens, + thinking_content=total_thinking, + tool_calls=json.dumps(merged_tool_calls) if merged_tool_calls else None ) db.session.add(msg) db.session.commit() user = get_or_create_default_user() - record_token_usage(user.id, conv_model, prompt_tokens, token_count) + record_token_usage(user.id, conv_model, total_prompt_tokens, total_tokens) - yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': token_count})}\n\n" + yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': total_tokens})}\n\n" return yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n" diff --git a/backend/tools/__init__.py b/backend/tools/__init__.py index a49b612..bc3a71b 100644 --- a/backend/tools/__init__.py +++ b/backend/tools/__init__.py @@ -26,7 +26,7 @@ def init_tools() -> None: Importing builtin module automatically registers all decorator-defined tools """ - from .builtin import crawler, data # noqa: F401 + from .builtin import crawler, data, weather # noqa: F401 # Public API exports diff --git a/backend/tools/builtin/weather.py b/backend/tools/builtin/weather.py new file mode 100644 index 0000000..f70ad14 --- /dev/null +++ b/backend/tools/builtin/weather.py @@ -0,0 +1,57 @@ +"""Weather related tools""" +from ..factory import tool + + +@tool( + name="get_weather", + description="Get weather information for a specified city. Use when user asks about weather.", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name, e.g.: 北京, 上海, 广州" + } + }, + "required": ["city"] + }, + category="weather" +) +def get_weather(arguments: dict) -> dict: + """ + Weather query tool (simulated) + + Args: + arguments: { + "city": "北京" + } + + Returns: + { + "city": "北京", + "temperature": 25, + "humidity": 60, + "description": "晴天" + } + """ + city = arguments["city"] + + # 模拟天气数据 + weather_data = { + "北京": {"temperature": 25, "humidity": 60, "description": "晴天"}, + "上海": {"temperature": 28, "humidity": 75, "description": "多云"}, + "广州": {"temperature": 32, "humidity": 85, "description": "雷阵雨"}, + "深圳": {"temperature": 30, "humidity": 80, "description": "阴天"}, + } + + data = weather_data.get(city, { + "temperature": 22, + "humidity": 65, + "description": "晴天" + }) + + return { + "city": city, + **data, + "query_time": "2026-03-24 12:00:00" + } diff --git a/docs/Design.md b/docs/Design.md index 86c1c70..ff6dac7 100644 --- a/docs/Design.md +++ b/docs/Design.md @@ -20,6 +20,19 @@ | `POST` | `/api/conversations/:id/messages` | 发送消息(对话补全,支持 `stream` 流式) | | `DELETE` | `/api/conversations/:id/messages/:message_id` | 删除消息 | +### 模型与工具 + +| 方法 | 路径 | 说明 | +| ------ | ------------- | -------- | +| `GET` | `/api/models` | 获取模型列表 | +| `GET` | `/api/tools` | 获取工具列表 | + +### 统计信息 + +| 方法 | 路径 | 说明 | +| ------ | -------------------- | ---------------- | +| `GET` | `/api/stats/tokens` | 获取 Token 使用统计 | + --- ## API 接口 @@ -219,6 +232,8 @@ POST /api/conversations/:id/messages **流式响应 (stream=true):** +**普通回复示例:** + ``` HTTP/1.1 200 OK Content-Type: text/event-stream @@ -239,6 +254,37 @@ event: done data: {"message_id": "msg_003", "token_count": 200} ``` +**工具调用示例:** + +``` +HTTP/1.1 200 OK +Content-Type: text/event-stream + +event: thinking +data: {"content": "用户想知道北京天气..."} + +event: tool_calls +data: {"calls": [{"id": "call_001", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\": \"北京\"}"}}]} + +event: tool_result +data: {"name": "get_weather", "content": "{\"temperature\": 25, \"humidity\": 60, \"description\": \"晴天\"}"} + +event: message +data: {"content": "北京"} + +event: message +data: {"content": "今天天气晴朗,"} + +event: message +data: {"content": "温度25°C,"} + +event: message +data: {"content": "湿度60%"} + +event: done +data: {"message_id": "msg_003", "token_count": 150} +``` + **非流式响应 (stream=false):** ```json @@ -280,18 +326,162 @@ DELETE /api/conversations/:id/messages/:message_id --- -### 3. SSE 事件说明 +### 3. 模型与工具 -| 事件 | 说明 | -| ---------- | ------------------------------- | -| `thinking` | 思维链增量内容(启用时) | -| `message` | 回复内容的增量片段 | -| `error` | 错误信息 | -| `done` | 回复结束,携带完整 message_id 和 token 统计 | +#### 获取模型列表 + +``` +GET /api/models +``` + +**响应:** + +```json +{ + "code": 0, + "data": ["glm-5", "glm-4", "glm-3-turbo"] +} +``` + +#### 获取工具列表 + +``` +GET /api/tools +``` + +**响应:** + +```json +{ + "code": 0, + "data": { + "tools": [ + { + "name": "get_weather", + "description": "获取指定城市的天气信息", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "城市名称" + } + }, + "required": ["city"] + } + } + ], + "total": 1 + } +} +``` --- -### 4. 错误码 +### 4. 统计信息 + +#### 获取 Token 使用统计 + +``` +GET /api/stats/tokens?period=daily +``` + +**参数:** + +| 参数 | 类型 | 说明 | +| -------- | ------ | ------------------------------------- | +| `period` | string | 统计周期:`daily`(今日)、`weekly`(近7天)、`monthly`(近30天) | + +**响应(daily):** + +```json +{ + "code": 0, + "data": { + "period": "daily", + "date": "2026-03-24", + "prompt_tokens": 1000, + "completion_tokens": 2000, + "total_tokens": 3000, + "by_model": { + "glm-5": { + "prompt": 500, + "completion": 1000, + "total": 1500 + }, + "glm-4": { + "prompt": 500, + "completion": 1000, + "total": 1500 + } + } + } +} +``` + +**响应(weekly/monthly):** + +```json +{ + "code": 0, + "data": { + "period": "weekly", + "start_date": "2026-03-18", + "end_date": "2026-03-24", + "prompt_tokens": 7000, + "completion_tokens": 14000, + "total_tokens": 21000, + "daily": { + "2026-03-18": {"prompt": 1000, "completion": 2000, "total": 3000}, + "2026-03-19": {"prompt": 1000, "completion": 2000, "total": 3000}, + ... + } + } +} +``` + +--- + +### 5. SSE 事件说明 + +| 事件 | 说明 | +| ------------- | ---------------------------------------- | +| `thinking` | 思维链增量内容(启用时) | +| `message` | 回复内容的增量片段 | +| `tool_calls` | 工具调用信息,包含工具名称和参数 | +| `tool_result` | 工具执行结果,包含工具名称和返回内容 | +| `error` | 错误信息 | +| `done` | 回复结束,携带完整 message_id 和 token 统计 | + +**tool_calls 事件数据格式:** + +```json +{ + "calls": [ + { + "id": "call_001", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"city\": \"北京\"}" + } + } + ] +} +``` + +**tool_result 事件数据格式:** + +```json +{ + "name": "get_weather", + "content": "{\"temperature\": 25, \"humidity\": 60}" +} +``` + +--- + +### 6. 错误码 | code | 说明 | | ----- | -------- | @@ -344,8 +534,81 @@ User 1 ── * Conversation 1 ── * Message | ------------------ | ------------- | ------------------------------- | | `id` | string (UUID) | 消息 ID | | `conversation_id` | string | 所属会话 ID | -| `role` | enum | `user` / `assistant` / `system` | +| `role` | enum | `user` / `assistant` / `system` / `tool` | | `content` | string | 消息内容 | | `token_count` | integer | token 消耗数 | | `thinking_content` | string | 思维链内容(启用时) | +| `tool_calls` | string (JSON) | 工具调用请求(assistant 消息) | +| `tool_call_id` | string | 工具调用 ID(tool 消息) | +| `name` | string | 工具名称(tool 消息) | | `created_at` | datetime | 创建时间 | + +#### 消息类型说明 + +**1. 用户消息 (role=user)** +```json +{ + "id": "msg_001", + "role": "user", + "content": "北京今天天气怎么样?", + "created_at": "2026-03-24T10:00:00Z" +} +``` + +**2. 助手消息 - 普通回复 (role=assistant)** +```json +{ + "id": "msg_002", + "role": "assistant", + "content": "北京今天天气晴朗...", + "token_count": 50, + "thinking_content": "用户想了解天气...", + "created_at": "2026-03-24T10:00:01Z" +} +``` + +**3. 助手消息 - 工具调用 (role=assistant, with tool_calls)** +```json +{ + "id": "msg_003", + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"city\": \"北京\"}" + } + } + ], + "created_at": "2026-03-24T10:00:01Z" +} +``` + +**4. 工具返回消息 (role=tool)** +```json +{ + "id": "msg_004", + "role": "tool", + "content": "{\"temperature\": 25, \"humidity\": 60, \"description\": \"晴天\"}", + "tool_call_id": "call_abc123", + "name": "get_weather", + "created_at": "2026-03-24T10:00:02Z" +} +``` + +#### 工具调用流程示例 + +``` +用户: "北京今天天气怎么样?" + ↓ +[msg_001] role=user, content="北京今天天气怎么样?" + ↓ +[msg_002] role=assistant, tool_calls=[{get_weather, args:{"city":"北京"}}] + ↓ +[msg_003] role=tool, name=get_weather, content="{...weather data...}" + ↓ +[msg_004] role=assistant, content="北京今天天气晴朗,温度25°C..." +``` diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 6607e53..6486b95 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -18,12 +18,15 @@ :streaming="streaming" :streaming-content="streamContent" :streaming-thinking="streamThinking" + :streaming-tool-calls="streamToolCalls" :has-more-messages="hasMoreMessages" :loading-more="loadingMessages" + :tools-enabled="toolsEnabled" @send-message="sendMessage" @delete-message="deleteMessage" @toggle-settings="showSettings = true" @load-more-messages="loadMoreMessages" + @toggle-tools="updateToolsEnabled" /> conversations.value.find(c => c.id === currentConvId.value) || null @@ -122,9 +127,10 @@ async function loadMessages(reset = true) { try { const res = await messageApi.list(currentConvId.value, reset ? null : nextMsgCursor.value) if (reset) { - messages.value = res.data.items + // Filter out tool messages (they're merged into assistant messages) + messages.value = res.data.items.filter(m => m.role !== 'tool') } else { - messages.value = [...res.data.items, ...messages.value] + messages.value = [...res.data.items.filter(m => m.role !== 'tool'), ...messages.value] } nextMsgCursor.value = res.data.next_cursor hasMoreMessages.value = res.data.has_more @@ -158,15 +164,34 @@ async function sendMessage(content) { streaming.value = true streamContent.value = '' streamThinking.value = '' + streamToolCalls.value = [] await messageApi.send(currentConvId.value, content, { stream: true, + toolsEnabled: toolsEnabled.value, onThinking(text) { streamThinking.value += text }, onMessage(text) { streamContent.value += text }, + onToolCalls(calls) { + console.log('🔧 Tool calls received:', calls) + streamToolCalls.value = calls + }, + onToolResult(result) { + console.log('✅ Tool result received:', result) + // 更新工具调用结果 + const call = streamToolCalls.value.find(c => c.function?.name === result.name) + if (call) { + call.result = result.content + } else { + // 如果找不到,添加到第一个调用(兜底处理) + if (streamToolCalls.value.length > 0) { + streamToolCalls.value[0].result = result.content + } + } + }, async onDone(data) { streaming.value = false // Replace temp message and add assistant message from server @@ -178,6 +203,7 @@ async function sendMessage(content) { content: streamContent.value, token_count: data.token_count, thinking_content: streamThinking.value || null, + tool_calls: streamToolCalls.value.length > 0 ? streamToolCalls.value : null, created_at: new Date().toISOString(), }) streamContent.value = '' @@ -251,6 +277,12 @@ async function saveSettings(data) { } } +// -- Update tools enabled -- +function updateToolsEnabled(val) { + toolsEnabled.value = val + localStorage.setItem('tools_enabled', String(val)) +} + // -- Init -- onMounted(() => { loadConversations() diff --git a/frontend/src/api/index.js b/frontend/src/api/index.js index cee4a4a..9a6da0d 100644 --- a/frontend/src/api/index.js +++ b/frontend/src/api/index.js @@ -64,11 +64,11 @@ export const messageApi = { return request(`/conversations/${convId}/messages?${params}`) }, - send(convId, content, { stream = true, onThinking, onMessage, onDone, onError } = {}) { + send(convId, content, { stream = true, toolsEnabled = true, onThinking, onMessage, onToolCalls, onToolResult, onDone, onError } = {}) { if (!stream) { return request(`/conversations/${convId}/messages`, { method: 'POST', - body: { content, stream: false }, + body: { content, stream: false, tools_enabled: toolsEnabled }, }) } @@ -79,7 +79,7 @@ export const messageApi = { const res = await fetch(`${BASE}/conversations/${convId}/messages`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ content, stream: true }), + body: JSON.stringify({ content, stream: true, tools_enabled: toolsEnabled }), signal: controller.signal, }) @@ -110,6 +110,10 @@ export const messageApi = { onThinking(data.content) } else if (currentEvent === 'message' && onMessage) { onMessage(data.content) + } else if (currentEvent === 'tool_calls' && onToolCalls) { + onToolCalls(data.calls) + } else if (currentEvent === 'tool_result' && onToolResult) { + onToolResult(data) } else if (currentEvent === 'done' && onDone) { onDone(data) } else if (currentEvent === 'error' && onError) { diff --git a/frontend/src/components/ChatView.vue b/frontend/src/components/ChatView.vue index e10a5e9..c00e7b7 100644 --- a/frontend/src/components/ChatView.vue +++ b/frontend/src/components/ChatView.vue @@ -37,6 +37,8 @@ :role="msg.role" :content="msg.content" :thinking-content="msg.thinking_content" + :tool-calls="msg.tool_calls" + :tool-name="msg.name" :token-count="msg.token_count" :created-at="msg.created_at" :deletable="msg.role === 'user'" @@ -46,9 +48,11 @@
claw
-
- {{ streamingThinking }} -
+
@@ -58,7 +62,9 @@ @@ -68,6 +74,7 @@ import { ref, computed, watch, nextTick } from 'vue' import MessageBubble from './MessageBubble.vue' import MessageInput from './MessageInput.vue' +import ProcessBlock from './ProcessBlock.vue' import { renderMarkdown } from '../utils/markdown' const props = defineProps({ @@ -76,11 +83,13 @@ const props = defineProps({ streaming: { type: Boolean, default: false }, streamingContent: { type: String, default: '' }, streamingThinking: { type: String, default: '' }, + streamingToolCalls: { type: Array, default: () => [] }, hasMoreMessages: { type: Boolean, default: false }, loadingMore: { type: Boolean, default: false }, + toolsEnabled: { type: Boolean, default: true }, }) -defineEmits(['sendMessage', 'deleteMessage', 'toggleSettings', 'loadMoreMessages']) +defineEmits(['sendMessage', 'deleteMessage', 'toggleSettings', 'loadMoreMessages', 'toggleTools']) const scrollContainer = ref(null) const inputRef = ref(null) @@ -296,18 +305,6 @@ defineExpose({ scrollToBottom }) color: white; } -.streaming-thinking { - font-size: 13px; - color: var(--text-secondary); - line-height: 1.6; - white-space: pre-wrap; - padding: 12px; - background: var(--bg-thinking); - border-radius: 8px; - border: 1px solid var(--border-light); - margin-bottom: 8px; -} - .streaming-content { font-size: 15px; line-height: 1.7; diff --git a/frontend/src/components/MessageBubble.vue b/frontend/src/components/MessageBubble.vue index 4ebeb44..a235a79 100644 --- a/frontend/src/components/MessageBubble.vue +++ b/frontend/src/components/MessageBubble.vue @@ -3,19 +3,16 @@
user
claw
-
- -
{{ thinkingContent }}
+ +
+
工具返回结果: {{ toolName }}
+
{{ content }}
-
+