"""Chat service module""" import json import uuid import logging from typing import List, Dict, Any, AsyncGenerator, Optional from luxx.models import Conversation, Message from luxx.tools.executor import ToolExecutor from luxx.tools.core import registry from luxx.services.llm_client import LLMClient from luxx.config import config logger = logging.getLogger(__name__) # Maximum iterations to prevent infinite loops MAX_ITERATIONS = 10 def _sse_event(event: str, data: dict) -> str: """Format a Server-Sent Event string.""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def get_llm_client(conversation: Conversation = None): """Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)""" max_tokens = None if conversation and conversation.provider_id: from luxx.models import LLMProvider from luxx.database import SessionLocal db = SessionLocal() try: provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first() if provider: max_tokens = provider.max_tokens client = LLMClient( api_key=provider.api_key, api_url=provider.base_url, model=provider.default_model ) return client, max_tokens finally: db.close() # Fallback to global config client = LLMClient() return client, max_tokens class StreamContext: """Context for streaming response state management.""" def __init__( self, step_index: int = 0, current_step_id: str = None, current_step_idx: int = None, current_stream_type: str = None, full_content: str = "", full_thinking: str = "" ): self.step_index = step_index self.current_step_id = current_step_id self.current_step_idx = current_step_idx self.current_stream_type = current_stream_type self.full_content = full_content self.full_thinking = full_thinking self.all_steps = [] self.all_tool_calls = [] self.all_tool_results = [] self.tool_calls_list = [] def reset_iteration(self): """Reset streaming step tracker for new iteration.""" self.current_step_id = None self.current_step_idx = None self.current_stream_type = None self.full_content = "" self.full_thinking = "" self.tool_calls_list = [] def start_stream_step(self, step_type: str) -> str: """Start a new streaming step. Returns the step_id.""" self.current_step_idx = self.step_index self.current_step_id = f"step-{self.step_index}" self.current_stream_type = step_type self.step_index += 1 return self.current_step_id def yield_stream_step(self, step_type: str, content: str) -> Dict[str, Any]: """Yield a streaming step event.""" return _sse_event("process_step", { "step": { "id": self.current_step_id, "index": self.current_step_idx, "type": step_type, "content": content } }) def save_streaming_step(self): """Save the current streaming step to all_steps.""" if self.current_step_id is None: return if self.current_stream_type == "thinking": self.all_steps.append({ "id": self.current_step_id, "index": self.current_step_idx, "type": "thinking", "content": self.full_thinking }) elif self.current_stream_type == "text": self.all_steps.append({ "id": self.current_step_id, "index": self.current_step_idx, "type": "text", "content": self.full_content }) def handle_thinking_stream(self, delta: Dict) -> Optional[Dict]: """Handle reasoning/thinking delta. Returns yield_obj if step was yielded.""" reasoning = delta.get("reasoning_content", "") if not reasoning: return None prev_len = len(self.full_thinking) self.full_thinking += reasoning if prev_len == 0: # New thinking stream started self.start_stream_step("thinking") return self.yield_stream_step("thinking", self.full_thinking) def handle_text_stream(self, delta: Dict) -> Optional[Dict]: """Handle content delta. Returns yield_obj if step was yielded.""" content = delta.get("content", "") if not content: return None prev_len = len(self.full_content) self.full_content += content if prev_len == 0: # New text stream started self.start_stream_step("text") return self.yield_stream_step("text", self.full_content) def handle_tool_call(self) -> tuple: """Handle tool calls. Returns (tool_call_step_ids, tool_call_steps, yield_objs).""" tool_call_step_ids = [] tool_call_steps = [] yield_objs = [] for tc in self.tool_calls_list: call_step_idx = self.step_index call_step_id = f"step-{self.step_index}" tool_call_step_ids.append(call_step_id) self.step_index += 1 call_step = { "id": call_step_id, "index": call_step_idx, "type": "tool_call", "id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] } tool_call_steps.append(call_step) yield_objs.append(_sse_event("process_step", {"step": call_step})) return tool_call_step_ids, tool_call_steps, yield_objs def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple: """Handle single tool result. Returns (result_step, yield_obj).""" result_step_idx = self.step_index result_step_id = f"step-{self.step_index}" self.step_index += 1 content = tool_result.get("content", "") success = True try: content_obj = json.loads(content) if isinstance(content_obj, dict): success = content_obj.get("success", True) except: pass result_step = { "id": result_step_id, "index": result_step_idx, "type": "tool_result", "id_ref": tool_call_step_id, "name": tool_result.get("name", ""), "content": content, "success": success } return result_step, _sse_event("process_step", {"step": result_step}) class ChatService: """Chat service with tool support""" def __init__(self): self.tool_executor = ToolExecutor() def build_messages( self, conversation: Conversation, include_system: bool = True ) -> List[Dict[str, str]]: """Build message list""" from luxx.database import SessionLocal from luxx.models import Message messages = [] if include_system and conversation.system_prompt: messages.append({ "role": "system", "content": conversation.system_prompt }) db = SessionLocal() try: db_messages = db.query(Message).filter( Message.conversation_id == conversation.id ).order_by(Message.created_at).all() for msg in db_messages: # Parse JSON content if possible try: content_obj = json.loads(msg.content) if msg.content else {} if isinstance(content_obj, dict): content = content_obj.get("text", msg.content) else: content = msg.content except (json.JSONDecodeError, TypeError): content = msg.content messages.append({ "role": msg.role, "content": content }) finally: db.close() return messages async def stream_response( self, conversation: Conversation, user_message: str, thinking_enabled: bool = False, enabled_tools: list = None, user_id: int = None, username: str = None, workspace: str = None, user_permission_level: int = 1 ) -> AsyncGenerator[Dict[str, str], None]: """ Streaming response generator Yields raw SSE event strings for direct forwarding. """ try: messages = self.build_messages(conversation) messages.append({ "role": "user", "content": json.dumps({"text": user_message, "attachments": []}) }) # Get tools based on enabled_tools filter if enabled_tools: tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] else: tools = [] llm, provider_max_tokens = get_llm_client(conversation) model = conversation.model or llm.default_model or "gpt-4" # 直接使用 provider 的 max_tokens max_tokens = provider_max_tokens # Token usage tracking total_usage = { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 } actual_token_count = 0 # Streaming context for state management ctx = StreamContext() for iteration in range(MAX_ITERATIONS): # Reset streaming context for this iteration ctx.reset_iteration() async for sse_line in llm.stream_call( model=model, messages=messages, tools=tools, temperature=conversation.temperature, max_tokens=max_tokens or 8192, thinking_enabled=thinking_enabled or conversation.thinking_enabled ): # Parse SSE line # Format: "event: xxx\ndata: {...}\n\n" event_type = None data_str = None for line in sse_line.strip().split('\n'): if line.startswith('event: '): event_type = line[7:].strip() elif line.startswith('data: '): data_str = line[6:].strip() if data_str is None: continue # Handle error events from LLM if event_type == 'error': try: error_data = json.loads(data_str) yield _sse_event("error", {"content": error_data.get("content", "Unknown error")}) except json.JSONDecodeError: yield _sse_event("error", {"content": data_str}) return # Parse the data try: chunk = json.loads(data_str) except json.JSONDecodeError: yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"}) return # 提取 API 返回的 usage 信息 if "usage" in chunk: usage = chunk["usage"] total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0) total_usage["completion_tokens"] = usage.get("completion_tokens", 0) total_usage["total_tokens"] = usage.get("total_tokens", 0) # Check for error in response if "error" in chunk: error_msg = chunk["error"].get("message", str(chunk["error"])) yield _sse_event("error", {"content": f"API Error: {error_msg}"}) return # Get delta choices = chunk.get("choices", []) if not choices: # Check if there's any content in the response (for non-standard LLM responses) if chunk.get("content") or chunk.get("message"): content = chunk.get("content") or chunk.get("message", {}).get("content", "") if content: prev_len = len(ctx.full_content) ctx.full_content += content if prev_len == 0: # New text stream started ctx.start_stream_step("text") yield _sse_event("process_step", { "step": { "id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}", "index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1, "type": "text", "content": ctx.full_content } }) continue delta = choices[0].get("delta", {}) # Handle reasoning (thinking) yield_obj = ctx.handle_thinking_stream(delta) if yield_obj: yield yield_obj # Handle content yield_obj = ctx.handle_text_stream(delta) if yield_obj: yield yield_obj # Accumulate tool calls tool_calls_delta = delta.get("tool_calls", []) for tc in tool_calls_delta: idx = tc.get("index", 0) if idx >= len(ctx.tool_calls_list): ctx.tool_calls_list.append({ "id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""} }) func = tc.get("function", {}) if func.get("name"): ctx.tool_calls_list[idx]["function"]["name"] += func["name"] if func.get("arguments"): ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"] # Save streaming step (thinking or text) ctx.save_streaming_step() # Handle tool calls if ctx.tool_calls_list: ctx.all_tool_calls.extend(ctx.tool_calls_list) # Handle tool_call steps tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_call() ctx.all_steps.extend(tool_call_steps) for yield_obj in yield_objs: yield yield_obj # Execute tools tool_context = { "workspace": workspace, "user_id": user_id, "username": username, "user_permission_level": user_permission_level } tool_results = self.tool_executor.process_tool_calls_parallel( ctx.tool_calls_list, tool_context ) # Handle tool_result steps for i, tr in enumerate(tool_results): tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id) ctx.all_steps.append(result_step) yield yield_obj ctx.all_tool_results.append({ "role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": tr.get("content", "") }) # Add assistant message with tool calls for next iteration messages.append({ "role": "assistant", "content": ctx.full_content or "", "tool_calls": ctx.tool_calls_list }) messages.extend(ctx.all_tool_results[-len(tool_results):]) ctx.all_tool_results = [] continue # No tool calls - final iteration, save message msg_id = str(uuid.uuid4()) # 使用 API 返回的真实 completion_tokens,如果 API 没返回则降级使用估算值 actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4 logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}") self._save_message( conversation.id, msg_id, ctx.full_content, ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps, actual_token_count, total_usage ) yield _sse_event("done", { "message_id": msg_id, "token_count": actual_token_count, "usage": total_usage }) return # Max iterations exceeded - save message before error if ctx.full_content or ctx.all_tool_calls: msg_id = str(uuid.uuid4()) self._save_message( conversation.id, msg_id, ctx.full_content, ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps, actual_token_count, total_usage ) yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) except Exception as e: logger.error(f"Stream error: {e}") yield _sse_event("error", {"content": str(e)}) def _save_message( self, conversation_id: str, msg_id: str, full_content: str, all_tool_calls: list, all_tool_results: list, all_steps: list, token_count: int = 0, usage: dict = None ): """Save the assistant message to database.""" from luxx.database import SessionLocal from luxx.models import Message content_json = { "text": full_content, "steps": all_steps } if all_tool_calls: content_json["tool_calls"] = all_tool_calls db = SessionLocal() try: msg = Message( id=msg_id, conversation_id=conversation_id, role="assistant", content=json.dumps(content_json, ensure_ascii=False), token_count=token_count, usage=json.dumps(usage) if usage else None ) db.add(msg) db.commit() except Exception as e: db.rollback() raise finally: db.close() # Global chat service chat_service = ChatService()