"""Chat service module""" import json import uuid from typing import List, Dict, Any, AsyncGenerator from luxx.models import Conversation, Message from luxx.tools.executor import ToolExecutor from luxx.tools.core import registry from luxx.services.llm_client import LLMClient from luxx.config import config # Maximum iterations to prevent infinite loops MAX_ITERATIONS = 10 def _sse_event(event: str, data: dict) -> str: """Format a Server-Sent Event string.""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def get_llm_client(conversation: Conversation = None): """Get LLM client, optionally using conversation's provider""" if conversation and conversation.provider_id: from luxx.models import LLMProvider from luxx.database import SessionLocal db = SessionLocal() try: provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first() if provider: client = LLMClient( api_key=provider.api_key, api_url=provider.base_url, model=provider.default_model ) return client finally: db.close() # Fallback to global config client = LLMClient() return client class ChatService: """Chat service with tool support""" def __init__(self): self.tool_executor = ToolExecutor() def build_messages( self, conversation: Conversation, include_system: bool = True ) -> List[Dict[str, str]]: """Build message list""" from luxx.database import SessionLocal from luxx.models import Message messages = [] if include_system and conversation.system_prompt: messages.append({ "role": "system", "content": conversation.system_prompt }) db = SessionLocal() try: db_messages = db.query(Message).filter( Message.conversation_id == conversation.id ).order_by(Message.created_at).all() for msg in db_messages: # Parse JSON content if possible try: content_obj = json.loads(msg.content) if msg.content else {} if isinstance(content_obj, dict): content = content_obj.get("text", msg.content) else: content = msg.content except (json.JSONDecodeError, TypeError): content = msg.content messages.append({ "role": msg.role, "content": content }) finally: db.close() return messages async def stream_response( self, conversation: Conversation, user_message: str, tools_enabled: bool = True ) -> AsyncGenerator[Dict[str, str], None]: """ Streaming response generator Yields raw SSE event strings for direct forwarding. """ try: messages = self.build_messages(conversation) messages.append({ "role": "user", "content": json.dumps({"text": user_message, "attachments": []}) }) tools = registry.list_all() if tools_enabled else None llm = get_llm_client(conversation) model = conversation.model or llm.default_model or "gpt-4" # State tracking all_steps = [] all_tool_calls = [] all_tool_results = [] step_index = 0 # Global step IDs for thinking and text (persist across iterations) thinking_step_id = None thinking_step_idx = None text_step_id = None text_step_idx = None for iteration in range(MAX_ITERATIONS): print(f"[CHAT] Starting iteration {iteration + 1}, messages: {len(messages)}") # Stream from LLM full_content = "" full_thinking = "" tool_calls_list = [] # Generate new step IDs for each iteration to track multiple thoughts/tools iteration_thinking_step_id = f"thinking-{iteration}" iteration_text_step_id = f"text-{iteration}" async for sse_line in llm.stream_call( model=model, messages=messages, tools=tools, temperature=conversation.temperature, max_tokens=conversation.max_tokens ): # Parse SSE line # Format: "event: xxx\ndata: {...}\n\n" event_type = None data_str = None for line in sse_line.strip().split('\n'): if line.startswith('event: '): event_type = line[7:].strip() elif line.startswith('data: '): data_str = line[6:].strip() if data_str is None: continue # Handle error events from LLM if event_type == 'error': try: error_data = json.loads(data_str) yield _sse_event("error", {"content": error_data.get("content", "Unknown error")}) except json.JSONDecodeError: yield _sse_event("error", {"content": data_str}) return # Parse the data try: chunk = json.loads(data_str) except json.JSONDecodeError: continue # Get delta choices = chunk.get("choices", []) if not choices: continue delta = choices[0].get("delta", {}) # Handle reasoning (thinking) reasoning = delta.get("reasoning_content", "") if reasoning: full_thinking += reasoning if thinking_step_id is None: thinking_step_id = iteration_thinking_step_id thinking_step_idx = step_index step_index += 1 yield _sse_event("process_step", { "id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking }) # Handle content content = delta.get("content", "") if content: full_content += content if text_step_id is None: text_step_idx = step_index text_step_id = iteration_text_step_id step_index += 1 yield _sse_event("process_step", { "id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content }) # Accumulate tool calls tool_calls_delta = delta.get("tool_calls", []) for tc in tool_calls_delta: idx = tc.get("index", 0) if idx >= len(tool_calls_list): tool_calls_list.append({ "id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""} }) func = tc.get("function", {}) if func.get("name"): tool_calls_list[idx]["function"]["name"] += func["name"] if func.get("arguments"): tool_calls_list[idx]["function"]["arguments"] += func["arguments"] # Save thinking step if thinking_step_id is not None: all_steps.append({ "id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking }) # Save text step if text_step_id is not None: all_steps.append({ "id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content }) # Handle tool calls if tool_calls_list: all_tool_calls.extend(tool_calls_list) # Yield tool_call steps tool_call_step_ids = [] # Track step IDs for tool calls for tc in tool_calls_list: call_step_id = f"tool-{iteration}-{tc.get('function', {}).get('name', 'unknown')}" tool_call_step_ids.append(call_step_id) call_step = { "id": call_step_id, "index": step_index, "type": "tool_call", "id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] } all_steps.append(call_step) yield _sse_event("process_step", call_step) step_index += 1 # Execute tools tool_results = self.tool_executor.process_tool_calls_parallel( tool_calls_list, {} ) # Yield tool_result steps for i, tr in enumerate(tool_results): tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"tool-{i}" result_step = { "id": f"result-{iteration}-{tr.get('name', 'unknown')}", "index": step_index, "type": "tool_result", "id_ref": tool_call_step_id, # Reference to the tool_call step "name": tr.get("name", ""), "content": tr.get("content", "") } all_steps.append(result_step) yield _sse_event("process_step", result_step) step_index += 1 all_tool_results.append({ "role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": tr.get("content", "") }) # Add assistant message with tool calls for next iteration messages.append({ "role": "assistant", "content": full_content or "", "tool_calls": tool_calls_list }) messages.extend(all_tool_results[-len(tool_results):]) all_tool_results = [] continue # No tool calls - final iteration, save message msg_id = str(uuid.uuid4()) self._save_message( conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps ) yield _sse_event("done", { "message_id": msg_id, "token_count": len(full_content) // 4 }) return # Max iterations exceeded yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) except Exception as e: print(f"[CHAT] Exception: {type(e).__name__}: {str(e)}") yield _sse_event("error", {"content": str(e)}) def _save_message( self, conversation_id: str, msg_id: str, full_content: str, all_tool_calls: list, all_tool_results: list, all_steps: list ): """Save the assistant message to database.""" from luxx.database import SessionLocal from luxx.models import Message content_json = { "text": full_content, "steps": all_steps } if all_tool_calls: content_json["tool_calls"] = all_tool_calls db = SessionLocal() try: msg = Message( id=msg_id, conversation_id=conversation_id, role="assistant", content=json.dumps(content_json, ensure_ascii=False), token_count=len(full_content) // 4 ) db.add(msg) db.commit() except Exception as e: print(f"[CHAT] Failed to save message: {e}") db.rollback() finally: db.close() # Global chat service chat_service = ChatService()