"""Chat service module""" import json import uuid import logging from typing import List, Dict, Any, AsyncGenerator, Optional from luxx.models import Conversation, Message from luxx.tools.executor import ToolExecutor from luxx.tools.core import registry from luxx.services.llm_client import LLMClient from luxx.config import config logger = logging.getLogger(__name__) # Maximum iterations to prevent infinite loops MAX_ITERATIONS = 10 def _sse_event(event: str, data: dict) -> str: """Format a Server-Sent Event string.""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def get_llm_client(conversation: Conversation = None): """Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)""" max_tokens = None if conversation and conversation.provider_id: from luxx.models import LLMProvider from luxx.database import SessionLocal db = SessionLocal() try: provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first() if provider: max_tokens = provider.max_tokens client = LLMClient( api_key=provider.api_key, api_url=provider.base_url, model=provider.default_model ) return client, max_tokens finally: db.close() # Fallback to global config client = LLMClient() return client, max_tokens class ChatService: """Chat service with tool support""" def __init__(self): self.tool_executor = ToolExecutor() def build_messages( self, conversation: Conversation, include_system: bool = True ) -> List[Dict[str, str]]: """Build message list""" from luxx.database import SessionLocal from luxx.models import Message messages = [] if include_system and conversation.system_prompt: messages.append({ "role": "system", "content": conversation.system_prompt }) db = SessionLocal() try: db_messages = db.query(Message).filter( Message.conversation_id == conversation.id ).order_by(Message.created_at).all() for msg in db_messages: # Parse JSON content if possible try: content_obj = json.loads(msg.content) if msg.content else {} if isinstance(content_obj, dict): content = content_obj.get("text", msg.content) else: content = msg.content except (json.JSONDecodeError, TypeError): content = msg.content messages.append({ "role": msg.role, "content": content }) finally: db.close() return messages async def stream_response( self, conversation: Conversation, user_message: str, thinking_enabled: bool = False, enabled_tools: list = None ) -> AsyncGenerator[Dict[str, str], None]: """ Streaming response generator Yields raw SSE event strings for direct forwarding. """ try: messages = self.build_messages(conversation) messages.append({ "role": "user", "content": json.dumps({"text": user_message, "attachments": []}) }) # Get tools based on enabled_tools filter if enabled_tools: tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] else: tools = [] llm, provider_max_tokens = get_llm_client(conversation) model = conversation.model or llm.default_model or "gpt-4" # 使用 provider 的 max_tokens,如果 conversation 有自己的 max_tokens 则覆盖 max_tokens = conversation.max_tokens if hasattr(conversation, 'max_tokens') and conversation.max_tokens else provider_max_tokens # State tracking all_steps = [] all_tool_calls = [] all_tool_results = [] step_index = 0 # Token usage tracking total_usage = { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 } # Global step IDs for thinking and text (persist across iterations) thinking_step_id = None thinking_step_idx = None text_step_id = None text_step_idx = None for iteration in range(MAX_ITERATIONS): # Stream from LLM full_content = "" full_thinking = "" tool_calls_list = [] # Step tracking - use unified step-{index} format thinking_step_id = None thinking_step_idx = None text_step_id = None text_step_idx = None async for sse_line in llm.stream_call( model=model, messages=messages, tools=tools, temperature=conversation.temperature, max_tokens=max_tokens or 8192, thinking_enabled=thinking_enabled or conversation.thinking_enabled ): # Parse SSE line # Format: "event: xxx\ndata: {...}\n\n" event_type = None data_str = None for line in sse_line.strip().split('\n'): if line.startswith('event: '): event_type = line[7:].strip() elif line.startswith('data: '): data_str = line[6:].strip() if data_str is None: continue # Handle error events from LLM if event_type == 'error': try: error_data = json.loads(data_str) yield _sse_event("error", {"content": error_data.get("content", "Unknown error")}) except json.JSONDecodeError: yield _sse_event("error", {"content": data_str}) return # Parse the data try: chunk = json.loads(data_str) except json.JSONDecodeError: yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"}) return # 提取 API 返回的 usage 信息 if "usage" in chunk: usage = chunk["usage"] total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0) total_usage["completion_tokens"] = usage.get("completion_tokens", 0) total_usage["total_tokens"] = usage.get("total_tokens", 0) # Check for error in response if "error" in chunk: error_msg = chunk["error"].get("message", str(chunk["error"])) yield _sse_event("error", {"content": f"API Error: {error_msg}"}) return # Get delta choices = chunk.get("choices", []) if not choices: # Check if there's any content in the response (for non-standard LLM responses) if chunk.get("content") or chunk.get("message"): content = chunk.get("content") or chunk.get("message", {}).get("content", "") if content: # BUG FIX: Update full_content so it gets saved to database prev_content_len = len(full_content) full_content += content if prev_content_len == 0: # New text stream started text_step_idx = step_index text_step_id = f"step-{step_index}" step_index += 1 yield _sse_event("process_step", { "step": { "id": text_step_id if prev_content_len == 0 else f"step-{step_index - 1}", "index": text_step_idx if prev_content_len == 0 else step_index - 1, "type": "text", "content": full_content # Always send accumulated content } }) continue delta = choices[0].get("delta", {}) # Handle reasoning (thinking) reasoning = delta.get("reasoning_content", "") if reasoning: prev_thinking_len = len(full_thinking) full_thinking += reasoning if prev_thinking_len == 0: # New thinking stream started thinking_step_idx = step_index thinking_step_id = f"step-{step_index}" step_index += 1 yield _sse_event("process_step", { "step": { "id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking } }) # Handle content content = delta.get("content", "") if content: prev_content_len = len(full_content) full_content += content if prev_content_len == 0: # New text stream started text_step_idx = step_index text_step_id = f"step-{step_index}" step_index += 1 yield _sse_event("process_step", { "step": { "id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content } }) # Accumulate tool calls tool_calls_delta = delta.get("tool_calls", []) for tc in tool_calls_delta: idx = tc.get("index", 0) if idx >= len(tool_calls_list): tool_calls_list.append({ "id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""} }) func = tc.get("function", {}) if func.get("name"): tool_calls_list[idx]["function"]["name"] += func["name"] if func.get("arguments"): tool_calls_list[idx]["function"]["arguments"] += func["arguments"] # Save thinking step if thinking_step_id is not None: all_steps.append({ "id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking }) # Save text step if text_step_id is not None: all_steps.append({ "id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content }) # Handle tool calls if tool_calls_list: all_tool_calls.extend(tool_calls_list) # Yield tool_call steps - use unified step-{index} format tool_call_step_ids = [] # Track step IDs for tool calls for tc in tool_calls_list: call_step_idx = step_index call_step_id = f"step-{step_index}" tool_call_step_ids.append(call_step_id) step_index += 1 call_step = { "id": call_step_id, "index": call_step_idx, "type": "tool_call", "id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] } all_steps.append(call_step) yield _sse_event("process_step", {"step": call_step}) # Execute tools tool_results = self.tool_executor.process_tool_calls_parallel( tool_calls_list, {} ) # Yield tool_result steps - use unified step-{index} format for i, tr in enumerate(tool_results): tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" result_step_idx = step_index result_step_id = f"step-{step_index}" step_index += 1 # 解析 content 中的 success 状态 content = tr.get("content", "") success = True try: content_obj = json.loads(content) if isinstance(content_obj, dict): success = content_obj.get("success", True) except: pass result_step = { "id": result_step_id, "index": result_step_idx, "type": "tool_result", "id_ref": tool_call_step_id, # Reference to the tool_call step "name": tr.get("name", ""), "content": content, "success": success } all_steps.append(result_step) yield _sse_event("process_step", {"step": result_step}) all_tool_results.append({ "role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": tr.get("content", "") }) # Add assistant message with tool calls for next iteration messages.append({ "role": "assistant", "content": full_content or "", "tool_calls": tool_calls_list }) messages.extend(all_tool_results[-len(tool_results):]) all_tool_results = [] continue # No tool calls - final iteration, save message msg_id = str(uuid.uuid4()) # 使用 API 返回的真实 completion_tokens,如果 API 没返回则降级使用估算值 actual_token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4 logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}") self._save_message( conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps, actual_token_count, total_usage ) yield _sse_event("done", { "message_id": msg_id, "token_count": actual_token_count, "usage": total_usage }) return # Max iterations exceeded - save message before error if full_content or all_tool_calls: msg_id = str(uuid.uuid4()) self._save_message( conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps, actual_token_count, total_usage ) yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) except Exception as e: yield _sse_event("error", {"content": str(e)}) def _save_message( self, conversation_id: str, msg_id: str, full_content: str, all_tool_calls: list, all_tool_results: list, all_steps: list, token_count: int = 0, usage: dict = None ): """Save the assistant message to database.""" from luxx.database import SessionLocal from luxx.models import Message content_json = { "text": full_content, "steps": all_steps } if all_tool_calls: content_json["tool_calls"] = all_tool_calls db = SessionLocal() try: msg = Message( id=msg_id, conversation_id=conversation_id, role="assistant", content=json.dumps(content_json, ensure_ascii=False), token_count=token_count, usage=json.dumps(usage) if usage else None ) db.add(msg) db.commit() except Exception as e: db.rollback() raise finally: db.close() # Global chat service chat_service = ChatService()