From f10909bec3a8e1c5e51dc9216abc5169254a334e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 17 Apr 2026 23:01:48 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E4=BC=98=E5=8C=96chat=E9=83=A8?= =?UTF-8?q?=E5=88=86=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- luxx/services/chat.py | 346 ++++++++++++++++++++++++------------------ 1 file changed, 201 insertions(+), 145 deletions(-) diff --git a/luxx/services/chat.py b/luxx/services/chat.py index 718bc05..ccc9f5a 100644 --- a/luxx/services/chat.py +++ b/luxx/services/chat.py @@ -45,6 +45,157 @@ def get_llm_client(conversation: Conversation = None): return client, max_tokens +class StreamContext: + """Context for streaming response state management.""" + + def __init__( + self, + step_index: int = 0, + current_step_id: str = None, + current_step_idx: int = None, + current_stream_type: str = None, + full_content: str = "", + full_thinking: str = "" + ): + self.step_index = step_index + self.current_step_id = current_step_id + self.current_step_idx = current_step_idx + self.current_stream_type = current_stream_type + self.full_content = full_content + self.full_thinking = full_thinking + self.all_steps = [] + self.all_tool_calls = [] + self.all_tool_results = [] + self.tool_calls_list = [] + + def reset_iteration(self): + """Reset streaming step tracker for new iteration.""" + self.current_step_id = None + self.current_step_idx = None + self.current_stream_type = None + self.full_content = "" + self.full_thinking = "" + self.tool_calls_list = [] + + def start_stream_step(self, step_type: str) -> str: + """Start a new streaming step. Returns the step_id.""" + self.current_step_idx = self.step_index + self.current_step_id = f"step-{self.step_index}" + self.current_stream_type = step_type + self.step_index += 1 + return self.current_step_id + + def yield_stream_step(self, step_type: str, content: str) -> Dict[str, Any]: + """Yield a streaming step event.""" + return _sse_event("process_step", { + "step": { + "id": self.current_step_id, + "index": self.current_step_idx, + "type": step_type, + "content": content + } + }) + + def save_streaming_step(self): + """Save the current streaming step to all_steps.""" + if self.current_step_id is None: + return + + if self.current_stream_type == "thinking": + self.all_steps.append({ + "id": self.current_step_id, + "index": self.current_step_idx, + "type": "thinking", + "content": self.full_thinking + }) + elif self.current_stream_type == "text": + self.all_steps.append({ + "id": self.current_step_id, + "index": self.current_step_idx, + "type": "text", + "content": self.full_content + }) + + def handle_thinking_stream(self, delta: Dict) -> Optional[Dict]: + """Handle reasoning/thinking delta. Returns yield_obj if step was yielded.""" + reasoning = delta.get("reasoning_content", "") + if not reasoning: + return None + + prev_len = len(self.full_thinking) + self.full_thinking += reasoning + + if prev_len == 0: # New thinking stream started + self.start_stream_step("thinking") + + return self.yield_stream_step("thinking", self.full_thinking) + + def handle_text_stream(self, delta: Dict) -> Optional[Dict]: + """Handle content delta. Returns yield_obj if step was yielded.""" + content = delta.get("content", "") + if not content: + return None + + prev_len = len(self.full_content) + self.full_content += content + + if prev_len == 0: # New text stream started + self.start_stream_step("text") + + return self.yield_stream_step("text", self.full_content) + + def handle_tool_call(self) -> tuple: + """Handle tool calls. Returns (tool_call_step_ids, tool_call_steps, yield_objs).""" + tool_call_step_ids = [] + tool_call_steps = [] + yield_objs = [] + + for tc in self.tool_calls_list: + call_step_idx = self.step_index + call_step_id = f"step-{self.step_index}" + tool_call_step_ids.append(call_step_id) + self.step_index += 1 + + call_step = { + "id": call_step_id, + "index": call_step_idx, + "type": "tool_call", + "id_ref": tc.get("id", ""), + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"] + } + tool_call_steps.append(call_step) + yield_objs.append(_sse_event("process_step", {"step": call_step})) + + return tool_call_step_ids, tool_call_steps, yield_objs + + def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple: + """Handle single tool result. Returns (result_step, yield_obj).""" + result_step_idx = self.step_index + result_step_id = f"step-{self.step_index}" + self.step_index += 1 + + content = tool_result.get("content", "") + success = True + try: + content_obj = json.loads(content) + if isinstance(content_obj, dict): + success = content_obj.get("success", True) + except: + pass + + result_step = { + "id": result_step_id, + "index": result_step_idx, + "type": "tool_result", + "id_ref": tool_call_step_id, + "name": tool_result.get("name", ""), + "content": content, + "success": success + } + return result_step, _sse_event("process_step", {"step": result_step}) + + class ChatService: """Chat service with tool support""" @@ -129,12 +280,6 @@ class ChatService: # 直接使用 provider 的 max_tokens max_tokens = provider_max_tokens - # State tracking - all_steps = [] - all_tool_calls = [] - all_tool_results = [] - step_index = 0 - # Token usage tracking total_usage = { "prompt_tokens": 0, @@ -142,23 +287,12 @@ class ChatService: "total_tokens": 0 } - # Global step IDs for thinking and text (persist across iterations) - thinking_step_id = None - thinking_step_idx = None - text_step_id = None - text_step_idx = None + # Streaming context for state management + ctx = StreamContext() for iteration in range(MAX_ITERATIONS): - # Stream from LLM - full_content = "" - full_thinking = "" - tool_calls_list = [] - - # Step tracking - use unified step-{index} format - thinking_step_id = None - thinking_step_idx = None - text_step_id = None - text_step_idx = None + # Reset streaming context for this iteration + ctx.reset_iteration() async for sse_line in llm.stream_call( model=model, @@ -218,19 +352,16 @@ class ChatService: if chunk.get("content") or chunk.get("message"): content = chunk.get("content") or chunk.get("message", {}).get("content", "") if content: - # BUG FIX: Update full_content so it gets saved to database - prev_content_len = len(full_content) - full_content += content - if prev_content_len == 0: # New text stream started - text_step_idx = step_index - text_step_id = f"step-{step_index}" - step_index += 1 + prev_len = len(ctx.full_content) + ctx.full_content += content + if prev_len == 0: # New text stream started + ctx.start_stream_step("text") yield _sse_event("process_step", { "step": { - "id": text_step_id if prev_content_len == 0 else f"step-{step_index - 1}", - "index": text_step_idx if prev_content_len == 0 else step_index - 1, + "id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}", + "index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1, "type": "text", - "content": full_content # Always send accumulated content + "content": ctx.full_content } }) continue @@ -238,96 +369,43 @@ class ChatService: delta = choices[0].get("delta", {}) # Handle reasoning (thinking) - reasoning = delta.get("reasoning_content", "") - if reasoning: - prev_thinking_len = len(full_thinking) - full_thinking += reasoning - if prev_thinking_len == 0: # New thinking stream started - thinking_step_idx = step_index - thinking_step_id = f"step-{step_index}" - step_index += 1 - yield _sse_event("process_step", { - "step": { - "id": thinking_step_id, - "index": thinking_step_idx, - "type": "thinking", - "content": full_thinking - } - }) + yield_obj = ctx.handle_thinking_stream(delta) + if yield_obj: + yield yield_obj # Handle content - content = delta.get("content", "") - if content: - prev_content_len = len(full_content) - full_content += content - if prev_content_len == 0: # New text stream started - text_step_idx = step_index - text_step_id = f"step-{step_index}" - step_index += 1 - yield _sse_event("process_step", { - "step": { - "id": text_step_id, - "index": text_step_idx, - "type": "text", - "content": full_content - } - }) + yield_obj = ctx.handle_text_stream(delta) + if yield_obj: + yield yield_obj # Accumulate tool calls tool_calls_delta = delta.get("tool_calls", []) for tc in tool_calls_delta: idx = tc.get("index", 0) - if idx >= len(tool_calls_list): - tool_calls_list.append({ + if idx >= len(ctx.tool_calls_list): + ctx.tool_calls_list.append({ "id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""} }) func = tc.get("function", {}) if func.get("name"): - tool_calls_list[idx]["function"]["name"] += func["name"] + ctx.tool_calls_list[idx]["function"]["name"] += func["name"] if func.get("arguments"): - tool_calls_list[idx]["function"]["arguments"] += func["arguments"] + ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"] - # Save thinking step - if thinking_step_id is not None: - all_steps.append({ - "id": thinking_step_id, - "index": thinking_step_idx, - "type": "thinking", - "content": full_thinking - }) - - # Save text step - if text_step_id is not None: - all_steps.append({ - "id": text_step_id, - "index": text_step_idx, - "type": "text", - "content": full_content - }) + # Save streaming step (thinking or text) + ctx.save_streaming_step() # Handle tool calls - if tool_calls_list: - all_tool_calls.extend(tool_calls_list) + if ctx.tool_calls_list: + ctx.all_tool_calls.extend(ctx.tool_calls_list) - # Yield tool_call steps - use unified step-{index} format - tool_call_step_ids = [] # Track step IDs for tool calls - for tc in tool_calls_list: - call_step_idx = step_index - call_step_id = f"step-{step_index}" - tool_call_step_ids.append(call_step_id) - step_index += 1 - call_step = { - "id": call_step_id, - "index": call_step_idx, - "type": "tool_call", - "id_ref": tc.get("id", ""), - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"] - } - all_steps.append(call_step) - yield _sse_event("process_step", {"step": call_step}) + # Handle tool_call steps + tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_call() + ctx.all_steps.extend(tool_call_steps) + for yield_obj in yield_objs: + yield yield_obj # Execute tools tool_context = { @@ -337,39 +415,17 @@ class ChatService: "user_permission_level": user_permission_level } tool_results = self.tool_executor.process_tool_calls_parallel( - tool_calls_list, tool_context + ctx.tool_calls_list, tool_context ) - # Yield tool_result steps - use unified step-{index} format + # Handle tool_result steps for i, tr in enumerate(tool_results): tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" - result_step_idx = step_index - result_step_id = f"step-{step_index}" - step_index += 1 + result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id) + ctx.all_steps.append(result_step) + yield yield_obj - # 解析 content 中的 success 状态 - content = tr.get("content", "") - success = True - try: - content_obj = json.loads(content) - if isinstance(content_obj, dict): - success = content_obj.get("success", True) - except: - pass - - result_step = { - "id": result_step_id, - "index": result_step_idx, - "type": "tool_result", - "id_ref": tool_call_step_id, # Reference to the tool_call step - "name": tr.get("name", ""), - "content": content, - "success": success - } - all_steps.append(result_step) - yield _sse_event("process_step", {"step": result_step}) - - all_tool_results.append({ + ctx.all_tool_results.append({ "role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": tr.get("content", "") @@ -378,27 +434,27 @@ class ChatService: # Add assistant message with tool calls for next iteration messages.append({ "role": "assistant", - "content": full_content or "", - "tool_calls": tool_calls_list + "content": ctx.full_content or "", + "tool_calls": ctx.tool_calls_list }) - messages.extend(all_tool_results[-len(tool_results):]) - all_tool_results = [] + messages.extend(ctx.all_tool_results[-len(tool_results):]) + ctx.all_tool_results = [] continue # No tool calls - final iteration, save message msg_id = str(uuid.uuid4()) # 使用 API 返回的真实 completion_tokens,如果 API 没返回则降级使用估算值 - actual_token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4 + actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4 logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}") self._save_message( conversation.id, msg_id, - full_content, - all_tool_calls, - all_tool_results, - all_steps, + ctx.full_content, + ctx.all_tool_calls, + ctx.all_tool_results, + ctx.all_steps, actual_token_count, total_usage ) @@ -411,15 +467,15 @@ class ChatService: return # Max iterations exceeded - save message before error - if full_content or all_tool_calls: + if ctx.full_content or ctx.all_tool_calls: msg_id = str(uuid.uuid4()) self._save_message( conversation.id, msg_id, - full_content, - all_tool_calls, - all_tool_results, - all_steps, + ctx.full_content, + ctx.all_tool_calls, + ctx.all_tool_results, + ctx.all_steps, actual_token_count, total_usage )