Luxx/luxx/routes/agents_ws.py

386 lines
12 KiB
Python

"""WebSocket routes for agent real-time communication"""
import asyncio
import json
import logging
import threading
from typing import Any, Dict, Optional, Set
from datetime import datetime
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
import uuid
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
from luxx.agents.supervisor import SupervisorAgent
from luxx.agents.worker import WorkerAgent
from luxx.agents.dag_scheduler import DAGScheduler, SchedulerPool
from luxx.agents.registry import AgentRegistry
from luxx.services.llm_client import llm_client
from luxx.tools.executor import ToolExecutor
logger = logging.getLogger(__name__)
router = APIRouter()
class ConnectionManager:
"""
WebSocket Connection Manager
Manages WebSocket connections for real-time agent progress updates.
Features:
- Connection tracking by task_id
- Heartbeat mechanism
- Progress broadcasting
"""
def __init__(self):
# task_id -> set of websocket connections
self._connections: Dict[str, Set[WebSocket]] = {}
# websocket -> task_id mapping
self._ws_to_task: Dict[WebSocket, str] = {}
# heartbeat tasks
self._heartbeat_tasks: Dict[WebSocket, asyncio.Task] = {}
# lock for thread safety
self._lock = threading.Lock()
# heartbeat interval in seconds
self._heartbeat_interval = 30
def connect(self, websocket: WebSocket, task_id: str) -> None:
"""
Accept and register a WebSocket connection
Args:
websocket: WebSocket connection
task_id: Task ID to subscribe to
"""
websocket.accept()
with self._lock:
if task_id not in self._connections:
self._connections[task_id] = set()
self._connections[task_id].add(websocket)
self._ws_to_task[websocket] = task_id
logger.info(f"WebSocket connected for task {task_id}")
def disconnect(self, websocket: WebSocket) -> None:
"""
Unregister a WebSocket connection
Args:
websocket: WebSocket connection
"""
with self._lock:
task_id = self._ws_to_task.pop(websocket, None)
if task_id and task_id in self._connections:
self._connections[task_id].discard(websocket)
if not self._connections[task_id]:
del self._connections[task_id]
# Cancel heartbeat task
if websocket in self._heartbeat_tasks:
self._heartbeat_tasks[websocket].cancel()
del self._heartbeat_tasks[websocket]
if task_id:
logger.info(f"WebSocket disconnected for task {task_id}")
async def send_to_task(self, task_id: str, message: Dict[str, Any]) -> None:
"""
Send message to all connections subscribed to a task
Args:
task_id: Task ID
message: Message to send
"""
with self._lock:
connections = self._connections.get(task_id, set()).copy()
if not connections:
return
dead_connections = set()
for websocket in connections:
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"Failed to send to websocket: {e}")
dead_connections.add(websocket)
# Clean up dead connections
for ws in dead_connections:
self.disconnect(ws)
async def broadcast(self, message: Dict[str, Any]) -> None:
"""
Broadcast message to all connections
Args:
message: Message to broadcast
"""
with self._lock:
all_connections = list(self._ws_to_task.keys())
for websocket in all_connections:
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"Failed to broadcast: {e}")
async def send_personal(self, websocket: WebSocket, message: Dict[str, Any]) -> None:
"""
Send message to a specific connection
Args:
websocket: Target WebSocket
message: Message to send
"""
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"Failed to send personal message: {e}")
def start_heartbeat(self, websocket: WebSocket) -> None:
"""
Start heartbeat for a connection
Args:
websocket: WebSocket connection
"""
async def heartbeat_loop():
while True:
await asyncio.sleep(self._heartbeat_interval)
try:
await websocket.send_json({
"type": "heartbeat",
"interval": self._heartbeat_interval
})
except Exception:
break
task = asyncio.create_task(heartbeat_loop())
with self._lock:
self._heartbeat_tasks[websocket] = task
@property
def connection_count(self) -> int:
"""Total number of connections"""
with self._lock:
return len(self._ws_to_task)
def get_task_connections(self, task_id: str) -> int:
"""Get number of connections for a task"""
with self._lock:
return len(self._connections.get(task_id, set()))
# Global connection manager
connection_manager = ConnectionManager()
# Global scheduler pool
scheduler_pool = SchedulerPool(max_concurrent=10)
# Global tool executor
tool_executor = ToolExecutor()
class ProgressEmitter:
"""
Progress emitter that sends updates via WebSocket
Wraps DAGScheduler progress callback to emit WebSocket messages.
"""
def __init__(self, task_id: str, connection_manager: ConnectionManager):
self.task_id = task_id
self.connection_manager = connection_manager
def __call__(self, node_id: str, progress: float, message: str = "") -> None:
"""Progress callback"""
if node_id == "dag":
# DAG-level progress
asyncio.create_task(self.connection_manager.send_to_task(
self.task_id,
{
"type": "dag_progress",
"data": {
"progress": progress,
"message": message
}
}
))
else:
# Node-level progress
asyncio.create_task(self.connection_manager.send_to_task(
self.task_id,
{
"type": "node_progress",
"data": {
"node_id": node_id,
"progress": progress,
"message": message
}
}
))
@router.websocket("/ws/dag/{task_id}")
async def dag_websocket(websocket: WebSocket, task_id: str):
"""
WebSocket endpoint for DAG progress updates
Protocol:
- Client sends: subscribe, get_status, ping, cancel_task
- Server sends: subscribed, heartbeat, dag_start, node_start, node_progress,
node_complete, node_error, dag_complete, pong
"""
# Accept connection
connection_manager.connect(websocket, task_id)
# Send subscribed confirmation
await connection_manager.send_personal(websocket, {
"type": "subscribed",
"task_id": task_id
})
# Start heartbeat
connection_manager.start_heartbeat(websocket)
# Get scheduler if exists
scheduler = scheduler_pool.get(task_id)
if scheduler:
# Send current DAG state
await connection_manager.send_personal(websocket, {
"type": "dag_status",
"data": scheduler.dag.to_dict()
})
try:
while True:
# Receive message
data = await websocket.receive_text()
try:
message = json.loads(data)
except json.JSONDecodeError:
await connection_manager.send_personal(websocket, {
"type": "error",
"message": "Invalid JSON"
})
continue
msg_type = message.get("type")
if msg_type == "subscribe":
# Already subscribed on connect
await connection_manager.send_personal(websocket, {
"type": "subscribed",
"task_id": task_id
})
elif msg_type == "get_status":
scheduler = scheduler_pool.get(task_id)
if scheduler:
await connection_manager.send_personal(websocket, {
"type": "dag_status",
"data": scheduler.dag.to_dict()
})
else:
await connection_manager.send_personal(websocket, {
"type": "dag_status",
"data": None
})
elif msg_type == "ping":
await connection_manager.send_personal(websocket, {
"type": "pong"
})
elif msg_type == "cancel_task":
if scheduler_pool.cancel(task_id):
await connection_manager.send_personal(websocket, {
"type": "task_cancelled",
"task_id": task_id
})
else:
await connection_manager.send_personal(websocket, {
"type": "error",
"message": "Task not found or already completed"
})
else:
await connection_manager.send_personal(websocket, {
"type": "error",
"message": f"Unknown message type: {msg_type}"
})
except WebSocketDisconnect:
connection_manager.disconnect(websocket)
except Exception as e:
logger.error(f"WebSocket error: {e}")
connection_manager.disconnect(websocket)
# Helper functions to emit progress from scheduler
def create_progress_emitter(task_id: str) -> ProgressEmitter:
"""Create a progress emitter for a task"""
return ProgressEmitter(task_id, connection_manager)
async def emit_dag_start(task_id: str, dag: DAG) -> None:
"""Emit DAG start event"""
await connection_manager.send_to_task(task_id, {
"type": "dag_start",
"data": {
"graph": dag.to_dict()
}
})
async def emit_node_start(task_id: str, node: TaskNode) -> None:
"""Emit node start event"""
await connection_manager.send_to_task(task_id, {
"type": "node_start",
"data": {
"node_id": node.id,
"name": node.name,
"status": node.status.value
}
})
async def emit_node_complete(task_id: str, node: TaskNode) -> None:
"""Emit node complete event"""
await connection_manager.send_to_task(task_id, {
"type": "node_complete",
"data": {
"node_id": node.id,
"result": node.result.to_dict() if node.result else None,
"output_data": node.output_data
}
})
async def emit_node_error(task_id: str, node: TaskNode, error: str) -> None:
"""Emit node error event"""
await connection_manager.send_to_task(task_id, {
"type": "node_error",
"data": {
"node_id": node.id,
"error": error
}
})
async def emit_dag_complete(task_id: str, success: bool, results: Dict) -> None:
"""Emit DAG complete event"""
await connection_manager.send_to_task(task_id, {
"type": "dag_complete",
"data": {
"success": success,
"results": results
}
})