From 6aea98554fcd17ea6712f44b381a0f9d94c8f08a Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 28 Mar 2026 00:42:59 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=20=E4=BF=AE=E6=94=B9chat=20?= =?UTF-8?q?=E4=B8=BB=E6=B5=81=E7=A8=8B=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.py | 6 + backend/services/chat.py | 524 ++++++++++++++++++++------------- backend/services/llm_client.py | 24 +- docs/Design.md | 93 ++++++ 4 files changed, 432 insertions(+), 215 deletions(-) diff --git a/backend/config.py b/backend/config.py index 937071c..232c1ee 100644 --- a/backend/config.py +++ b/backend/config.py @@ -33,3 +33,9 @@ if MODELS and not DEFAULT_MODEL: # Max agentic loop iterations (tool call rounds) MAX_ITERATIONS = _cfg.get("max_iterations", 5) + +# Max parallel workers for tool execution (ThreadPoolExecutor) +TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4) + +# Max character length for a single tool result content (truncated if exceeded) +TOOL_RESULT_MAX_LENGTH = _cfg.get("tool_result_max_length", 4096) diff --git a/backend/services/chat.py b/backend/services/chat.py index c3cd311..3d85987 100644 --- a/backend/services/chat.py +++ b/backend/services/chat.py @@ -1,8 +1,11 @@ """Chat completion service""" import json +import logging import uuid -from flask import current_app, g, Response, request as flask_request +from typing import Optional, Union +from flask import current_app, Response, request as flask_request from werkzeug.exceptions import ClientDisconnected +import requests from backend import db from backend.models import Conversation, Message from backend.tools import registry, ToolExecutor @@ -11,14 +14,15 @@ from backend.utils.helpers import ( build_messages, ) from backend.services.llm_client import LLMClient -from backend.config import MAX_ITERATIONS +from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS, TOOL_RESULT_MAX_LENGTH + +logger = logging.getLogger(__name__) def _client_disconnected(): """Check if the client has disconnected.""" try: stream = flask_request.input_stream - # If input_stream is unavailable, assume still connected if stream is None: return False return stream.closed @@ -26,151 +30,111 @@ def _client_disconnected(): return False +def _sse_event(event: str, data: dict) -> str: + """Format a Server-Sent Event string.""" + return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + class ChatService: """Chat completion service with tool support""" def __init__(self, llm: LLMClient): self.llm = llm + def stream_response( + self, + conv: Conversation, + tools_enabled: bool = True, + project_id: str = None, + tool_choice: Optional[Union[str, dict]] = None, + ): + """Stream response with tool call support. - def stream_response(self, conv: Conversation, tools_enabled: bool = True, project_id: str = None): - """Stream response with tool call support - Uses 'process_step' events to send thinking and tool calls in order, allowing them to be interleaved properly in the frontend. - + Args: conv: Conversation object tools_enabled: Whether to enable tools project_id: Project ID for workspace isolation + tool_choice: Optional tool_choice override (e.g. "auto", "required", or dict) """ conv_id = conv.id conv_model = conv.model app = current_app._get_current_object() tools = registry.list_all() if tools_enabled else None initial_messages = build_messages(conv, project_id) - - # Create per-request executor for thread-safe isolation. - # Each request gets its own _call_history and _cache, eliminating - # race conditions when multiple conversations stream concurrently. + executor = ToolExecutor(registry=registry) - - # Build context for tool execution + context = {"model": conv_model} if project_id: context["project_id"] = project_id elif conv.project_id: context["project_id"] = conv.project_id - + def generate(): messages = list(initial_messages) all_tool_calls = [] all_tool_results = [] - all_steps = [] # Collect all ordered steps for DB storage (thinking/text/tool_call/tool_result) - step_index = 0 # Track global step index for ordering - total_completion_tokens = 0 # Accumulated across all iterations - prompt_tokens = 0 # Not accumulated — last iteration's value is sufficient - # (each iteration re-sends the full context, so earlier - # prompts are strict subsets of the final one) + all_steps = [] + step_index = 0 + total_completion_tokens = 0 + total_prompt_tokens = 0 for iteration in range(MAX_ITERATIONS): - full_content = "" - full_thinking = "" - token_count = 0 - msg_id = str(uuid.uuid4()) - tool_calls_list = [] - - # Streaming step tracking — step ID is assigned on first chunk arrival. - # thinking always precedes text in GLM's streaming order, so text gets step_index+1. - thinking_step_id = None - thinking_step_idx = None - text_step_id = None - text_step_idx = None - try: - with app.app_context(): - active_conv = db.session.get(Conversation, conv_id) - resp = self.llm.call( - model=active_conv.model, - messages=messages, - max_tokens=active_conv.max_tokens, - temperature=active_conv.temperature, - thinking_enabled=active_conv.thinking_enabled, - tools=tools, - stream=True, - ) - resp.raise_for_status() - - # Stream LLM response chunk by chunk - for line in resp.iter_lines(): - # Early exit if client has disconnected - if _client_disconnected(): - resp.close() - return - - if not line: - continue - line = line.decode("utf-8") - if not line.startswith("data: "): - continue - data_str = line[6:] - if data_str == "[DONE]": - break - try: - chunk = json.loads(data_str) - except json.JSONDecodeError: - continue - - # Extract usage first (present in last chunk when stream_options is set) - usage = chunk.get("usage", {}) - if usage: - token_count = usage.get("completion_tokens", 0) - prompt_tokens = usage.get("prompt_tokens", 0) - - choices = chunk.get("choices", []) - if not choices: - continue - - delta = choices[0].get("delta", {}) - - # Accumulate thinking content for this iteration - reasoning = delta.get("reasoning_content", "") - if reasoning: - full_thinking += reasoning - if thinking_step_id is None: - thinking_step_id = f'step-{step_index}' - thinking_step_idx = step_index - yield f"event: process_step\ndata: {json.dumps({'id': thinking_step_id, 'index': thinking_step_idx, 'type': 'thinking', 'content': full_thinking}, ensure_ascii=False)}\n\n" - - # Accumulate text content for this iteration - text = delta.get("content", "") - if text: - full_content += text - if text_step_id is None: - text_step_idx = step_index + (1 if thinking_step_id is not None else 0) - text_step_id = f'step-{text_step_idx}' - yield f"event: process_step\ndata: {json.dumps({'id': text_step_id, 'index': text_step_idx, 'type': 'text', 'content': full_content}, ensure_ascii=False)}\n\n" - - # Accumulate tool calls from streaming deltas - tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list) - + stream_result = self._stream_llm_response( + app, conv_id, messages, tools, tool_choice, step_index + ) + except requests.exceptions.HTTPError as e: + resp = e.response + if resp is not None and resp.status_code >= 500: + yield _sse_event("error", {"content": f"LLM service unavailable ({resp.status_code})"}) + elif resp is not None and resp.status_code == 429: + yield _sse_event("error", {"content": "Rate limit exceeded, please try again later"}) + else: + yield _sse_event("error", {"content": f"LLM request failed: {e}"}) + return + except requests.exceptions.ConnectionError: + yield _sse_event("error", {"content": "Unable to connect to LLM service"}) + return + except requests.exceptions.Timeout: + yield _sse_event("error", {"content": "LLM request timed out"}) + return except Exception as e: - yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n" + logger.exception("Unexpected error during LLM streaming") + yield _sse_event("error", {"content": f"Internal error: {e}"}) return - # --- Finalize: save thinking/text steps to all_steps for DB storage --- - # No need to yield to frontend — incremental process_step events already sent. + if stream_result is None: + return # Client disconnected + + full_content, full_thinking, tool_calls_list, \ + thinking_step_id, thinking_step_idx, \ + text_step_id, text_step_idx, \ + completion_tokens, prompt_tokens, \ + sse_chunks = stream_result + + total_prompt_tokens += prompt_tokens + total_completion_tokens += completion_tokens + + # Yield accumulated SSE chunks to frontend + for chunk in sse_chunks: + yield chunk + + # Save thinking/text steps to all_steps for DB storage if thinking_step_id is not None: all_steps.append({ - 'id': thinking_step_id, 'index': thinking_step_idx, - 'type': 'thinking', 'content': full_thinking, + "id": thinking_step_id, "index": thinking_step_idx, + "type": "thinking", "content": full_thinking, }) step_index += 1 if text_step_id is not None: all_steps.append({ - 'id': text_step_id, 'index': text_step_idx, - 'type': 'text', 'content': full_content, + "id": text_step_id, "index": text_step_idx, + "type": "text", "content": full_content, }) step_index += 1 @@ -178,127 +142,79 @@ class ChatService: if tool_calls_list: all_tool_calls.extend(tool_calls_list) - # Phase 1: emit all tool_call steps (before execution) + # Emit tool_call steps (before execution) for tc in tool_calls_list: call_step = { - 'id': f'step-{step_index}', - 'index': step_index, - 'type': 'tool_call', - 'id_ref': tc['id'], - 'name': tc['function']['name'], - 'arguments': tc['function']['arguments'], + "id": f"step-{step_index}", + "index": step_index, + "type": "tool_call", + "id_ref": tc["id"], + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], } all_steps.append(call_step) - yield f"event: process_step\ndata: {json.dumps(call_step, ensure_ascii=False)}\n\n" + yield _sse_event("process_step", call_step) step_index += 1 - # Phase 2: execute tools — parallel when multiple, sequential when single - if len(tool_calls_list) > 1: - with app.app_context(): - tool_results = executor.process_tool_calls_parallel( - tool_calls_list, context, max_workers=4 - ) - else: - with app.app_context(): - tool_results = executor.process_tool_calls( - tool_calls_list, context - ) + # Execute tools with error wrapping + tool_results = self._execute_tools_safe( + app, executor, tool_calls_list, context + ) - # Phase 3: emit all tool_result steps (after execution, same order) + # Emit tool_result steps for tr in tool_results: + skipped = False try: result_content = json.loads(tr["content"]) skipped = result_content.get("skipped", False) except Exception: skipped = False result_step = { - 'id': f'step-{step_index}', - 'index': step_index, - 'type': 'tool_result', - 'id_ref': tr['tool_call_id'], - 'name': tr['name'], - 'content': tr['content'], - 'skipped': skipped, + "id": f"step-{step_index}", + "index": step_index, + "type": "tool_result", + "id_ref": tr["tool_call_id"], + "name": tr["name"], + "content": tr["content"], + "skipped": skipped, } all_steps.append(result_step) - yield f"event: process_step\ndata: {json.dumps(result_step, ensure_ascii=False)}\n\n" + yield _sse_event("process_step", result_step) step_index += 1 # Append assistant message + tool results for the next iteration messages.append({ "role": "assistant", "content": full_content or None, - "tool_calls": tool_calls_list + "tool_calls": tool_calls_list, }) messages.extend(tool_results) all_tool_results.extend(tool_results) - total_completion_tokens += token_count continue # --- No tool calls: final iteration — save message to DB --- - suggested_title = None - # prompt_tokens already holds the last iteration's value (set during streaming) - total_completion_tokens += token_count - with app.app_context(): - # Build content JSON with ordered steps array for DB storage. - # 'steps' is the single source of truth for rendering order. - content_json = { - "text": full_content, - } - if all_tool_calls: - content_json["tool_calls"] = self._build_tool_calls_json(all_tool_calls, all_tool_results) - # Store ordered steps — the single source of truth for rendering order - content_json["steps"] = all_steps + msg_id = str(uuid.uuid4()) + suggested_title = self._save_message( + app, conv_id, conv_model, msg_id, + full_content, all_tool_calls, all_tool_results, + all_steps, total_prompt_tokens, total_completion_tokens, + ) - msg = Message( - id=msg_id, - conversation_id=conv_id, - role="assistant", - content=json.dumps(content_json, ensure_ascii=False), - token_count=total_completion_tokens, - ) - db.session.add(msg) - db.session.commit() - - # Auto-generate title from first user message if needed - conv = db.session.get(Conversation, conv_id) - - # Record token usage (get user_id from conv, not g — - # app.app_context() creates a new context where g.current_user is lost) - if conv: - record_token_usage(conv.user_id, conv_model, prompt_tokens, total_completion_tokens) - - if conv and (not conv.title or conv.title == "新对话"): - user_msg = Message.query.filter_by( - conversation_id=conv_id, role="user" - ).order_by(Message.created_at.asc()).first() - if user_msg and user_msg.content: - try: - content_data = json.loads(user_msg.content) - title_text = content_data.get("text", "")[:30] - except (json.JSONDecodeError, TypeError): - title_text = user_msg.content.strip()[:30] - if title_text: - suggested_title = title_text - else: - suggested_title = "新对话" - db.session.refresh(conv) - conv.title = suggested_title - db.session.commit() - else: - suggested_title = None - - yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': total_completion_tokens, 'suggested_title': suggested_title}, ensure_ascii=False)}\n\n" + yield _sse_event("done", { + "message_id": msg_id, + "token_count": total_completion_tokens, + "suggested_title": suggested_title, + }) return - - yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n" - + + yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) + def safe_generate(): """Wrapper that catches client disconnection during yield.""" try: yield from generate() except (ClientDisconnected, BrokenPipeError, ConnectionResetError): - pass # Client aborted, silently stop + pass return Response( safe_generate(), @@ -308,16 +224,224 @@ class ChatService: "X-Accel-Buffering": "no", "Connection": "keep-alive", "Transfer-Encoding": "chunked", - } + }, ) - + + # ------------------------------------------------------------------ + # Private helpers — extracted for testability and readability + # ------------------------------------------------------------------ + + def _stream_llm_response( + self, app, conv_id, messages, tools, tool_choice, step_index + ): + """Call LLM streaming API and parse the response. + + Returns a tuple of parsed results, or None if the client disconnected. + Raises HTTPError / ConnectionError / Timeout for the caller to handle. + """ + full_content = "" + full_thinking = "" + token_count = 0 + prompt_tokens = 0 + tool_calls_list = [] + + thinking_step_id = None + thinking_step_idx = None + text_step_id = None + text_step_idx = None + + sse_chunks = [] # Collect SSE events to yield later + + with app.app_context(): + active_conv = db.session.get(Conversation, conv_id) + resp = self.llm.call( + model=active_conv.model, + messages=messages, + max_tokens=active_conv.max_tokens, + temperature=active_conv.temperature, + thinking_enabled=active_conv.thinking_enabled, + tools=tools, + tool_choice=tool_choice, + stream=True, + ) + resp.raise_for_status() + + for line in resp.iter_lines(): + if _client_disconnected(): + resp.close() + return None + + if not line: + continue + line = line.decode("utf-8") + if not line.startswith("data: "): + continue + data_str = line[6:] + if data_str == "[DONE]": + break + try: + chunk = json.loads(data_str) + except json.JSONDecodeError: + continue + + usage = chunk.get("usage", {}) + if usage: + token_count = usage.get("completion_tokens", 0) + prompt_tokens = usage.get("prompt_tokens", 0) + + choices = chunk.get("choices", []) + if not choices: + continue + + delta = choices[0].get("delta", {}) + + reasoning = delta.get("reasoning_content", "") + if reasoning: + full_thinking += reasoning + if thinking_step_id is None: + thinking_step_id = f"step-{step_index}" + thinking_step_idx = step_index + sse_chunks.append(_sse_event("process_step", { + "id": thinking_step_id, "index": thinking_step_idx, + "type": "thinking", "content": full_thinking, + })) + + text = delta.get("content", "") + if text: + full_content += text + if text_step_id is None: + text_step_idx = step_index + (1 if thinking_step_id is not None else 0) + text_step_id = f"step-{text_step_idx}" + sse_chunks.append(_sse_event("process_step", { + "id": text_step_id, "index": text_step_idx, + "type": "text", "content": full_content, + })) + + tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list) + + return ( + full_content, full_thinking, tool_calls_list, + thinking_step_id, thinking_step_idx, + text_step_id, text_step_idx, + token_count, prompt_tokens, + sse_chunks, + ) + + def _execute_tools_safe(self, app, executor, tool_calls_list, context): + """Execute tool calls with top-level error wrapping. + + If an unexpected exception occurs during tool execution, it is + converted into error tool results instead of crashing the stream. + """ + try: + if len(tool_calls_list) > 1: + with app.app_context(): + tool_results = executor.process_tool_calls_parallel( + tool_calls_list, context, max_workers=TOOL_MAX_WORKERS + ) + else: + with app.app_context(): + tool_results = executor.process_tool_calls( + tool_calls_list, context + ) + except Exception as e: + logger.exception("Error during tool execution") + tool_results = [ + { + "role": "tool", + "tool_call_id": tc["id"], + "name": tc["function"]["name"], + "content": json.dumps({ + "success": False, + "error": f"Tool execution failed: {e}", + }, ensure_ascii=False), + } + for tc in tool_calls_list + ] + + # Truncate oversized tool result content + for tr in tool_results: + if len(tr["content"]) > TOOL_RESULT_MAX_LENGTH: + try: + result_data = json.loads(tr["content"]) + original = result_data + except (json.JSONDecodeError, TypeError): + original = None + + tr["content"] = json.dumps( + {"success": False, "error": "Tool result too large, truncated"}, + ensure_ascii=False, + ) if not original else json.dumps( + { + **original, + "truncated": True, + "_note": f"Content truncated, original length {len(tr['content'])} chars", + }, + ensure_ascii=False, + default=str, + )[:TOOL_RESULT_MAX_LENGTH] + + return tool_results + + def _save_message( + self, app, conv_id, conv_model, msg_id, + full_content, all_tool_calls, all_tool_results, + all_steps, total_prompt_tokens, total_completion_tokens, + ): + """Save the final assistant message and auto-generate title if needed. + + Returns the suggested_title or None. + """ + suggested_title = None + with app.app_context(): + content_json = {"text": full_content} + if all_tool_calls: + content_json["tool_calls"] = self._build_tool_calls_json( + all_tool_calls, all_tool_results + ) + content_json["steps"] = all_steps + + msg = Message( + id=msg_id, + conversation_id=conv_id, + role="assistant", + content=json.dumps(content_json, ensure_ascii=False), + token_count=total_completion_tokens, + ) + db.session.add(msg) + db.session.commit() + + conv = db.session.get(Conversation, conv_id) + + if conv: + record_token_usage( + conv.user_id, conv_model, + total_prompt_tokens, total_completion_tokens, + ) + + if conv and (not conv.title or conv.title == "新对话"): + user_msg = Message.query.filter_by( + conversation_id=conv_id, role="user" + ).order_by(Message.created_at.asc()).first() + if user_msg and user_msg.content: + try: + content_data = json.loads(user_msg.content) + title_text = content_data.get("text", "")[:30] + except (json.JSONDecodeError, TypeError): + title_text = user_msg.content.strip()[:30] + suggested_title = title_text or "新对话" + db.session.refresh(conv) + conv.title = suggested_title + db.session.commit() + + return suggested_title + def _build_tool_calls_json(self, tool_calls: list, tool_results: list) -> list: - """Build tool calls JSON structure - matches streaming format""" + """Build tool calls JSON structure - matches streaming format.""" result = [] for i, tc in enumerate(tool_calls): result_content = tool_results[i]["content"] if i < len(tool_results) else None - # Parse result to extract success/skipped status success = True skipped = False execution_time = 0 @@ -327,10 +451,9 @@ class ChatService: success = result_data.get("success", True) skipped = result_data.get("skipped", False) execution_time = result_data.get("execution_time", 0) - except: + except (json.JSONDecodeError, TypeError): pass - # Keep same structure as streaming format result.append({ "id": tc.get("id", ""), "type": tc.get("type", "function"), @@ -345,9 +468,8 @@ class ChatService: }) return result - def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list: - """Process tool calls from streaming delta""" + """Process tool calls from streaming delta.""" tool_calls_delta = delta.get("tool_calls", []) for tc in tool_calls_delta: idx = tc.get("index", 0) @@ -355,7 +477,7 @@ class ChatService: tool_calls_list.append({ "id": tc.get("id", ""), "type": tc.get("type", "function"), - "function": {"name": "", "arguments": ""} + "function": {"name": "", "arguments": ""}, }) if tc.get("id"): tool_calls_list[idx]["id"] = tc["id"] diff --git a/backend/services/llm_client.py b/backend/services/llm_client.py index fbbaa20..267a488 100644 --- a/backend/services/llm_client.py +++ b/backend/services/llm_client.py @@ -9,7 +9,7 @@ import os import re import time import requests -from typing import Optional, List +from typing import Optional, List, Union def _resolve_env_vars(value: str) -> str: @@ -59,7 +59,8 @@ class LLMClient: raise ValueError(f"Model '{model}' has no api_key configured") return api_url, api_key - def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled, tools, stream, api_url): + def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled, + tools, tool_choice, stream, api_url): """Build request body with provider-specific parameter adaptation.""" provider = _detect_provider(api_url) @@ -79,23 +80,17 @@ class LLMClient: # --- Provider-specific: thinking --- if thinking_enabled: - if provider == "glm": + if provider == "glm" or provider == "deepseek": body["thinking"] = {"type": "enabled"} - elif provider == "deepseek": - pass # deepseek-reasoner has built-in reasoning, no extra param + else: + raise NotImplementedError(f"Thinking not supported for provider '{provider}'") - # --- Provider-specific: tools --- if tools: body["tools"] = tools - body["tool_choice"] = "auto" + body["tool_choice"] = tool_choice if tool_choice is not None else "auto" - # --- Provider-specific: stream --- if stream: body["stream"] = True - if provider == "glm": - body["stream_options"] = {"include_usage": True} - elif provider == "deepseek": - pass # DeepSeek does not support stream_options return body @@ -107,15 +102,16 @@ class LLMClient: temperature: float = 1.0, thinking_enabled: bool = False, tools: Optional[List[dict]] = None, + tool_choice: Optional[Union[str, dict]] = None, stream: bool = False, - timeout: int = 120, + timeout: int = 200, max_retries: int = 3, ): """Call LLM API with retry on rate limit (429)""" api_url, api_key = self._get_credentials(model) body = self._build_body( model, messages, max_tokens, temperature, - thinking_enabled, tools, stream, api_url, + thinking_enabled, tools, tool_choice, stream, api_url, ) for attempt in range(max_retries + 1): diff --git a/docs/Design.md b/docs/Design.md index c740a16..f2e8ca8 100644 --- a/docs/Design.md +++ b/docs/Design.md @@ -638,6 +638,99 @@ buffer 拼接: "event: process_step\ndata: {\"id\":\"step-0\",...}\n\n" --- +## Token 用量计算 + +### 术语定义 + +| 术语 | 说明 | +| --- | --- | +| `prompt_tokens` | 发给模型的输入 token 数量(包括 system prompt、历史消息、工具定义、工具结果等全部上下文) | +| `completion_tokens` | 模型生成的输出 token 数量(包括 thinking 内容、正文回复、工具调用 JSON) | +| `total_tokens` | `prompt_tokens + completion_tokens` | + +### 计算流程 + +一次完整的对话可能经历多轮工具调用迭代,每轮都会向 LLM 发送请求并收到响应。Token 用量计算分为三个阶段: + +```mermaid +flowchart LR + A[LLM SSE Stream] -->|usage 字段| B["_stream_llm_response()"] + B -->|每轮累加| C["generate() 循环"] + C -->|最终值| D["_save_message()"] + D --> E["record_token_usage()"] + E --> F["TokenUsage 表"] +``` + +#### 1. 流式解析 — 从 SSE chunks 中提取 + +LLM API 在流的最后一个 chunk 中返回 `usage` 字段(需要在请求中设置 `stream_options` 才有,否则为空): + +```python +# chat.py: _stream_llm_response() +usage = chunk.get("usage", {}) +if usage: + token_count = usage.get("completion_tokens", 0) # 本轮输出 token + prompt_tokens = usage.get("prompt_tokens", 0) # 本轮输入 token +``` + +#### 2. 迭代累加 — generate() 循环 + +每轮迭代结束后,将本轮的 prompt 和 completion token 累加到总计: + +```python +# chat.py: generate() +total_prompt_tokens += prompt_tokens # 累加每轮 prompt +total_completion_tokens += completion_tokens # 累加每轮 completion +``` + +#### 3. 记录到数据库 + +最终调用 `record_token_usage()` 写入 TokenUsage 表,同时 Message 表也记录 completion token: + +```python +# chat.py: _save_message() +msg = Message(token_count=total_completion_tokens) # Message 表仅记录 completion +record_token_usage(user_id, model, total_prompt_tokens, total_completion_tokens) +``` + +### 多轮迭代示例 + +一次涉及工具调用的对话(如:用户提问 → LLM 调用搜索 → LLM 生成回复): + +``` +迭代 1: prompt=800, completion=150 (LLM 决定调用 web_search) +迭代 2: prompt=1500, completion=300 (LLM 根据搜索结果生成最终回复) + +───────────────────────────────────────── +累加结果: + total_prompt_tokens = 800 + 1500 = 2300 + total_completion_tokens = 150 + 300 = 450 +───────────────────────────────────────── +``` + +> **注意**:`prompt_tokens` 的累加意味着存在重复计算 — 第 2 轮的 prompt 包含了第 1 轮的上下文,累加后 `total_prompt_tokens` 大于本次对话的真实输入 token 总量(历史部分被多次计算)。这是因为每轮请求是独立的 API 调用,各自计费。如果需要精确的单次对话输入 token,可以只取最后一轮的 `prompt_tokens`。 + +### 存储位置 + +| 位置 | 存什么 | 粒度 | +| --- | --- | --- | +| `Message.token_count` | `total_completion_tokens`(仅输出) | 单条消息 | +| `TokenUsage` 表 | `prompt_tokens` + `completion_tokens` + `total_tokens` | 按 user + 日期 + model 聚合 | + +`TokenUsage` 按 **user_id + 日期 + model** 维度聚合,同一天同一模型的多次对话会累加到同一条记录: + +```python +# helpers.py: record_token_usage() +if existing: + existing.prompt_tokens += prompt_tokens + existing.completion_tokens += completion_tokens + existing.total_tokens += prompt_tokens + completion_tokens +else: + create new TokenUsage record +``` + +--- + ## 分页机制 所有列表接口使用**游标分页**: