386 lines
12 KiB
Python
386 lines
12 KiB
Python
"""WebSocket routes for agent real-time communication"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import threading
|
|
from typing import Any, Dict, Optional, Set
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
|
import uuid
|
|
|
|
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
|
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
|
from luxx.agents.supervisor import SupervisorAgent
|
|
from luxx.agents.worker import WorkerAgent
|
|
from luxx.agents.dag_scheduler import DAGScheduler, SchedulerPool
|
|
from luxx.agents.registry import AgentRegistry
|
|
from luxx.services.llm_client import llm_client
|
|
from luxx.tools.executor import ToolExecutor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class ConnectionManager:
|
|
"""
|
|
WebSocket Connection Manager
|
|
|
|
Manages WebSocket connections for real-time agent progress updates.
|
|
Features:
|
|
- Connection tracking by task_id
|
|
- Heartbeat mechanism
|
|
- Progress broadcasting
|
|
"""
|
|
|
|
def __init__(self):
|
|
# task_id -> set of websocket connections
|
|
self._connections: Dict[str, Set[WebSocket]] = {}
|
|
# websocket -> task_id mapping
|
|
self._ws_to_task: Dict[WebSocket, str] = {}
|
|
# heartbeat tasks
|
|
self._heartbeat_tasks: Dict[WebSocket, asyncio.Task] = {}
|
|
# lock for thread safety
|
|
self._lock = threading.Lock()
|
|
# heartbeat interval in seconds
|
|
self._heartbeat_interval = 30
|
|
|
|
def connect(self, websocket: WebSocket, task_id: str) -> None:
|
|
"""
|
|
Accept and register a WebSocket connection
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
task_id: Task ID to subscribe to
|
|
"""
|
|
websocket.accept()
|
|
|
|
with self._lock:
|
|
if task_id not in self._connections:
|
|
self._connections[task_id] = set()
|
|
self._connections[task_id].add(websocket)
|
|
self._ws_to_task[websocket] = task_id
|
|
|
|
logger.info(f"WebSocket connected for task {task_id}")
|
|
|
|
def disconnect(self, websocket: WebSocket) -> None:
|
|
"""
|
|
Unregister a WebSocket connection
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
"""
|
|
with self._lock:
|
|
task_id = self._ws_to_task.pop(websocket, None)
|
|
if task_id and task_id in self._connections:
|
|
self._connections[task_id].discard(websocket)
|
|
if not self._connections[task_id]:
|
|
del self._connections[task_id]
|
|
|
|
# Cancel heartbeat task
|
|
if websocket in self._heartbeat_tasks:
|
|
self._heartbeat_tasks[websocket].cancel()
|
|
del self._heartbeat_tasks[websocket]
|
|
|
|
if task_id:
|
|
logger.info(f"WebSocket disconnected for task {task_id}")
|
|
|
|
async def send_to_task(self, task_id: str, message: Dict[str, Any]) -> None:
|
|
"""
|
|
Send message to all connections subscribed to a task
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
message: Message to send
|
|
"""
|
|
with self._lock:
|
|
connections = self._connections.get(task_id, set()).copy()
|
|
|
|
if not connections:
|
|
return
|
|
|
|
dead_connections = set()
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_json(message)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to send to websocket: {e}")
|
|
dead_connections.add(websocket)
|
|
|
|
# Clean up dead connections
|
|
for ws in dead_connections:
|
|
self.disconnect(ws)
|
|
|
|
async def broadcast(self, message: Dict[str, Any]) -> None:
|
|
"""
|
|
Broadcast message to all connections
|
|
|
|
Args:
|
|
message: Message to broadcast
|
|
"""
|
|
with self._lock:
|
|
all_connections = list(self._ws_to_task.keys())
|
|
|
|
for websocket in all_connections:
|
|
try:
|
|
await websocket.send_json(message)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to broadcast: {e}")
|
|
|
|
async def send_personal(self, websocket: WebSocket, message: Dict[str, Any]) -> None:
|
|
"""
|
|
Send message to a specific connection
|
|
|
|
Args:
|
|
websocket: Target WebSocket
|
|
message: Message to send
|
|
"""
|
|
try:
|
|
await websocket.send_json(message)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to send personal message: {e}")
|
|
|
|
def start_heartbeat(self, websocket: WebSocket) -> None:
|
|
"""
|
|
Start heartbeat for a connection
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
"""
|
|
async def heartbeat_loop():
|
|
while True:
|
|
await asyncio.sleep(self._heartbeat_interval)
|
|
try:
|
|
await websocket.send_json({
|
|
"type": "heartbeat",
|
|
"interval": self._heartbeat_interval
|
|
})
|
|
except Exception:
|
|
break
|
|
|
|
task = asyncio.create_task(heartbeat_loop())
|
|
with self._lock:
|
|
self._heartbeat_tasks[websocket] = task
|
|
|
|
@property
|
|
def connection_count(self) -> int:
|
|
"""Total number of connections"""
|
|
with self._lock:
|
|
return len(self._ws_to_task)
|
|
|
|
def get_task_connections(self, task_id: str) -> int:
|
|
"""Get number of connections for a task"""
|
|
with self._lock:
|
|
return len(self._connections.get(task_id, set()))
|
|
|
|
|
|
# Global connection manager
|
|
connection_manager = ConnectionManager()
|
|
|
|
# Global scheduler pool
|
|
scheduler_pool = SchedulerPool(max_concurrent=10)
|
|
|
|
# Global tool executor
|
|
tool_executor = ToolExecutor()
|
|
|
|
|
|
class ProgressEmitter:
|
|
"""
|
|
Progress emitter that sends updates via WebSocket
|
|
|
|
Wraps DAGScheduler progress callback to emit WebSocket messages.
|
|
"""
|
|
|
|
def __init__(self, task_id: str, connection_manager: ConnectionManager):
|
|
self.task_id = task_id
|
|
self.connection_manager = connection_manager
|
|
|
|
def __call__(self, node_id: str, progress: float, message: str = "") -> None:
|
|
"""Progress callback"""
|
|
if node_id == "dag":
|
|
# DAG-level progress
|
|
asyncio.create_task(self.connection_manager.send_to_task(
|
|
self.task_id,
|
|
{
|
|
"type": "dag_progress",
|
|
"data": {
|
|
"progress": progress,
|
|
"message": message
|
|
}
|
|
}
|
|
))
|
|
else:
|
|
# Node-level progress
|
|
asyncio.create_task(self.connection_manager.send_to_task(
|
|
self.task_id,
|
|
{
|
|
"type": "node_progress",
|
|
"data": {
|
|
"node_id": node_id,
|
|
"progress": progress,
|
|
"message": message
|
|
}
|
|
}
|
|
))
|
|
|
|
|
|
@router.websocket("/ws/dag/{task_id}")
|
|
async def dag_websocket(websocket: WebSocket, task_id: str):
|
|
"""
|
|
WebSocket endpoint for DAG progress updates
|
|
|
|
Protocol:
|
|
- Client sends: subscribe, get_status, ping, cancel_task
|
|
- Server sends: subscribed, heartbeat, dag_start, node_start, node_progress,
|
|
node_complete, node_error, dag_complete, pong
|
|
"""
|
|
# Accept connection
|
|
connection_manager.connect(websocket, task_id)
|
|
|
|
# Send subscribed confirmation
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "subscribed",
|
|
"task_id": task_id
|
|
})
|
|
|
|
# Start heartbeat
|
|
connection_manager.start_heartbeat(websocket)
|
|
|
|
# Get scheduler if exists
|
|
scheduler = scheduler_pool.get(task_id)
|
|
if scheduler:
|
|
# Send current DAG state
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "dag_status",
|
|
"data": scheduler.dag.to_dict()
|
|
})
|
|
|
|
try:
|
|
while True:
|
|
# Receive message
|
|
data = await websocket.receive_text()
|
|
|
|
try:
|
|
message = json.loads(data)
|
|
except json.JSONDecodeError:
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "error",
|
|
"message": "Invalid JSON"
|
|
})
|
|
continue
|
|
|
|
msg_type = message.get("type")
|
|
|
|
if msg_type == "subscribe":
|
|
# Already subscribed on connect
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "subscribed",
|
|
"task_id": task_id
|
|
})
|
|
|
|
elif msg_type == "get_status":
|
|
scheduler = scheduler_pool.get(task_id)
|
|
if scheduler:
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "dag_status",
|
|
"data": scheduler.dag.to_dict()
|
|
})
|
|
else:
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "dag_status",
|
|
"data": None
|
|
})
|
|
|
|
elif msg_type == "ping":
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "pong"
|
|
})
|
|
|
|
elif msg_type == "cancel_task":
|
|
if scheduler_pool.cancel(task_id):
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "task_cancelled",
|
|
"task_id": task_id
|
|
})
|
|
else:
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "error",
|
|
"message": "Task not found or already completed"
|
|
})
|
|
|
|
else:
|
|
await connection_manager.send_personal(websocket, {
|
|
"type": "error",
|
|
"message": f"Unknown message type: {msg_type}"
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
connection_manager.disconnect(websocket)
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {e}")
|
|
connection_manager.disconnect(websocket)
|
|
|
|
|
|
# Helper functions to emit progress from scheduler
|
|
|
|
def create_progress_emitter(task_id: str) -> ProgressEmitter:
|
|
"""Create a progress emitter for a task"""
|
|
return ProgressEmitter(task_id, connection_manager)
|
|
|
|
|
|
async def emit_dag_start(task_id: str, dag: DAG) -> None:
|
|
"""Emit DAG start event"""
|
|
await connection_manager.send_to_task(task_id, {
|
|
"type": "dag_start",
|
|
"data": {
|
|
"graph": dag.to_dict()
|
|
}
|
|
})
|
|
|
|
|
|
async def emit_node_start(task_id: str, node: TaskNode) -> None:
|
|
"""Emit node start event"""
|
|
await connection_manager.send_to_task(task_id, {
|
|
"type": "node_start",
|
|
"data": {
|
|
"node_id": node.id,
|
|
"name": node.name,
|
|
"status": node.status.value
|
|
}
|
|
})
|
|
|
|
|
|
async def emit_node_complete(task_id: str, node: TaskNode) -> None:
|
|
"""Emit node complete event"""
|
|
await connection_manager.send_to_task(task_id, {
|
|
"type": "node_complete",
|
|
"data": {
|
|
"node_id": node.id,
|
|
"result": node.result.to_dict() if node.result else None,
|
|
"output_data": node.output_data
|
|
}
|
|
})
|
|
|
|
|
|
async def emit_node_error(task_id: str, node: TaskNode, error: str) -> None:
|
|
"""Emit node error event"""
|
|
await connection_manager.send_to_task(task_id, {
|
|
"type": "node_error",
|
|
"data": {
|
|
"node_id": node.id,
|
|
"error": error
|
|
}
|
|
})
|
|
|
|
|
|
async def emit_dag_complete(task_id: str, success: bool, results: Dict) -> None:
|
|
"""Emit DAG complete event"""
|
|
await connection_manager.send_to_task(task_id, {
|
|
"type": "dag_complete",
|
|
"data": {
|
|
"success": success,
|
|
"results": results
|
|
}
|
|
})
|