"""Chat service module - Refactored with step-by-step flow""" import json import uuid import logging from typing import List, Dict, Any, AsyncGenerator, Tuple, Optional # For Python < 3.9 compatibility try: from typing import List except ImportError: pass from luxx.models import Conversation, Message, LLMProvider from luxx.tools.executor import ToolExecutor from luxx.tools.core import registry from luxx.services.llm_client import LLMClient from luxx.database import SessionLocal logger = logging.getLogger(__name__) MAX_ITERATIONS = 20 def _sse_event(event: str, data: dict) -> str: """Format SSE event string.""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def get_llm_client(conversation: Conversation = None) -> Tuple[LLMClient, Optional[int]]: """Get LLM client from conversation's provider.""" if conversation and conversation.provider_id: db = SessionLocal() try: provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first() if provider: return LLMClient( api_key=provider.api_key, api_url=provider.base_url, model=provider.default_model ), provider.max_tokens finally: db.close() return LLMClient(), None class StreamState: """Holds streaming state across iterations.""" def __init__(self): self.messages: List[Dict] = [] self.all_steps: List[Dict] = [] self.all_tool_calls: List[Dict] = [] self.all_tool_results: List[Dict] = [] self.step_index: int = 0 self.total_usage: Dict[str, int] = { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 } # Current iteration state self.full_content: str = "" self.full_thinking: str = "" self.tool_calls_list: List[Dict] = [] self.thinking_step_id: Optional[str] = None self.thinking_step_idx: Optional[int] = None self.text_step_id: Optional[str] = None self.text_step_idx: Optional[int] = None class ChatService: """Chat service with step-by-step flow architecture.""" def __init__(self): self.tool_executor = ToolExecutor() # ==================== Step 1: Initialize ==================== def build_messages(self, conversation: Conversation, user_message: str) -> List[Dict[str, str]]: """Build message list including user message.""" messages = [] if conversation.system_prompt: messages.append({"role": "system", "content": conversation.system_prompt}) db = SessionLocal() try: for msg in db.query(Message).filter( Message.conversation_id == conversation.id ).order_by(Message.created_at).all(): try: obj = json.loads(msg.content) if msg.content else {} content = obj.get("text", msg.content) if isinstance(obj, dict) else msg.content except (json.JSONDecodeError, TypeError): content = msg.content messages.append({"role": msg.role, "content": content}) finally: db.close() messages.append({"role": "user", "content": json.dumps({"text": user_message, "attachments": []})}) return messages def init_stream_state(self, conversation: Conversation, user_message: str, enabled_tools: list) -> Tuple[StreamState, LLMClient, Dict, str, Optional[int]]: """Initialize streaming state. Returns: (state, llm, tools, model, max_tokens)""" state = StreamState() state.messages = self.build_messages(conversation, user_message) tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else [] llm, max_tokens = get_llm_client(conversation) model = conversation.model or llm.default_model or "gpt-4" tool_context = {"workspace": None, "user_id": None, "username": None, "user_permission_level": 1} return state, llm, tools, model, max_tokens, tool_context # ==================== Step 2: Stream LLM ==================== def parse_sse_line(self, sse_line: str) -> Tuple[Optional[str], Optional[str]]: """Parse SSE line into (event_type, data_str).""" event_type = data_str = None for line in sse_line.strip().split('\n'): if line.startswith('event: '): event_type = line[7:].strip() elif line.startswith('data: '): data_str = line[6:].strip() return event_type, data_str def stream_llm_response(self, llm, model: str, messages: List[Dict], tools: list, temperature: float, max_tokens: int, thinking_enabled: bool): """ Stream LLM response and yield (sse_line, parsed_chunk) pairs. """ for sse_line in llm.stream_call( model=model, messages=messages, tools=tools, temperature=temperature, max_tokens=max_tokens or 8192, thinking_enabled=thinking_enabled ): _, data_str = self.parse_sse_line(sse_line) chunk = None if data_str: try: chunk = json.loads(data_str) except json.JSONDecodeError: pass yield sse_line, chunk def process_delta(self, state: StreamState, delta: dict) -> List[str]: """ Process a single delta, return list of SSE event strings. """ events = [] # Handle thinking/reasoning reasoning = delta.get("reasoning_content", "") if reasoning: if not state.full_thinking: state.thinking_step_idx = state.step_index state.thinking_step_id = f"step-{state.step_index}" state.step_index += 1 state.full_thinking += reasoning events.append(_sse_event("process_step", { "step": {"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking} })) # Handle content content = delta.get("content", "") if content: if not state.full_content: state.text_step_idx = state.step_index state.text_step_id = f"step-{state.step_index}" state.step_index += 1 state.full_content += content events.append(_sse_event("process_step", { "step": {"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content} })) # Handle tool calls for tc in delta.get("tool_calls", []): idx = tc.get("index", 0) if idx >= len(state.tool_calls_list): state.tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}}) func = tc.get("function", {}) if func.get("name"): state.tool_calls_list[idx]["function"]["name"] += func["name"] if func.get("arguments"): state.tool_calls_list[idx]["function"]["arguments"] += func["arguments"] return events def save_steps(self, state: StreamState): """Save current iteration steps to all_steps.""" if state.thinking_step_id: state.all_steps.append({"id": state.thinking_step_id, "index": state.thinking_step_idx, "type": "thinking", "content": state.full_thinking}) if state.text_step_id: state.all_steps.append({"id": state.text_step_id, "index": state.text_step_idx, "type": "text", "content": state.full_content}) # ==================== Step 3: Execute Tools ==================== def execute_tools(self, state: StreamState, tool_context: Dict) -> Tuple[List[Dict], List[str]]: """ Execute tools and return (results, events). """ if not state.tool_calls_list: return [], [] state.all_tool_calls.extend(state.tool_calls_list) tool_call_ids = [] events = [] # Yield tool_call steps for tc in state.tool_calls_list: step_id = f"step-{state.step_index}" tool_call_ids.append(step_id) state.step_index += 1 step = { "id": step_id, "index": len(state.all_steps), "type": "tool_call", "id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] } state.all_steps.append(step) events.append(_sse_event("process_step", {"step": step})) # Execute tools results = self.tool_executor.process_tool_calls_parallel(state.tool_calls_list, tool_context) # Yield tool_result steps for i, tr in enumerate(results): ref_id = tool_call_ids[i] if i < len(tool_call_ids) else f"tool-{i}" step_id = f"step-{state.step_index}" state.step_index += 1 content = tr.get("content", "") success = True try: obj = json.loads(content) if isinstance(obj, dict): success = obj.get("success", True) except: pass step = { "id": step_id, "index": len(state.all_steps), "type": "tool_result", "id_ref": ref_id, "name": tr.get("name", ""), "content": content, "success": success } state.all_steps.append(step) events.append(_sse_event("process_step", {"step": step})) state.all_tool_results.append({"role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": content}) return results, events def update_messages_for_next_iteration(self, state: StreamState, results: List[Dict]): """Update messages list with assistant response and tool results for next iteration.""" state.messages.append({"role": "assistant", "content": state.full_content or "", "tool_calls": state.tool_calls_list}) if results: state.messages.extend(state.all_tool_results[-len(results):]) state.all_tool_results = [] def reset_iteration_state(self, state: StreamState): """Reset state for next iteration.""" state.full_content = state.full_thinking = "" state.tool_calls_list = [] state.thinking_step_id = state.thinking_step_idx = state.text_step_id = state.text_step_idx = None # ==================== Step 4: Finalize ==================== def save_message(self, conversation_id: str, state: StreamState, token_count: int): """Save assistant message to database.""" content_json = {"text": state.full_content, "steps": state.all_steps} if state.all_tool_calls: content_json["tool_calls"] = state.all_tool_calls db = SessionLocal() try: db.add(Message( id=str(uuid.uuid4()), conversation_id=conversation_id, role="assistant", content=json.dumps(content_json, ensure_ascii=False), token_count=token_count, usage=json.dumps(state.total_usage) if state.total_usage else None )) db.commit() except Exception: db.rollback() raise finally: db.close() # ==================== Main Orchestrator ==================== async def stream_response( self, conversation: Conversation, user_message: str, thinking_enabled: bool = False, enabled_tools: list = None, user_id: int = None, username: str = None, workspace: str = None, user_permission_level: int = 1 ) -> AsyncGenerator[str, None]: """Main streaming orchestrator - step-by-step flow.""" try: # Step 1: Initialize state, llm, tools, model, max_tokens, tool_context = self.init_stream_state( conversation, user_message, enabled_tools or [] ) tool_context.update ({ "user_id": user_id, "username": username, "workspace": workspace, "user_permission_level": user_permission_level }) # ReAct loop for _ in range(MAX_ITERATIONS): self.reset_iteration_state(state) # Step 2: Stream LLM async for sse_line, chunk in self.stream_llm_response( llm, model, state.messages, tools, conversation.temperature, max_tokens, thinking_enabled or conversation.thinking_enabled ): # Handle error events event_type, data_str = self.parse_sse_line(sse_line) if event_type == 'error': error_data = json.loads(data_str) if data_str else {} yield _sse_event("error", {"content": error_data.get("content", "Unknown error")}) return if not chunk: continue # Extract usage if "usage" in chunk: u = chunk["usage"] state.total_usage = { "prompt_tokens": u.get("prompt_tokens", 0), "completion_tokens": u.get("completion_tokens", 0), "total_tokens": u.get("total_tokens", 0) } # Check for API errors if "error" in chunk: yield _sse_event("error", {"content": f"API Error: {chunk['error'].get('message', str(chunk['error']))}"}) return # Get delta delta = None choices = chunk.get("choices", []) if choices: delta = choices[0].get("delta", {}) if not delta: content = choices[0].get("message", {}).get("content") if content: delta = {"content": content} if not delta: content = chunk.get("content") or chunk.get("message", {}).get("content", "") if content: delta = {"content": content} if delta: # Step 2b: Process delta for event in self.process_delta(state, delta): yield event # Save steps after streaming self.save_steps(state) # Step 3: Execute tools if present if state.tool_calls_list: results, events = self.execute_tools(state, tool_context) for event in events: yield event self.update_messages_for_next_iteration(state, results) continue # Step 4: Finalize (no tool calls) token_count = state.total_usage.get("completion_tokens", 0) self.save_message(conversation.id, state, token_count) yield _sse_event("done", {"message_id": str(uuid.uuid4()), "token_count": token_count, "usage": state.total_usage}) return # Max iterations exceeded if state.full_content or state.all_tool_calls: self.save_message(conversation.id, state, state.total_usage.get("completion_tokens", 0)) yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) except Exception as e: yield _sse_event("error", {"content": str(e)}) # Global service chat_service = ChatService()