refactor: 修改chat 主流程逻辑

This commit is contained in:
ViperEkura 2026-03-28 00:42:59 +08:00
parent 7bd19a7529
commit 6aea98554f
4 changed files with 432 additions and 215 deletions

View File

@ -33,3 +33,9 @@ if MODELS and not DEFAULT_MODEL:
# Max agentic loop iterations (tool call rounds) # Max agentic loop iterations (tool call rounds)
MAX_ITERATIONS = _cfg.get("max_iterations", 5) MAX_ITERATIONS = _cfg.get("max_iterations", 5)
# Max parallel workers for tool execution (ThreadPoolExecutor)
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
# Max character length for a single tool result content (truncated if exceeded)
TOOL_RESULT_MAX_LENGTH = _cfg.get("tool_result_max_length", 4096)

View File

@ -1,8 +1,11 @@
"""Chat completion service""" """Chat completion service"""
import json import json
import logging
import uuid import uuid
from flask import current_app, g, Response, request as flask_request from typing import Optional, Union
from flask import current_app, Response, request as flask_request
from werkzeug.exceptions import ClientDisconnected from werkzeug.exceptions import ClientDisconnected
import requests
from backend import db from backend import db
from backend.models import Conversation, Message from backend.models import Conversation, Message
from backend.tools import registry, ToolExecutor from backend.tools import registry, ToolExecutor
@ -11,14 +14,15 @@ from backend.utils.helpers import (
build_messages, build_messages,
) )
from backend.services.llm_client import LLMClient from backend.services.llm_client import LLMClient
from backend.config import MAX_ITERATIONS from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS, TOOL_RESULT_MAX_LENGTH
logger = logging.getLogger(__name__)
def _client_disconnected(): def _client_disconnected():
"""Check if the client has disconnected.""" """Check if the client has disconnected."""
try: try:
stream = flask_request.input_stream stream = flask_request.input_stream
# If input_stream is unavailable, assume still connected
if stream is None: if stream is None:
return False return False
return stream.closed return stream.closed
@ -26,151 +30,111 @@ def _client_disconnected():
return False return False
def _sse_event(event: str, data: dict) -> str:
"""Format a Server-Sent Event string."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
class ChatService: class ChatService:
"""Chat completion service with tool support""" """Chat completion service with tool support"""
def __init__(self, llm: LLMClient): def __init__(self, llm: LLMClient):
self.llm = llm self.llm = llm
def stream_response(
self,
conv: Conversation,
tools_enabled: bool = True,
project_id: str = None,
tool_choice: Optional[Union[str, dict]] = None,
):
"""Stream response with tool call support.
def stream_response(self, conv: Conversation, tools_enabled: bool = True, project_id: str = None):
"""Stream response with tool call support
Uses 'process_step' events to send thinking and tool calls in order, Uses 'process_step' events to send thinking and tool calls in order,
allowing them to be interleaved properly in the frontend. allowing them to be interleaved properly in the frontend.
Args: Args:
conv: Conversation object conv: Conversation object
tools_enabled: Whether to enable tools tools_enabled: Whether to enable tools
project_id: Project ID for workspace isolation project_id: Project ID for workspace isolation
tool_choice: Optional tool_choice override (e.g. "auto", "required", or dict)
""" """
conv_id = conv.id conv_id = conv.id
conv_model = conv.model conv_model = conv.model
app = current_app._get_current_object() app = current_app._get_current_object()
tools = registry.list_all() if tools_enabled else None tools = registry.list_all() if tools_enabled else None
initial_messages = build_messages(conv, project_id) initial_messages = build_messages(conv, project_id)
# Create per-request executor for thread-safe isolation.
# Each request gets its own _call_history and _cache, eliminating
# race conditions when multiple conversations stream concurrently.
executor = ToolExecutor(registry=registry) executor = ToolExecutor(registry=registry)
# Build context for tool execution
context = {"model": conv_model} context = {"model": conv_model}
if project_id: if project_id:
context["project_id"] = project_id context["project_id"] = project_id
elif conv.project_id: elif conv.project_id:
context["project_id"] = conv.project_id context["project_id"] = conv.project_id
def generate(): def generate():
messages = list(initial_messages) messages = list(initial_messages)
all_tool_calls = [] all_tool_calls = []
all_tool_results = [] all_tool_results = []
all_steps = [] # Collect all ordered steps for DB storage (thinking/text/tool_call/tool_result) all_steps = []
step_index = 0 # Track global step index for ordering step_index = 0
total_completion_tokens = 0 # Accumulated across all iterations total_completion_tokens = 0
prompt_tokens = 0 # Not accumulated — last iteration's value is sufficient total_prompt_tokens = 0
# (each iteration re-sends the full context, so earlier
# prompts are strict subsets of the final one)
for iteration in range(MAX_ITERATIONS): for iteration in range(MAX_ITERATIONS):
full_content = ""
full_thinking = ""
token_count = 0
msg_id = str(uuid.uuid4())
tool_calls_list = []
# Streaming step tracking — step ID is assigned on first chunk arrival.
# thinking always precedes text in GLM's streaming order, so text gets step_index+1.
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
try: try:
with app.app_context(): stream_result = self._stream_llm_response(
active_conv = db.session.get(Conversation, conv_id) app, conv_id, messages, tools, tool_choice, step_index
resp = self.llm.call( )
model=active_conv.model, except requests.exceptions.HTTPError as e:
messages=messages, resp = e.response
max_tokens=active_conv.max_tokens, if resp is not None and resp.status_code >= 500:
temperature=active_conv.temperature, yield _sse_event("error", {"content": f"LLM service unavailable ({resp.status_code})"})
thinking_enabled=active_conv.thinking_enabled, elif resp is not None and resp.status_code == 429:
tools=tools, yield _sse_event("error", {"content": "Rate limit exceeded, please try again later"})
stream=True, else:
) yield _sse_event("error", {"content": f"LLM request failed: {e}"})
resp.raise_for_status() return
except requests.exceptions.ConnectionError:
# Stream LLM response chunk by chunk yield _sse_event("error", {"content": "Unable to connect to LLM service"})
for line in resp.iter_lines(): return
# Early exit if client has disconnected except requests.exceptions.Timeout:
if _client_disconnected(): yield _sse_event("error", {"content": "LLM request timed out"})
resp.close() return
return
if not line:
continue
line = line.decode("utf-8")
if not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
# Extract usage first (present in last chunk when stream_options is set)
usage = chunk.get("usage", {})
if usage:
token_count = usage.get("completion_tokens", 0)
prompt_tokens = usage.get("prompt_tokens", 0)
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
# Accumulate thinking content for this iteration
reasoning = delta.get("reasoning_content", "")
if reasoning:
full_thinking += reasoning
if thinking_step_id is None:
thinking_step_id = f'step-{step_index}'
thinking_step_idx = step_index
yield f"event: process_step\ndata: {json.dumps({'id': thinking_step_id, 'index': thinking_step_idx, 'type': 'thinking', 'content': full_thinking}, ensure_ascii=False)}\n\n"
# Accumulate text content for this iteration
text = delta.get("content", "")
if text:
full_content += text
if text_step_id is None:
text_step_idx = step_index + (1 if thinking_step_id is not None else 0)
text_step_id = f'step-{text_step_idx}'
yield f"event: process_step\ndata: {json.dumps({'id': text_step_id, 'index': text_step_idx, 'type': 'text', 'content': full_content}, ensure_ascii=False)}\n\n"
# Accumulate tool calls from streaming deltas
tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list)
except Exception as e: except Exception as e:
yield f"event: error\ndata: {json.dumps({'content': str(e)}, ensure_ascii=False)}\n\n" logger.exception("Unexpected error during LLM streaming")
yield _sse_event("error", {"content": f"Internal error: {e}"})
return return
# --- Finalize: save thinking/text steps to all_steps for DB storage --- if stream_result is None:
# No need to yield to frontend — incremental process_step events already sent. return # Client disconnected
full_content, full_thinking, tool_calls_list, \
thinking_step_id, thinking_step_idx, \
text_step_id, text_step_idx, \
completion_tokens, prompt_tokens, \
sse_chunks = stream_result
total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens
# Yield accumulated SSE chunks to frontend
for chunk in sse_chunks:
yield chunk
# Save thinking/text steps to all_steps for DB storage
if thinking_step_id is not None: if thinking_step_id is not None:
all_steps.append({ all_steps.append({
'id': thinking_step_id, 'index': thinking_step_idx, "id": thinking_step_id, "index": thinking_step_idx,
'type': 'thinking', 'content': full_thinking, "type": "thinking", "content": full_thinking,
}) })
step_index += 1 step_index += 1
if text_step_id is not None: if text_step_id is not None:
all_steps.append({ all_steps.append({
'id': text_step_id, 'index': text_step_idx, "id": text_step_id, "index": text_step_idx,
'type': 'text', 'content': full_content, "type": "text", "content": full_content,
}) })
step_index += 1 step_index += 1
@ -178,127 +142,79 @@ class ChatService:
if tool_calls_list: if tool_calls_list:
all_tool_calls.extend(tool_calls_list) all_tool_calls.extend(tool_calls_list)
# Phase 1: emit all tool_call steps (before execution) # Emit tool_call steps (before execution)
for tc in tool_calls_list: for tc in tool_calls_list:
call_step = { call_step = {
'id': f'step-{step_index}', "id": f"step-{step_index}",
'index': step_index, "index": step_index,
'type': 'tool_call', "type": "tool_call",
'id_ref': tc['id'], "id_ref": tc["id"],
'name': tc['function']['name'], "name": tc["function"]["name"],
'arguments': tc['function']['arguments'], "arguments": tc["function"]["arguments"],
} }
all_steps.append(call_step) all_steps.append(call_step)
yield f"event: process_step\ndata: {json.dumps(call_step, ensure_ascii=False)}\n\n" yield _sse_event("process_step", call_step)
step_index += 1 step_index += 1
# Phase 2: execute tools — parallel when multiple, sequential when single # Execute tools with error wrapping
if len(tool_calls_list) > 1: tool_results = self._execute_tools_safe(
with app.app_context(): app, executor, tool_calls_list, context
tool_results = executor.process_tool_calls_parallel( )
tool_calls_list, context, max_workers=4
)
else:
with app.app_context():
tool_results = executor.process_tool_calls(
tool_calls_list, context
)
# Phase 3: emit all tool_result steps (after execution, same order) # Emit tool_result steps
for tr in tool_results: for tr in tool_results:
skipped = False
try: try:
result_content = json.loads(tr["content"]) result_content = json.loads(tr["content"])
skipped = result_content.get("skipped", False) skipped = result_content.get("skipped", False)
except Exception: except Exception:
skipped = False skipped = False
result_step = { result_step = {
'id': f'step-{step_index}', "id": f"step-{step_index}",
'index': step_index, "index": step_index,
'type': 'tool_result', "type": "tool_result",
'id_ref': tr['tool_call_id'], "id_ref": tr["tool_call_id"],
'name': tr['name'], "name": tr["name"],
'content': tr['content'], "content": tr["content"],
'skipped': skipped, "skipped": skipped,
} }
all_steps.append(result_step) all_steps.append(result_step)
yield f"event: process_step\ndata: {json.dumps(result_step, ensure_ascii=False)}\n\n" yield _sse_event("process_step", result_step)
step_index += 1 step_index += 1
# Append assistant message + tool results for the next iteration # Append assistant message + tool results for the next iteration
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": full_content or None, "content": full_content or None,
"tool_calls": tool_calls_list "tool_calls": tool_calls_list,
}) })
messages.extend(tool_results) messages.extend(tool_results)
all_tool_results.extend(tool_results) all_tool_results.extend(tool_results)
total_completion_tokens += token_count
continue continue
# --- No tool calls: final iteration — save message to DB --- # --- No tool calls: final iteration — save message to DB ---
suggested_title = None msg_id = str(uuid.uuid4())
# prompt_tokens already holds the last iteration's value (set during streaming) suggested_title = self._save_message(
total_completion_tokens += token_count app, conv_id, conv_model, msg_id,
with app.app_context(): full_content, all_tool_calls, all_tool_results,
# Build content JSON with ordered steps array for DB storage. all_steps, total_prompt_tokens, total_completion_tokens,
# 'steps' is the single source of truth for rendering order. )
content_json = {
"text": full_content,
}
if all_tool_calls:
content_json["tool_calls"] = self._build_tool_calls_json(all_tool_calls, all_tool_results)
# Store ordered steps — the single source of truth for rendering order
content_json["steps"] = all_steps
msg = Message( yield _sse_event("done", {
id=msg_id, "message_id": msg_id,
conversation_id=conv_id, "token_count": total_completion_tokens,
role="assistant", "suggested_title": suggested_title,
content=json.dumps(content_json, ensure_ascii=False), })
token_count=total_completion_tokens,
)
db.session.add(msg)
db.session.commit()
# Auto-generate title from first user message if needed
conv = db.session.get(Conversation, conv_id)
# Record token usage (get user_id from conv, not g —
# app.app_context() creates a new context where g.current_user is lost)
if conv:
record_token_usage(conv.user_id, conv_model, prompt_tokens, total_completion_tokens)
if conv and (not conv.title or conv.title == "新对话"):
user_msg = Message.query.filter_by(
conversation_id=conv_id, role="user"
).order_by(Message.created_at.asc()).first()
if user_msg and user_msg.content:
try:
content_data = json.loads(user_msg.content)
title_text = content_data.get("text", "")[:30]
except (json.JSONDecodeError, TypeError):
title_text = user_msg.content.strip()[:30]
if title_text:
suggested_title = title_text
else:
suggested_title = "新对话"
db.session.refresh(conv)
conv.title = suggested_title
db.session.commit()
else:
suggested_title = None
yield f"event: done\ndata: {json.dumps({'message_id': msg_id, 'token_count': total_completion_tokens, 'suggested_title': suggested_title}, ensure_ascii=False)}\n\n"
return return
yield f"event: error\ndata: {json.dumps({'content': 'exceeded maximum tool call iterations'}, ensure_ascii=False)}\n\n" yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
def safe_generate(): def safe_generate():
"""Wrapper that catches client disconnection during yield.""" """Wrapper that catches client disconnection during yield."""
try: try:
yield from generate() yield from generate()
except (ClientDisconnected, BrokenPipeError, ConnectionResetError): except (ClientDisconnected, BrokenPipeError, ConnectionResetError):
pass # Client aborted, silently stop pass
return Response( return Response(
safe_generate(), safe_generate(),
@ -308,16 +224,224 @@ class ChatService:
"X-Accel-Buffering": "no", "X-Accel-Buffering": "no",
"Connection": "keep-alive", "Connection": "keep-alive",
"Transfer-Encoding": "chunked", "Transfer-Encoding": "chunked",
} },
) )
# ------------------------------------------------------------------
# Private helpers — extracted for testability and readability
# ------------------------------------------------------------------
def _stream_llm_response(
self, app, conv_id, messages, tools, tool_choice, step_index
):
"""Call LLM streaming API and parse the response.
Returns a tuple of parsed results, or None if the client disconnected.
Raises HTTPError / ConnectionError / Timeout for the caller to handle.
"""
full_content = ""
full_thinking = ""
token_count = 0
prompt_tokens = 0
tool_calls_list = []
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
sse_chunks = [] # Collect SSE events to yield later
with app.app_context():
active_conv = db.session.get(Conversation, conv_id)
resp = self.llm.call(
model=active_conv.model,
messages=messages,
max_tokens=active_conv.max_tokens,
temperature=active_conv.temperature,
thinking_enabled=active_conv.thinking_enabled,
tools=tools,
tool_choice=tool_choice,
stream=True,
)
resp.raise_for_status()
for line in resp.iter_lines():
if _client_disconnected():
resp.close()
return None
if not line:
continue
line = line.decode("utf-8")
if not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
usage = chunk.get("usage", {})
if usage:
token_count = usage.get("completion_tokens", 0)
prompt_tokens = usage.get("prompt_tokens", 0)
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
reasoning = delta.get("reasoning_content", "")
if reasoning:
full_thinking += reasoning
if thinking_step_id is None:
thinking_step_id = f"step-{step_index}"
thinking_step_idx = step_index
sse_chunks.append(_sse_event("process_step", {
"id": thinking_step_id, "index": thinking_step_idx,
"type": "thinking", "content": full_thinking,
}))
text = delta.get("content", "")
if text:
full_content += text
if text_step_id is None:
text_step_idx = step_index + (1 if thinking_step_id is not None else 0)
text_step_id = f"step-{text_step_idx}"
sse_chunks.append(_sse_event("process_step", {
"id": text_step_id, "index": text_step_idx,
"type": "text", "content": full_content,
}))
tool_calls_list = self._process_tool_calls_delta(delta, tool_calls_list)
return (
full_content, full_thinking, tool_calls_list,
thinking_step_id, thinking_step_idx,
text_step_id, text_step_idx,
token_count, prompt_tokens,
sse_chunks,
)
def _execute_tools_safe(self, app, executor, tool_calls_list, context):
"""Execute tool calls with top-level error wrapping.
If an unexpected exception occurs during tool execution, it is
converted into error tool results instead of crashing the stream.
"""
try:
if len(tool_calls_list) > 1:
with app.app_context():
tool_results = executor.process_tool_calls_parallel(
tool_calls_list, context, max_workers=TOOL_MAX_WORKERS
)
else:
with app.app_context():
tool_results = executor.process_tool_calls(
tool_calls_list, context
)
except Exception as e:
logger.exception("Error during tool execution")
tool_results = [
{
"role": "tool",
"tool_call_id": tc["id"],
"name": tc["function"]["name"],
"content": json.dumps({
"success": False,
"error": f"Tool execution failed: {e}",
}, ensure_ascii=False),
}
for tc in tool_calls_list
]
# Truncate oversized tool result content
for tr in tool_results:
if len(tr["content"]) > TOOL_RESULT_MAX_LENGTH:
try:
result_data = json.loads(tr["content"])
original = result_data
except (json.JSONDecodeError, TypeError):
original = None
tr["content"] = json.dumps(
{"success": False, "error": "Tool result too large, truncated"},
ensure_ascii=False,
) if not original else json.dumps(
{
**original,
"truncated": True,
"_note": f"Content truncated, original length {len(tr['content'])} chars",
},
ensure_ascii=False,
default=str,
)[:TOOL_RESULT_MAX_LENGTH]
return tool_results
def _save_message(
self, app, conv_id, conv_model, msg_id,
full_content, all_tool_calls, all_tool_results,
all_steps, total_prompt_tokens, total_completion_tokens,
):
"""Save the final assistant message and auto-generate title if needed.
Returns the suggested_title or None.
"""
suggested_title = None
with app.app_context():
content_json = {"text": full_content}
if all_tool_calls:
content_json["tool_calls"] = self._build_tool_calls_json(
all_tool_calls, all_tool_results
)
content_json["steps"] = all_steps
msg = Message(
id=msg_id,
conversation_id=conv_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=total_completion_tokens,
)
db.session.add(msg)
db.session.commit()
conv = db.session.get(Conversation, conv_id)
if conv:
record_token_usage(
conv.user_id, conv_model,
total_prompt_tokens, total_completion_tokens,
)
if conv and (not conv.title or conv.title == "新对话"):
user_msg = Message.query.filter_by(
conversation_id=conv_id, role="user"
).order_by(Message.created_at.asc()).first()
if user_msg and user_msg.content:
try:
content_data = json.loads(user_msg.content)
title_text = content_data.get("text", "")[:30]
except (json.JSONDecodeError, TypeError):
title_text = user_msg.content.strip()[:30]
suggested_title = title_text or "新对话"
db.session.refresh(conv)
conv.title = suggested_title
db.session.commit()
return suggested_title
def _build_tool_calls_json(self, tool_calls: list, tool_results: list) -> list: def _build_tool_calls_json(self, tool_calls: list, tool_results: list) -> list:
"""Build tool calls JSON structure - matches streaming format""" """Build tool calls JSON structure - matches streaming format."""
result = [] result = []
for i, tc in enumerate(tool_calls): for i, tc in enumerate(tool_calls):
result_content = tool_results[i]["content"] if i < len(tool_results) else None result_content = tool_results[i]["content"] if i < len(tool_results) else None
# Parse result to extract success/skipped status
success = True success = True
skipped = False skipped = False
execution_time = 0 execution_time = 0
@ -327,10 +451,9 @@ class ChatService:
success = result_data.get("success", True) success = result_data.get("success", True)
skipped = result_data.get("skipped", False) skipped = result_data.get("skipped", False)
execution_time = result_data.get("execution_time", 0) execution_time = result_data.get("execution_time", 0)
except: except (json.JSONDecodeError, TypeError):
pass pass
# Keep same structure as streaming format
result.append({ result.append({
"id": tc.get("id", ""), "id": tc.get("id", ""),
"type": tc.get("type", "function"), "type": tc.get("type", "function"),
@ -345,9 +468,8 @@ class ChatService:
}) })
return result return result
def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list: def _process_tool_calls_delta(self, delta: dict, tool_calls_list: list) -> list:
"""Process tool calls from streaming delta""" """Process tool calls from streaming delta."""
tool_calls_delta = delta.get("tool_calls", []) tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta: for tc in tool_calls_delta:
idx = tc.get("index", 0) idx = tc.get("index", 0)
@ -355,7 +477,7 @@ class ChatService:
tool_calls_list.append({ tool_calls_list.append({
"id": tc.get("id", ""), "id": tc.get("id", ""),
"type": tc.get("type", "function"), "type": tc.get("type", "function"),
"function": {"name": "", "arguments": ""} "function": {"name": "", "arguments": ""},
}) })
if tc.get("id"): if tc.get("id"):
tool_calls_list[idx]["id"] = tc["id"] tool_calls_list[idx]["id"] = tc["id"]

View File

@ -9,7 +9,7 @@ import os
import re import re
import time import time
import requests import requests
from typing import Optional, List from typing import Optional, List, Union
def _resolve_env_vars(value: str) -> str: def _resolve_env_vars(value: str) -> str:
@ -59,7 +59,8 @@ class LLMClient:
raise ValueError(f"Model '{model}' has no api_key configured") raise ValueError(f"Model '{model}' has no api_key configured")
return api_url, api_key return api_url, api_key
def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled, tools, stream, api_url): def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled,
tools, tool_choice, stream, api_url):
"""Build request body with provider-specific parameter adaptation.""" """Build request body with provider-specific parameter adaptation."""
provider = _detect_provider(api_url) provider = _detect_provider(api_url)
@ -79,23 +80,17 @@ class LLMClient:
# --- Provider-specific: thinking --- # --- Provider-specific: thinking ---
if thinking_enabled: if thinking_enabled:
if provider == "glm": if provider == "glm" or provider == "deepseek":
body["thinking"] = {"type": "enabled"} body["thinking"] = {"type": "enabled"}
elif provider == "deepseek": else:
pass # deepseek-reasoner has built-in reasoning, no extra param raise NotImplementedError(f"Thinking not supported for provider '{provider}'")
# --- Provider-specific: tools ---
if tools: if tools:
body["tools"] = tools body["tools"] = tools
body["tool_choice"] = "auto" body["tool_choice"] = tool_choice if tool_choice is not None else "auto"
# --- Provider-specific: stream ---
if stream: if stream:
body["stream"] = True body["stream"] = True
if provider == "glm":
body["stream_options"] = {"include_usage": True}
elif provider == "deepseek":
pass # DeepSeek does not support stream_options
return body return body
@ -107,15 +102,16 @@ class LLMClient:
temperature: float = 1.0, temperature: float = 1.0,
thinking_enabled: bool = False, thinking_enabled: bool = False,
tools: Optional[List[dict]] = None, tools: Optional[List[dict]] = None,
tool_choice: Optional[Union[str, dict]] = None,
stream: bool = False, stream: bool = False,
timeout: int = 120, timeout: int = 200,
max_retries: int = 3, max_retries: int = 3,
): ):
"""Call LLM API with retry on rate limit (429)""" """Call LLM API with retry on rate limit (429)"""
api_url, api_key = self._get_credentials(model) api_url, api_key = self._get_credentials(model)
body = self._build_body( body = self._build_body(
model, messages, max_tokens, temperature, model, messages, max_tokens, temperature,
thinking_enabled, tools, stream, api_url, thinking_enabled, tools, tool_choice, stream, api_url,
) )
for attempt in range(max_retries + 1): for attempt in range(max_retries + 1):

View File

@ -638,6 +638,99 @@ buffer 拼接: "event: process_step\ndata: {\"id\":\"step-0\",...}\n\n"
--- ---
## Token 用量计算
### 术语定义
| 术语 | 说明 |
| --- | --- |
| `prompt_tokens` | 发给模型的输入 token 数量(包括 system prompt、历史消息、工具定义、工具结果等全部上下文 |
| `completion_tokens` | 模型生成的输出 token 数量(包括 thinking 内容、正文回复、工具调用 JSON |
| `total_tokens` | `prompt_tokens + completion_tokens` |
### 计算流程
一次完整的对话可能经历多轮工具调用迭代,每轮都会向 LLM 发送请求并收到响应。Token 用量计算分为三个阶段:
```mermaid
flowchart LR
A[LLM SSE Stream] -->|usage 字段| B["_stream_llm_response()"]
B -->|每轮累加| C["generate() 循环"]
C -->|最终值| D["_save_message()"]
D --> E["record_token_usage()"]
E --> F["TokenUsage 表"]
```
#### 1. 流式解析 — 从 SSE chunks 中提取
LLM API 在流的最后一个 chunk 中返回 `usage` 字段(需要在请求中设置 `stream_options` 才有,否则为空):
```python
# chat.py: _stream_llm_response()
usage = chunk.get("usage", {})
if usage:
token_count = usage.get("completion_tokens", 0) # 本轮输出 token
prompt_tokens = usage.get("prompt_tokens", 0) # 本轮输入 token
```
#### 2. 迭代累加 — generate() 循环
每轮迭代结束后,将本轮的 prompt 和 completion token 累加到总计:
```python
# chat.py: generate()
total_prompt_tokens += prompt_tokens # 累加每轮 prompt
total_completion_tokens += completion_tokens # 累加每轮 completion
```
#### 3. 记录到数据库
最终调用 `record_token_usage()` 写入 TokenUsage 表,同时 Message 表也记录 completion token
```python
# chat.py: _save_message()
msg = Message(token_count=total_completion_tokens) # Message 表仅记录 completion
record_token_usage(user_id, model, total_prompt_tokens, total_completion_tokens)
```
### 多轮迭代示例
一次涉及工具调用的对话(如:用户提问 → LLM 调用搜索 → LLM 生成回复):
```
迭代 1: prompt=800, completion=150 (LLM 决定调用 web_search)
迭代 2: prompt=1500, completion=300 (LLM 根据搜索结果生成最终回复)
─────────────────────────────────────────
累加结果:
total_prompt_tokens = 800 + 1500 = 2300
total_completion_tokens = 150 + 300 = 450
─────────────────────────────────────────
```
> **注意**`prompt_tokens` 的累加意味着存在重复计算 — 第 2 轮的 prompt 包含了第 1 轮的上下文,累加后 `total_prompt_tokens` 大于本次对话的真实输入 token 总量(历史部分被多次计算)。这是因为每轮请求是独立的 API 调用,各自计费。如果需要精确的单次对话输入 token可以只取最后一轮的 `prompt_tokens`
### 存储位置
| 位置 | 存什么 | 粒度 |
| --- | --- | --- |
| `Message.token_count` | `total_completion_tokens`(仅输出) | 单条消息 |
| `TokenUsage` 表 | `prompt_tokens` + `completion_tokens` + `total_tokens` | 按 user + 日期 + model 聚合 |
`TokenUsage`**user_id + 日期 + model** 维度聚合,同一天同一模型的多次对话会累加到同一条记录:
```python
# helpers.py: record_token_usage()
if existing:
existing.prompt_tokens += prompt_tokens
existing.completion_tokens += completion_tokens
existing.total_tokens += prompt_tokens + completion_tokens
else:
create new TokenUsage record
```
---
## 分页机制 ## 分页机制
所有列表接口使用**游标分页** 所有列表接口使用**游标分页**