"""WebSocket routes for agent real-time communication""" import asyncio import json import logging import threading from typing import Any, Dict, Optional, Set from datetime import datetime from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends import uuid from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult from luxx.agents.supervisor import SupervisorAgent from luxx.agents.worker import WorkerAgent from luxx.agents.dag_scheduler import DAGScheduler, SchedulerPool from luxx.agents.registry import AgentRegistry from luxx.services.llm_client import llm_client from luxx.tools.executor import ToolExecutor logger = logging.getLogger(__name__) router = APIRouter() class ConnectionManager: """ WebSocket Connection Manager Manages WebSocket connections for real-time agent progress updates. Features: - Connection tracking by task_id - Heartbeat mechanism - Progress broadcasting """ def __init__(self): # task_id -> set of websocket connections self._connections: Dict[str, Set[WebSocket]] = {} # websocket -> task_id mapping self._ws_to_task: Dict[WebSocket, str] = {} # heartbeat tasks self._heartbeat_tasks: Dict[WebSocket, asyncio.Task] = {} # lock for thread safety self._lock = threading.Lock() # heartbeat interval in seconds self._heartbeat_interval = 30 def connect(self, websocket: WebSocket, task_id: str) -> None: """ Accept and register a WebSocket connection Args: websocket: WebSocket connection task_id: Task ID to subscribe to """ websocket.accept() with self._lock: if task_id not in self._connections: self._connections[task_id] = set() self._connections[task_id].add(websocket) self._ws_to_task[websocket] = task_id logger.info(f"WebSocket connected for task {task_id}") def disconnect(self, websocket: WebSocket) -> None: """ Unregister a WebSocket connection Args: websocket: WebSocket connection """ with self._lock: task_id = self._ws_to_task.pop(websocket, None) if task_id and task_id in self._connections: self._connections[task_id].discard(websocket) if not self._connections[task_id]: del self._connections[task_id] # Cancel heartbeat task if websocket in self._heartbeat_tasks: self._heartbeat_tasks[websocket].cancel() del self._heartbeat_tasks[websocket] if task_id: logger.info(f"WebSocket disconnected for task {task_id}") async def send_to_task(self, task_id: str, message: Dict[str, Any]) -> None: """ Send message to all connections subscribed to a task Args: task_id: Task ID message: Message to send """ with self._lock: connections = self._connections.get(task_id, set()).copy() if not connections: return dead_connections = set() for websocket in connections: try: await websocket.send_json(message) except Exception as e: logger.warning(f"Failed to send to websocket: {e}") dead_connections.add(websocket) # Clean up dead connections for ws in dead_connections: self.disconnect(ws) async def broadcast(self, message: Dict[str, Any]) -> None: """ Broadcast message to all connections Args: message: Message to broadcast """ with self._lock: all_connections = list(self._ws_to_task.keys()) for websocket in all_connections: try: await websocket.send_json(message) except Exception as e: logger.warning(f"Failed to broadcast: {e}") async def send_personal(self, websocket: WebSocket, message: Dict[str, Any]) -> None: """ Send message to a specific connection Args: websocket: Target WebSocket message: Message to send """ try: await websocket.send_json(message) except Exception as e: logger.warning(f"Failed to send personal message: {e}") def start_heartbeat(self, websocket: WebSocket) -> None: """ Start heartbeat for a connection Args: websocket: WebSocket connection """ async def heartbeat_loop(): while True: await asyncio.sleep(self._heartbeat_interval) try: await websocket.send_json({ "type": "heartbeat", "interval": self._heartbeat_interval }) except Exception: break task = asyncio.create_task(heartbeat_loop()) with self._lock: self._heartbeat_tasks[websocket] = task @property def connection_count(self) -> int: """Total number of connections""" with self._lock: return len(self._ws_to_task) def get_task_connections(self, task_id: str) -> int: """Get number of connections for a task""" with self._lock: return len(self._connections.get(task_id, set())) # Global connection manager connection_manager = ConnectionManager() # Global scheduler pool scheduler_pool = SchedulerPool(max_concurrent=10) # Global tool executor tool_executor = ToolExecutor() class ProgressEmitter: """ Progress emitter that sends updates via WebSocket Wraps DAGScheduler progress callback to emit WebSocket messages. """ def __init__(self, task_id: str, connection_manager: ConnectionManager): self.task_id = task_id self.connection_manager = connection_manager def __call__(self, node_id: str, progress: float, message: str = "") -> None: """Progress callback""" if node_id == "dag": # DAG-level progress asyncio.create_task(self.connection_manager.send_to_task( self.task_id, { "type": "dag_progress", "data": { "progress": progress, "message": message } } )) else: # Node-level progress asyncio.create_task(self.connection_manager.send_to_task( self.task_id, { "type": "node_progress", "data": { "node_id": node_id, "progress": progress, "message": message } } )) @router.websocket("/ws/dag/{task_id}") async def dag_websocket(websocket: WebSocket, task_id: str): """ WebSocket endpoint for DAG progress updates Protocol: - Client sends: subscribe, get_status, ping, cancel_task - Server sends: subscribed, heartbeat, dag_start, node_start, node_progress, node_complete, node_error, dag_complete, pong """ # Accept connection connection_manager.connect(websocket, task_id) # Send subscribed confirmation await connection_manager.send_personal(websocket, { "type": "subscribed", "task_id": task_id }) # Start heartbeat connection_manager.start_heartbeat(websocket) # Get scheduler if exists scheduler = scheduler_pool.get(task_id) if scheduler: # Send current DAG state await connection_manager.send_personal(websocket, { "type": "dag_status", "data": scheduler.dag.to_dict() }) try: while True: # Receive message data = await websocket.receive_text() try: message = json.loads(data) except json.JSONDecodeError: await connection_manager.send_personal(websocket, { "type": "error", "message": "Invalid JSON" }) continue msg_type = message.get("type") if msg_type == "subscribe": # Already subscribed on connect await connection_manager.send_personal(websocket, { "type": "subscribed", "task_id": task_id }) elif msg_type == "get_status": scheduler = scheduler_pool.get(task_id) if scheduler: await connection_manager.send_personal(websocket, { "type": "dag_status", "data": scheduler.dag.to_dict() }) else: await connection_manager.send_personal(websocket, { "type": "dag_status", "data": None }) elif msg_type == "ping": await connection_manager.send_personal(websocket, { "type": "pong" }) elif msg_type == "cancel_task": if scheduler_pool.cancel(task_id): await connection_manager.send_personal(websocket, { "type": "task_cancelled", "task_id": task_id }) else: await connection_manager.send_personal(websocket, { "type": "error", "message": "Task not found or already completed" }) else: await connection_manager.send_personal(websocket, { "type": "error", "message": f"Unknown message type: {msg_type}" }) except WebSocketDisconnect: connection_manager.disconnect(websocket) except Exception as e: logger.error(f"WebSocket error: {e}") connection_manager.disconnect(websocket) # Helper functions to emit progress from scheduler def create_progress_emitter(task_id: str) -> ProgressEmitter: """Create a progress emitter for a task""" return ProgressEmitter(task_id, connection_manager) async def emit_dag_start(task_id: str, dag: DAG) -> None: """Emit DAG start event""" await connection_manager.send_to_task(task_id, { "type": "dag_start", "data": { "graph": dag.to_dict() } }) async def emit_node_start(task_id: str, node: TaskNode) -> None: """Emit node start event""" await connection_manager.send_to_task(task_id, { "type": "node_start", "data": { "node_id": node.id, "name": node.name, "status": node.status.value } }) async def emit_node_complete(task_id: str, node: TaskNode) -> None: """Emit node complete event""" await connection_manager.send_to_task(task_id, { "type": "node_complete", "data": { "node_id": node.id, "result": node.result.to_dict() if node.result else None, "output_data": node.output_data } }) async def emit_node_error(task_id: str, node: TaskNode, error: str) -> None: """Emit node error event""" await connection_manager.send_to_task(task_id, { "type": "node_error", "data": { "node_id": node.id, "error": error } }) async def emit_dag_complete(task_id: str, success: bool, results: Dict) -> None: """Emit DAG complete event""" await connection_manager.send_to_task(task_id, { "type": "dag_complete", "data": { "success": success, "results": results } })