chore: 优化chat部分结构

This commit is contained in:
ViperEkura 2026-04-17 23:01:48 +08:00
parent 22a4b8a4bb
commit f10909bec3
1 changed files with 201 additions and 145 deletions

View File

@ -45,6 +45,157 @@ def get_llm_client(conversation: Conversation = None):
return client, max_tokens
class StreamContext:
"""Context for streaming response state management."""
def __init__(
self,
step_index: int = 0,
current_step_id: str = None,
current_step_idx: int = None,
current_stream_type: str = None,
full_content: str = "",
full_thinking: str = ""
):
self.step_index = step_index
self.current_step_id = current_step_id
self.current_step_idx = current_step_idx
self.current_stream_type = current_stream_type
self.full_content = full_content
self.full_thinking = full_thinking
self.all_steps = []
self.all_tool_calls = []
self.all_tool_results = []
self.tool_calls_list = []
def reset_iteration(self):
"""Reset streaming step tracker for new iteration."""
self.current_step_id = None
self.current_step_idx = None
self.current_stream_type = None
self.full_content = ""
self.full_thinking = ""
self.tool_calls_list = []
def start_stream_step(self, step_type: str) -> str:
"""Start a new streaming step. Returns the step_id."""
self.current_step_idx = self.step_index
self.current_step_id = f"step-{self.step_index}"
self.current_stream_type = step_type
self.step_index += 1
return self.current_step_id
def yield_stream_step(self, step_type: str, content: str) -> Dict[str, Any]:
"""Yield a streaming step event."""
return _sse_event("process_step", {
"step": {
"id": self.current_step_id,
"index": self.current_step_idx,
"type": step_type,
"content": content
}
})
def save_streaming_step(self):
"""Save the current streaming step to all_steps."""
if self.current_step_id is None:
return
if self.current_stream_type == "thinking":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "thinking",
"content": self.full_thinking
})
elif self.current_stream_type == "text":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "text",
"content": self.full_content
})
def handle_thinking_stream(self, delta: Dict) -> Optional[Dict]:
"""Handle reasoning/thinking delta. Returns yield_obj if step was yielded."""
reasoning = delta.get("reasoning_content", "")
if not reasoning:
return None
prev_len = len(self.full_thinking)
self.full_thinking += reasoning
if prev_len == 0: # New thinking stream started
self.start_stream_step("thinking")
return self.yield_stream_step("thinking", self.full_thinking)
def handle_text_stream(self, delta: Dict) -> Optional[Dict]:
"""Handle content delta. Returns yield_obj if step was yielded."""
content = delta.get("content", "")
if not content:
return None
prev_len = len(self.full_content)
self.full_content += content
if prev_len == 0: # New text stream started
self.start_stream_step("text")
return self.yield_stream_step("text", self.full_content)
def handle_tool_call(self) -> tuple:
"""Handle tool calls. Returns (tool_call_step_ids, tool_call_steps, yield_objs)."""
tool_call_step_ids = []
tool_call_steps = []
yield_objs = []
for tc in self.tool_calls_list:
call_step_idx = self.step_index
call_step_id = f"step-{self.step_index}"
tool_call_step_ids.append(call_step_id)
self.step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
tool_call_steps.append(call_step)
yield_objs.append(_sse_event("process_step", {"step": call_step}))
return tool_call_step_ids, tool_call_steps, yield_objs
def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple:
"""Handle single tool result. Returns (result_step, yield_obj)."""
result_step_idx = self.step_index
result_step_id = f"step-{self.step_index}"
self.step_index += 1
content = tool_result.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id,
"name": tool_result.get("name", ""),
"content": content,
"success": success
}
return result_step, _sse_event("process_step", {"step": result_step})
class ChatService:
"""Chat service with tool support"""
@ -129,12 +280,6 @@ class ChatService:
# 直接使用 provider 的 max_tokens
max_tokens = provider_max_tokens
# State tracking
all_steps = []
all_tool_calls = []
all_tool_results = []
step_index = 0
# Token usage tracking
total_usage = {
"prompt_tokens": 0,
@ -142,23 +287,12 @@ class ChatService:
"total_tokens": 0
}
# Global step IDs for thinking and text (persist across iterations)
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
# Streaming context for state management
ctx = StreamContext()
for iteration in range(MAX_ITERATIONS):
# Stream from LLM
full_content = ""
full_thinking = ""
tool_calls_list = []
# Step tracking - use unified step-{index} format
thinking_step_id = None
thinking_step_idx = None
text_step_id = None
text_step_idx = None
# Reset streaming context for this iteration
ctx.reset_iteration()
async for sse_line in llm.stream_call(
model=model,
@ -218,19 +352,16 @@ class ChatService:
if chunk.get("content") or chunk.get("message"):
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
if content:
# BUG FIX: Update full_content so it gets saved to database
prev_content_len = len(full_content)
full_content += content
if prev_content_len == 0: # New text stream started
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
prev_len = len(ctx.full_content)
ctx.full_content += content
if prev_len == 0: # New text stream started
ctx.start_stream_step("text")
yield _sse_event("process_step", {
"step": {
"id": text_step_id if prev_content_len == 0 else f"step-{step_index - 1}",
"index": text_step_idx if prev_content_len == 0 else step_index - 1,
"id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}",
"index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1,
"type": "text",
"content": full_content # Always send accumulated content
"content": ctx.full_content
}
})
continue
@ -238,96 +369,43 @@ class ChatService:
delta = choices[0].get("delta", {})
# Handle reasoning (thinking)
reasoning = delta.get("reasoning_content", "")
if reasoning:
prev_thinking_len = len(full_thinking)
full_thinking += reasoning
if prev_thinking_len == 0: # New thinking stream started
thinking_step_idx = step_index
thinking_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
}
})
yield_obj = ctx.handle_thinking_stream(delta)
if yield_obj:
yield yield_obj
# Handle content
content = delta.get("content", "")
if content:
prev_content_len = len(full_content)
full_content += content
if prev_content_len == 0: # New text stream started
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
}
})
yield_obj = ctx.handle_text_stream(delta)
if yield_obj:
yield yield_obj
# Accumulate tool calls
tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta:
idx = tc.get("index", 0)
if idx >= len(tool_calls_list):
tool_calls_list.append({
if idx >= len(ctx.tool_calls_list):
ctx.tool_calls_list.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""}
})
func = tc.get("function", {})
if func.get("name"):
tool_calls_list[idx]["function"]["name"] += func["name"]
ctx.tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
# Save thinking step
if thinking_step_id is not None:
all_steps.append({
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
})
# Save text step
if text_step_id is not None:
all_steps.append({
"id": text_step_id,
"index": text_step_idx,
"type": "text",
"content": full_content
})
# Save streaming step (thinking or text)
ctx.save_streaming_step()
# Handle tool calls
if tool_calls_list:
all_tool_calls.extend(tool_calls_list)
if ctx.tool_calls_list:
ctx.all_tool_calls.extend(ctx.tool_calls_list)
# Yield tool_call steps - use unified step-{index} format
tool_call_step_ids = [] # Track step IDs for tool calls
for tc in tool_calls_list:
call_step_idx = step_index
call_step_id = f"step-{step_index}"
tool_call_step_ids.append(call_step_id)
step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
all_steps.append(call_step)
yield _sse_event("process_step", {"step": call_step})
# Handle tool_call steps
tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_call()
ctx.all_steps.extend(tool_call_steps)
for yield_obj in yield_objs:
yield yield_obj
# Execute tools
tool_context = {
@ -337,39 +415,17 @@ class ChatService:
"user_permission_level": user_permission_level
}
tool_results = self.tool_executor.process_tool_calls_parallel(
tool_calls_list, tool_context
ctx.tool_calls_list, tool_context
)
# Yield tool_result steps - use unified step-{index} format
# Handle tool_result steps
for i, tr in enumerate(tool_results):
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
result_step_idx = step_index
result_step_id = f"step-{step_index}"
step_index += 1
result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id)
ctx.all_steps.append(result_step)
yield yield_obj
# 解析 content 中的 success 状态
content = tr.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id, # Reference to the tool_call step
"name": tr.get("name", ""),
"content": content,
"success": success
}
all_steps.append(result_step)
yield _sse_event("process_step", {"step": result_step})
all_tool_results.append({
ctx.all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": tr.get("content", "")
@ -378,27 +434,27 @@ class ChatService:
# Add assistant message with tool calls for next iteration
messages.append({
"role": "assistant",
"content": full_content or "",
"tool_calls": tool_calls_list
"content": ctx.full_content or "",
"tool_calls": ctx.tool_calls_list
})
messages.extend(all_tool_results[-len(tool_results):])
all_tool_results = []
messages.extend(ctx.all_tool_results[-len(tool_results):])
ctx.all_tool_results = []
continue
# No tool calls - final iteration, save message
msg_id = str(uuid.uuid4())
# 使用 API 返回的真实 completion_tokens如果 API 没返回则降级使用估算值
actual_token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
ctx.full_content,
ctx.all_tool_calls,
ctx.all_tool_results,
ctx.all_steps,
actual_token_count,
total_usage
)
@ -411,15 +467,15 @@ class ChatService:
return
# Max iterations exceeded - save message before error
if full_content or all_tool_calls:
if ctx.full_content or ctx.all_tool_calls:
msg_id = str(uuid.uuid4())
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
ctx.full_content,
ctx.all_tool_calls,
ctx.all_tool_results,
ctx.all_steps,
actual_token_count,
total_usage
)