refactor: 修改chat 主流程逻辑

This commit is contained in:
ViperEkura 2026-03-28 00:42:59 +08:00
parent 7bd19a7529
commit 6aea98554f
4 changed files with 432 additions and 215 deletions

View File

@ -33,3 +33,9 @@ if MODELS and not DEFAULT_MODEL:
# Max agentic loop iterations (tool call rounds) # Max agentic loop iterations (tool call rounds)
MAX_ITERATIONS = _cfg.get("max_iterations", 5) MAX_ITERATIONS = _cfg.get("max_iterations", 5)
# Max parallel workers for tool execution (ThreadPoolExecutor)
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
# Max character length for a single tool result content (truncated if exceeded)
TOOL_RESULT_MAX_LENGTH = _cfg.get("tool_result_max_length", 4096)

View File

@ -1,8 +1,11 @@
"""Chat completion service""" """Chat completion service"""
import json import json
import logging
import uuid import uuid
from flask import current_app, g, Response, request as flask_request from typing import Optional, Union
from flask import current_app, Response, request as flask_request
from werkzeug.exceptions import ClientDisconnected from werkzeug.exceptions import ClientDisconnected
import requests
from backend import db from backend import db
from backend.models import Conversation, Message from backend.models import Conversation, Message
from backend.tools import registry, ToolExecutor from backend.tools import registry, ToolExecutor
@ -11,14 +14,15 @@ from backend.utils.helpers import (
build_messages, build_messages,
) )
from backend.services.llm_client import LLMClient from backend.services.llm_client import LLMClient
from backend.config import MAX_ITERATIONS from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS, TOOL_RESULT_MAX_LENGTH
logger = logging.getLogger(__name__)
def _client_disconnected(): def _client_disconnected():
"""Check if the client has disconnected.""" """Check if the client has disconnected."""
try: try:
stream = flask_request.input_stream stream = flask_request.input_stream
# If input_stream is unavailable, assume still connected
if stream is None: if stream is None:
return False return False
return stream.closed return stream.closed
@ -26,15 +30,25 @@ def _client_disconnected():
return False return False
def _sse_event(event: str, data: dict) -> str:
"""Format a Server-Sent Event string."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
class ChatService: class ChatService:
"""Chat completion service with tool support""" """Chat completion service with tool support"""
def __init__(self, llm: LLMClient): def __init__(self, llm: LLMClient):
self.llm = llm self.llm = llm
def stream_response(
def stream_response(self, conv: Conversation, tools_enabled: bool = True, project_id: str = None): self,
"""Stream response with tool call support conv: Conversation,
tools_enabled: bool = True,
project_id: str = None,
tool_choice: Optional[Union[str, dict]] = None,
):
"""Stream response with tool call support.
Uses 'process_step' events to send thinking and tool calls in order, Uses 'process_step' events to send thinking and tool calls in order,
allowing them to be interleaved properly in the frontend. allowing them to be interleaved properly in the frontend.
@ -43,6 +57,7 @@ class ChatService:
conv: Conversation object conv: Conversation object
tools_enabled: Whether to enable tools tools_enabled: Whether to enable tools
project_id: Project ID for workspace isolation project_id: Project ID for workspace isolation
tool_choice: Optional tool_choice override (e.g. "auto", "required", or dict)
""" """
conv_id = conv.id conv_id = conv.id
conv_model = conv.model conv_model = conv.model
@ -50,12 +65,8 @@ class ChatService:
tools = registry.list_all() if tools_enabled else None tools = registry.list_all() if tools_enabled else None
initial_messages = build_messages(conv, project_id) initial_messages = build_messages(conv, project_id)
# Create per-request executor for thread-safe isolation.
# Each request gets its own _call_history and _cache, eliminating
# race conditions when multiple conversations stream concurrently.
executor = ToolExecutor(registry=registry) executor = ToolExecutor(registry=registry)
# Build context for tool execution
context = {"model": conv_model} context = {"model": conv_model}
if project_id: if project_id:
context["project_id"] = project_id context["project_id"] = project_id
@ -66,28 +77,181 @@ class ChatService:
messages = list(initial_messages) messages = list(initial_messages)
all_tool_calls = [] all_tool_calls = []
all_tool_results = [] all_tool_results = []
all_steps = [] # Collect all ordered steps for DB storage (thinking/text/tool_call/tool_result) all_steps = []
step_index = 0 # Track global step index for ordering step_index = 0
total_completion_tokens = 0 # Accumulated across all iterations total_completion_tokens = 0
prompt_tokens = 0 # Not accumulated — last iteration's value is sufficient total_prompt_tokens = 0
# (each iteration re-sends the full context, so earlier
# prompts are strict subsets of the final one)
for iteration in range(MAX_ITERATIONS): for iteration in range(MAX_ITERATIONS):
try:
stream_result = self._stream_llm_response(
app, conv_id, messages, tools, tool_choice, step_index
)
except requests.exceptions.HTTPError as e:
resp = e.response
if resp is not None and resp.status_code >= 500:
yield _sse_event("error", {"content": f"LLM service unavailable ({resp.status_code})"})
elif resp is not None and resp.status_code == 429:
yield _sse_event("error", {"content": "Rate limit exceeded, please try again later"})
else:
yield _sse_event("error", {"content": f"LLM request failed: {e}"})
return
except requests.exceptions.ConnectionError:
yield _sse_event("error", {"content": "Unable to connect to LLM service"})
return
except requests.exceptions.Timeout:
yield _sse_event("error", {"content": "LLM request timed out"})
return
except Exception as e:
logger.exception("Unexpected error during LLM streaming")
yield _sse_event("error", {"content": f"Internal error: {e}"})
return
if stream_result is None:
return # Client disconnected
full_content, full_thinking, tool_calls_list, \
thinking_step_id, thinking_step_idx, \
text_step_id, text_step_idx, \
completion_tokens, prompt_tokens, \
sse_chunks = stream_result
total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens
# Yield accumulated SSE chunks to frontend
for chunk in sse_chunks:
yield chunk
# Save thinking/text steps to all_steps for DB storage
if thinking_step_id is not None:
all_steps.append({
"id": thinking_step_id, "index": thinking_step_idx,
"type": "thinking", "content": full_thinking,
})
step_index += 1
if text_step_id is not None:
all_steps.append({
"id": text_step_id, "index": text_step_idx,
"type": "text", "content": full_content,
})
step_index += 1
# --- Branch: tool calls vs final ---
if tool_calls_list:
all_tool_calls.extend(tool_calls_list)
# Emit tool_call steps (before execution)
for tc in tool_calls_list:
call_step = {
"id": f"step-{step_index}",
"index": step_index,
"type": "tool_call",
"id_ref": tc["id"],
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
all_steps.append(call_step)
yield _sse_event("process_step", call_step)
step_index += 1
# Execute tools with error wrapping
tool_results = self._execute_tools_safe(
app, executor, tool_calls_list, context
)
# Emit tool_result steps
for tr in tool_results:
skipped = False
try:
result_content = json.loads(tr["content"])
skipped = result_content.get("skipped", False)
except Exception:
skipped = False
result_step = {
"id": f"step-{step_index}",
"index": step_index,
"type": "tool_result",
"id_ref": tr["tool_call_id"],
"name": tr["name"],
"content": tr["content"],
"skipped": skipped,
}
all_steps.append(result_step)
yield _sse_event("process_step", result_step)
step_index += 1
# Append assistant message + tool results for the next iteration
messages.append({
"role": "assistant",
"content": full_content or None,
"tool_calls": tool_calls_list,
})
messages.extend(tool_results)
all_tool_results.extend(tool_results)
continue
# --- No tool calls: final iteration — save message to DB ---
msg_id = str(uuid.uuid4())
suggested_title = self._save_message(
app, conv_id, conv_model, msg_id,
full_content, all_tool_calls, all_tool_results,
all_steps, total_prompt_tokens, total_completion_tokens,
)
yield _sse_event("done", {
"message_id": msg_id,
"token_count": total_completion_tokens,
"suggested_title": suggested_title,
})
return
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
def safe_generate():
"""Wrapper that catches client disconnection during yield."""
try:
yield from generate()
except (ClientDisconnected, BrokenPipeError, ConnectionResetError):
pass
return Response(
safe_generate(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache, no-store, must-revalidate",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
},
)
# ------------------------------------------------------------------
# Private helpers — extracted for testability and readability
# ------------------------------------------------------------------
def _stream_llm_response(
self, app, conv_id, messages, tools, tool_choice, step_index
):
"""Call LLM streaming API and parse the response.
Returns a tuple of parsed results, or None if the client disconnected.
Raises HTTPError / ConnectionError / Timeout for the caller to handle.
"""
full_content = "" full_content = ""
full_thinking = "" full_thinking = ""
token_count = 0 token_count = 0
msg_id = str(uuid.uuid4()) prompt_tokens = 0
tool_calls_list = [] tool_calls_list = []
# Streaming step tracking — step ID is assigned on first chunk arrival.
# thinking always precedes text in GLM's streaming order, so text gets step_index+1.
thinking_step_id = None thinking_step_id = None
thinking_step_idx = None thinking_step_idx = None
text_step_id = None text_step_id = None
text_step_idx = None text_step_idx = None
try: sse_chunks = [] # Collect SSE events to yield later
with app.app_context(): with app.app_context():
active_conv = db.session.get(Conversation, conv_id) active_conv = db.session.get(Conversation, conv_id)
resp = self.llm.call( resp = self.llm.call(
@ -97,16 +261,15 @@ class ChatService:
temperature=active_conv.temperature, temperature=active_conv.temperature,
thinking_enabled=active_conv.thinking_enabled, thinking_enabled=active_conv.thinking_enabled,
tools=tools, tools=tools,
tool_choice=tool_choice,
stream=True, stream=True,
) )
resp.raise_for_status() resp.raise_for_status()
# Stream LLM response chunk by chunk
for line in resp.iter_lines(): for line in resp.iter_lines():
# Early exit if client has disconnected
if _client_disconnected(): if _client_disconnected():
resp.close() resp.close()
return return None
if not line: if not line:
continue continue
@ -121,7 +284,6 @@ class ChatService:
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
# Extract usage first (present in last chunk when stream_options is set)
usage = chunk.get("usage", {}) usage = chunk.get("usage", {})
if usage: if usage:
token_count = usage.get("completion_tokens", 0) token_count = usage.get("completion_tokens", 0)
@ -133,121 +295,110 @@ class ChatService:
delta = choices[0].get("delta", {}) delta = choices[0].get("delta", {})
# Accumulate thinking content for this iteration
reasoning = delta.get("reasoning_content", "") reasoning = delta.get("reasoning_content", "")
if reasoning: if reasoning:
full_thinking += reasoning full_thinking += reasoning
if thinking_step_id is None: if thinking_step_id is None:
thinking_step_id = f'step-{step_index}' thinking_step_id = f"step-{step_index}"
thinking_step_idx = step_index thinking_step_idx = step_index
yield f"event: process_step\ndata: {json.dumps({'id': thinking_step_id, 'index': thinking_step_idx, 'type': 'thinking', 'content': full_thinking}, ensure_ascii=False)}\n\n" sse_chunks.append(_sse_event("process_step", {
"id": thinking_step_id, "index": thinking_step_idx,
"type": "thinking", "content": full_thinking,
}))
# Accumulate text content for this iteration
text = delta.get("content", "") text = delta.get("content", "")
if text: if text:
full_content += text full_content += text
if text_step_id is None: if text_step_id is None:
text_step_idx = step_index + (1 if thinking_step_id is not None else 0) text_step_idx = step_index + (1 if thinking_step_id is not None else 0)
text_step_id = f'step-{text_step_idx}' text_step_id = f"step-{text_step_idx}"
yield f"event: process_step\ndata: {json.dumps({'id': text_step_id, 'index': text_step_idx, 'type': 'text', 'content': full_content}, ensure_ascii=False)}\n\n" sse_chunks.append(_sse_event("process_step", {
"id": text_step_id, "index": text_step_idx,
"type": "text", "content": full_content,
}))
# Accumulate tool calls from streaming deltas
tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list) tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list)
except Exception as e: return (
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n" full_content, full_thinking, tool_calls_list,
return thinking_step_id, thinking_step_idx,
text_step_id, text_step_idx,
token_count, prompt_tokens,
sse_chunks,
)
# --- Finalize: save thinking/text steps to all_steps for DB storage --- def _execute_tools_safe(self, app, executor, tool_calls_list, context):
# No need to yield to frontend — incremental process_step events already sent. """Execute tool calls with top-level error wrapping.
if thinking_step_id is not None:
all_steps.append({
'id': thinking_step_id, 'index': thinking_step_idx,
'type': 'thinking', 'content': full_thinking,
})
step_index += 1
if text_step_id is not None: If an unexpected exception occurs during tool execution, it is
all_steps.append({ converted into error tool results instead of crashing the stream.
'id': text_step_id, 'index': text_step_idx, """
'type': 'text', 'content': full_content, try:
})
step_index += 1
# --- Branch: tool calls vs final ---
if tool_calls_list:
all_tool_calls.extend(tool_calls_list)
# Phase 1: emit all tool_call steps (before execution)
for tc in tool_calls_list:
call_step = {
'id': f'step-{step_index}',
'index': step_index,
'type': 'tool_call',
'id_ref': tc['id'],
'name': tc['function']['name'],
'arguments': tc['function']['arguments'],
}
all_steps.append(call_step)
yield f"event: process_step\ndata: {json.dumps(call_step, ensure_ascii=False)}\n\n"
step_index += 1
# Phase 2: execute tools — parallel when multiple, sequential when single
if len(tool_calls_list) > 1: if len(tool_calls_list) > 1:
with app.app_context(): with app.app_context():
tool_results = executor.process_tool_calls_parallel( tool_results = executor.process_tool_calls_parallel(
tool_calls_list, context, max_workers=4 tool_calls_list, context, max_workers=TOOL_MAX_WORKERS
) )
else: else:
with app.app_context(): with app.app_context():
tool_results = executor.process_tool_calls( tool_results = executor.process_tool_calls(
tool_calls_list, context tool_calls_list, context
) )
except Exception as e:
logger.exception("Error during tool execution")
tool_results = [
{
"role": "tool",
"tool_call_id": tc["id"],
"name": tc["function"]["name"],
"content": json.dumps({
"success": False,
"error": f"Tool execution failed: {e}",
}, ensure_ascii=False),
}
for tc in tool_calls_list
]
# Phase 3: emit all tool_result steps (after execution, same order) # Truncate oversized tool result content
for tr in tool_results: for tr in tool_results:
if len(tr["content"]) > TOOL_RESULT_MAX_LENGTH:
try: try:
result_content = json.loads(tr["content"]) result_data = json.loads(tr["content"])
skipped = result_content.get("skipped", False) original = result_data
except Exception: except (json.JSONDecodeError, TypeError):
skipped = False original = None
result_step = {
'id': f'step-{step_index}',
'index': step_index,
'type': 'tool_result',
'id_ref': tr['tool_call_id'],
'name': tr['name'],
'content': tr['content'],
'skipped': skipped,
}
all_steps.append(result_step)
yield f"event: process_step\ndata: {json.dumps(result_step, ensure_ascii=False)}\n\n"
step_index += 1
# Append assistant message + tool results for the next iteration tr["content"] = json.dumps(
messages.append({ {"success": False, "error": "Tool result too large, truncated"},
"role": "assistant", ensure_ascii=False,
"content": full_content or None, ) if not original else json.dumps(
"tool_calls": tool_calls_list {
}) **original,
messages.extend(tool_results) "truncated": True,
all_tool_results.extend(tool_results) "_note": f"Content truncated, original length {len(tr['content'])} chars",
total_completion_tokens += token_count },
continue ensure_ascii=False,
default=str,
)[:TOOL_RESULT_MAX_LENGTH]
# --- No tool calls: final iteration — save message to DB --- return tool_results
def _save_message(
self, app, conv_id, conv_model, msg_id,
full_content, all_tool_calls, all_tool_results,
all_steps, total_prompt_tokens, total_completion_tokens,
):
"""Save the final assistant message and auto-generate title if needed.
Returns the suggested_title or None.
"""
suggested_title = None suggested_title = None
# prompt_tokens already holds the last iteration's value (set during streaming)
total_completion_tokens += token_count
with app.app_context(): with app.app_context():
# Build content JSON with ordered steps array for DB storage. content_json = {"text": full_content}
# 'steps' is the single source of truth for rendering order.
content_json = {
"text": full_content,
}
if all_tool_calls: if all_tool_calls:
content_json["tool_calls"] = self._build_tool_calls_json(all_tool_calls, all_tool_results) content_json["tool_calls"] = self._build_tool_calls_json(
# Store ordered steps — the single source of truth for rendering order all_tool_calls, all_tool_results
)
content_json["steps"] = all_steps content_json["steps"] = all_steps
msg = Message( msg = Message(
@ -260,13 +411,13 @@ class ChatService:
db.session.add(msg) db.session.add(msg)
db.session.commit() db.session.commit()
# Auto-generate title from first user message if needed
conv = db.session.get(Conversation, conv_id) conv = db.session.get(Conversation, conv_id)
# Record token usage (get user_id from conv, not g —
# app.app_context() creates a new context where g.current_user is lost)
if conv: if conv:
record_token_usage(conv.user_id, conv_model, prompt_tokens, total_completion_tokens) record_token_usage(
conv.user_id, conv_model,
total_prompt_tokens, total_completion_tokens,
)
if conv and (not conv.title or conv.title == "新对话"): if conv and (not conv.title or conv.title == "新对话"):
user_msg = Message.query.filter_by( user_msg = Message.query.filter_by(
@ -278,46 +429,19 @@ class ChatService:
title_text = content_data.get("text", "")[:30] title_text = content_data.get("text", "")[:30]
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
title_text = user_msg.content.strip()[:30] title_text = user_msg.content.strip()[:30]
if title_text: suggested_title = title_text or "新对话"
suggested_title = title_text
else:
suggested_title = "新对话"
db.session.refresh(conv) db.session.refresh(conv)
conv.title = suggested_title conv.title = suggested_title
db.session.commit() db.session.commit()
else:
suggested_title = None
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': total_completion_tokens, 'suggested_title': suggested_title}, ensure_ascii=False)}\n\n" return suggested_title
return
yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n"
def safe_generate():
"""Wrapper that catches client disconnection during yield."""
try:
yield from generate()
except (ClientDisconnected, BrokenPipeError, ConnectionResetError):
pass # Client aborted, silently stop
return Response(
safe_generate(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache, no-store, must-revalidate",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
)
def _build_tool_calls_json(self, tool_calls: list, tool_results: list) -> list: def _build_tool_calls_json(self, tool_calls: list, tool_results: list) -> list:
"""Build tool calls JSON structure - matches streaming format""" """Build tool calls JSON structure - matches streaming format."""
result = [] result = []
for i, tc in enumerate(tool_calls): for i, tc in enumerate(tool_calls):
result_content = tool_results[i]["content"] if i < len(tool_results) else None result_content = tool_results[i]["content"] if i < len(tool_results) else None
# Parse result to extract success/skipped status
success = True success = True
skipped = False skipped = False
execution_time = 0 execution_time = 0
@ -327,10 +451,9 @@ class ChatService:
success = result_data.get("success", True) success = result_data.get("success", True)
skipped = result_data.get("skipped", False) skipped = result_data.get("skipped", False)
execution_time = result_data.get("execution_time", 0) execution_time = result_data.get("execution_time", 0)
except: except (json.JSONDecodeError, TypeError):
pass pass
# Keep same structure as streaming format
result.append({ result.append({
"id": tc.get("id", ""), "id": tc.get("id", ""),
"type": tc.get("type", "function"), "type": tc.get("type", "function"),
@ -345,9 +468,8 @@ class ChatService:
}) })
return result return result
def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list: def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list:
"""Process tool calls from streaming delta""" """Process tool calls from streaming delta."""
tool_calls_delta = delta.get("tool_calls", []) tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta: for tc in tool_calls_delta:
idx = tc.get("index", 0) idx = tc.get("index", 0)
@ -355,7 +477,7 @@ class ChatService:
tool_calls_list.append({ tool_calls_list.append({
"id": tc.get("id", ""), "id": tc.get("id", ""),
"type": tc.get("type", "function"), "type": tc.get("type", "function"),
"function": {"name": "", "arguments": ""} "function": {"name": "", "arguments": ""},
}) })
if tc.get("id"): if tc.get("id"):
tool_calls_list[idx]["id"] = tc["id"] tool_calls_list[idx]["id"] = tc["id"]

View File

@ -9,7 +9,7 @@ import os
import re import re
import time import time
import requests import requests
from typing import Optional, List from typing import Optional, List, Union
def _resolve_env_vars(value: str) -> str: def _resolve_env_vars(value: str) -> str:
@ -59,7 +59,8 @@ class LLMClient:
raise ValueError(f"Model '{model}' has no api_key configured") raise ValueError(f"Model '{model}' has no api_key configured")
return api_url, api_key return api_url, api_key
def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled, tools, stream, api_url): def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled,
tools, tool_choice, stream, api_url):
"""Build request body with provider-specific parameter adaptation.""" """Build request body with provider-specific parameter adaptation."""
provider = _detect_provider(api_url) provider = _detect_provider(api_url)
@ -79,23 +80,17 @@ class LLMClient:
# --- Provider-specific: thinking --- # --- Provider-specific: thinking ---
if thinking_enabled: if thinking_enabled:
if provider == "glm": if provider == "glm" or provider == "deepseek":
body["thinking"] = {"type": "enabled"} body["thinking"] = {"type": "enabled"}
elif provider == "deepseek": else:
pass # deepseek-reasoner has built-in reasoning, no extra param raise NotImplementedError(f"Thinking not supported for provider '{provider}'")
# --- Provider-specific: tools ---
if tools: if tools:
body["tools"] = tools body["tools"] = tools
body["tool_choice"] = "auto" body["tool_choice"] = tool_choice if tool_choice is not None else "auto"
# --- Provider-specific: stream ---
if stream: if stream:
body["stream"] = True body["stream"] = True
if provider == "glm":
body["stream_options"] = {"include_usage": True}
elif provider == "deepseek":
pass # DeepSeek does not support stream_options
return body return body
@ -107,15 +102,16 @@ class LLMClient:
temperature: float = 1.0, temperature: float = 1.0,
thinking_enabled: bool = False, thinking_enabled: bool = False,
tools: Optional[List[dict]] = None, tools: Optional[List[dict]] = None,
tool_choice: Optional[Union[str, dict]] = None,
stream: bool = False, stream: bool = False,
timeout: int = 120, timeout: int = 200,
max_retries: int = 3, max_retries: int = 3,
): ):
"""Call LLM API with retry on rate limit (429)""" """Call LLM API with retry on rate limit (429)"""
api_url, api_key = self._get_credentials(model) api_url, api_key = self._get_credentials(model)
body = self._build_body( body = self._build_body(
model, messages, max_tokens, temperature, model, messages, max_tokens, temperature,
thinking_enabled, tools, stream, api_url, thinking_enabled, tools, tool_choice, stream, api_url,
) )
for attempt in range(max_retries + 1): for attempt in range(max_retries + 1):

View File

@ -638,6 +638,99 @@ buffer 拼接: "event: process_step\ndata: {\"id\":\"step-0\",...}\n\n"
--- ---
## Token 用量计算
### 术语定义
| 术语 | 说明 |
| --- | --- |
| `prompt_tokens` | 发给模型的输入 token 数量(包括 system prompt、历史消息、工具定义、工具结果等全部上下文 |
| `completion_tokens` | 模型生成的输出 token 数量(包括 thinking 内容、正文回复、工具调用 JSON |
| `total_tokens` | `prompt_tokens + completion_tokens` |
### 计算流程
一次完整的对话可能经历多轮工具调用迭代,每轮都会向 LLM 发送请求并收到响应。Token 用量计算分为三个阶段:
```mermaid
flowchart LR
A[LLM SSE Stream] -->|usage 字段| B["_stream_llm_response()"]
B -->|每轮累加| C["generate() 循环"]
C -->|最终值| D["_save_message()"]
D --> E["record_token_usage()"]
E --> F["TokenUsage 表"]
```
#### 1. 流式解析 — 从 SSE chunks 中提取
LLM API 在流的最后一个 chunk 中返回 `usage` 字段(需要在请求中设置 `stream_options` 才有,否则为空):
```python
# chat.py: _stream_llm_response()
usage = chunk.get("usage", {})
if usage:
token_count = usage.get("completion_tokens", 0) # 本轮输出 token
prompt_tokens = usage.get("prompt_tokens", 0) # 本轮输入 token
```
#### 2. 迭代累加 — generate() 循环
每轮迭代结束后,将本轮的 prompt 和 completion token 累加到总计:
```python
# chat.py: generate()
total_prompt_tokens += prompt_tokens # 累加每轮 prompt
total_completion_tokens += completion_tokens # 累加每轮 completion
```
#### 3. 记录到数据库
最终调用 `record_token_usage()` 写入 TokenUsage 表,同时 Message 表也记录 completion token
```python
# chat.py: _save_message()
msg = Message(token_count=total_completion_tokens) # Message 表仅记录 completion
record_token_usage(user_id, model, total_prompt_tokens, total_completion_tokens)
```
### 多轮迭代示例
一次涉及工具调用的对话(如:用户提问 → LLM 调用搜索 → LLM 生成回复):
```
迭代 1: prompt=800, completion=150 (LLM 决定调用 web_search)
迭代 2: prompt=1500, completion=300 (LLM 根据搜索结果生成最终回复)
─────────────────────────────────────────
累加结果:
total_prompt_tokens = 800 + 1500 = 2300
total_completion_tokens = 150 + 300 = 450
─────────────────────────────────────────
```
> **注意**`prompt_tokens` 的累加意味着存在重复计算 — 第 2 轮的 prompt 包含了第 1 轮的上下文,累加后 `total_prompt_tokens` 大于本次对话的真实输入 token 总量(历史部分被多次计算)。这是因为每轮请求是独立的 API 调用,各自计费。如果需要精确的单次对话输入 token可以只取最后一轮的 `prompt_tokens`
### 存储位置
| 位置 | 存什么 | 粒度 |
| --- | --- | --- |
| `Message.token_count` | `total_completion_tokens`(仅输出) | 单条消息 |
| `TokenUsage` 表 | `prompt_tokens` + `completion_tokens` + `total_tokens` | 按 user + 日期 + model 聚合 |
`TokenUsage`**user_id + 日期 + model** 维度聚合,同一天同一模型的多次对话会累加到同一条记录:
```python
# helpers.py: record_token_usage()
if existing:
existing.prompt_tokens += prompt_tokens
existing.completion_tokens += completion_tokens
existing.total_tokens += prompt_tokens + completion_tokens
else:
create new TokenUsage record
```
---
## 分页机制 ## 分页机制
所有列表接口使用**游标分页** 所有列表接口使用**游标分页**