"""Chat service module""" import json import uuid import logging from typing import List, Dict,AsyncGenerator from luxx.models import Conversation, Message, LLMProvider from luxx.tools.executor import ToolExecutor from luxx.tools.core import registry from luxx.services.llm_client import LLMClient from luxx.database import SessionLocal logger = logging.getLogger(__name__) # Maximum iterations to prevent infinite loops MAX_ITERATIONS = 20 def _sse_event(event: str, data: dict) -> str: """Format a Server-Sent Event string.""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def get_llm_client(conversation: Conversation = None): """Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)""" max_tokens = None if conversation and conversation.provider_id: db = SessionLocal() try: provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first() if provider: max_tokens = provider.max_tokens client = LLMClient( api_key=provider.api_key, api_url=provider.base_url, model=provider.default_model ) return client, max_tokens finally: db.close() # Fallback to global config client = LLMClient() return client, max_tokens class ChatService: """Chat service with tool support""" def __init__(self): self.tool_executor = ToolExecutor() def build_messages( self, conversation: Conversation, include_system: bool = True ) -> List[Dict[str, str]]: """Build message list""" messages = [] if include_system and conversation.system_prompt: messages.append({ "role": "system", "content": conversation.system_prompt }) db = SessionLocal() try: db_messages = db.query(Message).filter( Message.conversation_id == conversation.id ).order_by(Message.created_at).all() for msg in db_messages: # Parse JSON content if possible try: content_obj = json.loads(msg.content) if msg.content else {} if isinstance(content_obj, dict): content = content_obj.get("text", msg.content) else: content = msg.content except (json.JSONDecodeError, TypeError): content = msg.content messages.append({ "role": msg.role, "content": content }) finally: db.close() return messages async def stream_response( self, conversation: Conversation, user_message: str, thinking_enabled: bool = False, enabled_tools: list = None, user_id: int = None, username: str = None, workspace: str = None, user_permission_level: int = 1 ) -> AsyncGenerator[Dict[str, str], None]: """ Streaming response generator Yields raw SSE event strings for direct forwarding. """ try: messages = self.build_messages(conversation) messages.append({ "role": "user", "content": json.dumps({"text": user_message, "attachments": []}) }) # Get tools based on enabled_tools filter if enabled_tools: tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] else: tools = [] llm, provider_max_tokens = get_llm_client(conversation) model = conversation.model or llm.default_model or "gpt-4" # 直接使用 provider 的 max_tokens max_tokens = provider_max_tokens # State tracking all_steps = [] all_tool_calls = [] all_tool_results = [] step_index = 0 # Token usage tracking total_usage = { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 } # Global step IDs for thinking and text (persist across iterations) thinking_step_id = None thinking_step_idx = None text_step_id = None text_step_idx = None for _ in range(MAX_ITERATIONS): # Stream from LLM full_content = "" full_thinking = "" tool_calls_list = [] # Step tracking - use unified step-{index} format thinking_step_id = None thinking_step_idx = None text_step_id = None text_step_idx = None async for sse_line in llm.stream_call( model=model, messages=messages, tools=tools, temperature=conversation.temperature, max_tokens=max_tokens or 8192, thinking_enabled=thinking_enabled or conversation.thinking_enabled ): # Parse SSE line # Format: "event: xxx\ndata: {...}\n\n" event_type = None data_str = None for line in sse_line.strip().split('\n'): if line.startswith('event: '): event_type = line[7:].strip() elif line.startswith('data: '): data_str = line[6:].strip() if data_str is None: continue # Handle error events from LLM if event_type == 'error': try: error_data = json.loads(data_str) yield _sse_event("error", {"content": error_data.get("content", "Unknown error")}) except json.JSONDecodeError: yield _sse_event("error", {"content": data_str}) return # Parse the data try: chunk = json.loads(data_str) except json.JSONDecodeError: yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"}) return # 提取 API 返回的 usage 信息 if "usage" in chunk: usage = chunk["usage"] total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0) total_usage["completion_tokens"] = usage.get("completion_tokens", 0) total_usage["total_tokens"] = usage.get("total_tokens", 0) # Check for error in response if "error" in chunk: error_msg = chunk["error"].get("message", str(chunk["error"])) yield _sse_event("error", {"content": f"API Error: {error_msg}"}) return # Get delta choices = chunk.get("choices", []) delta = None if choices: delta = choices[0].get("delta", {}) # If no delta but has message (non-streaming response) if not delta: message = choices[0].get("message", {}) if message.get("content"): delta = {"content": message["content"]} if not delta: # Check if there's any content in the response (for non-standard LLM responses) content = chunk.get("content") or chunk.get("message", {}).get("content", "") if content: delta = {"content": content} if delta: # Handle reasoning (thinking) reasoning = delta.get("reasoning_content", "") if reasoning: prev_thinking_len = len(full_thinking) full_thinking += reasoning if prev_thinking_len == 0: # New thinking stream started thinking_step_idx = step_index thinking_step_id = f"step-{step_index}" step_index += 1 yield _sse_event("process_step", { "step": { "id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking } }) # Handle content content = delta.get("content", "") if content: prev_content_len = len(full_content) full_content += content if prev_content_len == 0: # New text stream started text_step_idx = step_index text_step_id = f"step-{step_index}" step_index += 1 yield _sse_event("process_step", { "step": { "id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content } }) # Accumulate tool calls tool_calls_delta = delta.get("tool_calls", []) for tc in tool_calls_delta: idx = tc.get("index", 0) if idx >= len(tool_calls_list): tool_calls_list.append({ "id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""} }) func = tc.get("function", {}) if func.get("name"): tool_calls_list[idx]["function"]["name"] += func["name"] if func.get("arguments"): tool_calls_list[idx]["function"]["arguments"] += func["arguments"] # Save thinking step if thinking_step_id is not None: all_steps.append({ "id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking }) # Save text step if text_step_id is not None: all_steps.append({ "id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content }) # Handle tool calls if tool_calls_list: all_tool_calls.extend(tool_calls_list) # Yield tool_call steps - use unified step-{index} format tool_call_step_ids = [] # Track step IDs for tool calls for tc in tool_calls_list: call_step_idx = step_index call_step_id = f"step-{step_index}" tool_call_step_ids.append(call_step_id) step_index += 1 call_step = { "id": call_step_id, "index": call_step_idx, "type": "tool_call", "id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] } all_steps.append(call_step) yield _sse_event("process_step", {"step": call_step}) # Execute tools tool_context = { "workspace": workspace, "user_id": user_id, "username": username, "user_permission_level": user_permission_level } tool_results = self.tool_executor.process_tool_calls_parallel( tool_calls_list, tool_context ) # Yield tool_result steps - use unified step-{index} format for i, tr in enumerate(tool_results): tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" result_step_idx = step_index result_step_id = f"step-{step_index}" step_index += 1 # 解析 content 中的 success 状态 content = tr.get("content", "") success = True try: content_obj = json.loads(content) if isinstance(content_obj, dict): success = content_obj.get("success", True) except: pass result_step = { "id": result_step_id, "index": result_step_idx, "type": "tool_result", "id_ref": tool_call_step_id, # Reference to the tool_call step "name": tr.get("name", ""), "content": content, "success": success } all_steps.append(result_step) yield _sse_event("process_step", {"step": result_step}) all_tool_results.append({ "role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": tr.get("content", "") }) # Add assistant message with tool calls for next iteration messages.append({ "role": "assistant", "content": full_content or "", "tool_calls": tool_calls_list }) messages.extend(all_tool_results[-len(tool_results):]) all_tool_results = [] continue # No tool calls - final iteration, save message msg_id = str(uuid.uuid4()) actual_token_count = total_usage.get("completion_tokens", 0) logger.info(f"total_usage: {total_usage}") self._save_message( conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps, actual_token_count, total_usage ) yield _sse_event("done", { "message_id": msg_id, "token_count": actual_token_count, "usage": total_usage }) return # Max iterations exceeded - save message before error if full_content or all_tool_calls: msg_id = str(uuid.uuid4()) self._save_message( conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps, actual_token_count, total_usage ) yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) except Exception as e: yield _sse_event("error", {"content": str(e)}) def _save_message( self, conversation_id: str, msg_id: str, full_content: str, all_tool_calls: list, all_tool_results: list, all_steps: list, token_count: int = 0, usage: dict = None ): """Save the assistant message to database.""" content_json = { "text": full_content, "steps": all_steps } if all_tool_calls: content_json["tool_calls"] = all_tool_calls db = SessionLocal() try: msg = Message( id=msg_id, conversation_id=conversation_id, role="assistant", content=json.dumps(content_json, ensure_ascii=False), token_count=token_count, usage=json.dumps(usage) if usage else None ) db.add(msg) db.commit() except Exception as e: db.rollback() raise finally: db.close() # Global chat service chat_service = ChatService()