diff --git a/luxx/agents/__init__.py b/luxx/agents/__init__.py new file mode 100644 index 0000000..b9dbf48 --- /dev/null +++ b/luxx/agents/__init__.py @@ -0,0 +1,16 @@ +"""Multi-Agent system module""" +from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus +from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult +from luxx.agents.registry import AgentRegistry + +__all__ = [ + "Agent", + "AgentConfig", + "AgentType", + "AgentStatus", + "DAG", + "TaskNode", + "TaskNodeStatus", + "TaskResult", + "AgentRegistry", +] diff --git a/luxx/agents/core.py b/luxx/agents/core.py new file mode 100644 index 0000000..5089ceb --- /dev/null +++ b/luxx/agents/core.py @@ -0,0 +1,161 @@ +"""Agent core models""" +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional +from enum import Enum +from datetime import datetime + +from luxx.tools.core import CommandPermission + + +class AgentType(str, Enum): + """Agent type enumeration""" + SUPERVISOR = "supervisor" + WORKER = "worker" + + +class AgentStatus(str, Enum): + """Agent status enumeration""" + IDLE = "idle" + PLANNING = "planning" + EXECUTING = "executing" + WAITING = "waiting" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class AgentConfig: + """Agent configuration""" + name: str + agent_type: AgentType + description: str = "" + max_permission: CommandPermission = CommandPermission.EXECUTE + max_turns: int = 10 # Context window: sliding window size + model: str = "deepseek-chat" + temperature: float = 0.7 + max_tokens: int = 4096 + system_prompt: str = "" + tools: List[str] = field(default_factory=list) # Tool names available to this agent + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Agent: + """ + Agent entity + + Represents an AI agent with its configuration, state, and context. + """ + id: str + config: AgentConfig + status: AgentStatus = AgentStatus.IDLE + user_id: Optional[int] = None + conversation_id: Optional[str] = None + workspace: Optional[str] = None + + # Runtime state + created_at: datetime = field(default_factory=datetime.utcnow) + updated_at: datetime = field(default_factory=datetime.utcnow) + + # Context management + context_window: List[Dict[str, Any]] = field(default_factory=list) + accumulated_result: Dict[str, Any] = field(default_factory=dict) + + # Progress tracking + current_task_id: Optional[str] = None + progress: float = 0.0 # 0.0 - 1.0 + + # Permission (effective permission = min(user_permission, agent.max_permission)) + effective_permission: CommandPermission = field(default_factory=lambda: CommandPermission.EXECUTE) + + def __post_init__(self): + """Post-initialization processing""" + if self.config.system_prompt and not self.context_window: + # Initialize with system prompt + self.context_window = [ + {"role": "system", "content": self.config.system_prompt} + ] + + def add_message(self, role: str, content: str) -> None: + """ + Add message to context window with sliding window management + + Args: + role: Message role (user/assistant/system) + content: Message content + """ + self.context_window.append({"role": role, "content": content}) + self._trim_context() + self.updated_at = datetime.utcnow() + + def _trim_context(self) -> None: + """ + Trim context window using sliding window strategy + + Keeps system prompt and the most recent N turns (user+assistant pairs) + """ + max_items = 1 + (self.config.max_turns * 2) # system + (max_turns * 2) + + # Always keep system prompt at index 0 + if len(self.context_window) > max_items and len(self.context_window) > 1: + # Keep system prompt and the most recent messages + system_prompt = self.context_window[0] + remaining = self.context_window[1:] + trimmed = remaining[-(max_items - 1):] + self.context_window = [system_prompt] + trimmed + + def get_context(self) -> List[Dict[str, Any]]: + """Get current context window""" + return self.context_window.copy() + + def set_user_permission(self, user_permission: CommandPermission) -> None: + """ + Set effective permission based on user and agent limits + + Effective permission = min(user_permission, agent.max_permission) + + Args: + user_permission: User's permission level + """ + self.effective_permission = min(user_permission, self.config.max_permission) + + def store_result(self, key: str, value: Any) -> None: + """ + Store result for supervisor's result-based context management + + Args: + key: Result key + value: Result value + """ + self.accumulated_result[key] = value + self.updated_at = datetime.utcnow() + + def get_result(self, key: str) -> Optional[Any]: + """Get stored result by key""" + return self.accumulated_result.get(key) + + def clear_context(self) -> None: + """Clear context but keep system prompt""" + if self.context_window and self.context_window[0]["role"] == "system": + system_prompt = self.context_window[0] + self.context_window = [system_prompt] + else: + self.context_window = [] + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization""" + return { + "id": self.id, + "name": self.config.name, + "type": self.config.agent_type.value, + "status": self.status.value, + "user_id": self.user_id, + "conversation_id": self.conversation_id, + "workspace": self.workspace, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + "current_task_id": self.current_task_id, + "progress": self.progress, + "effective_permission": self.effective_permission.name, + } diff --git a/luxx/agents/dag.py b/luxx/agents/dag.py new file mode 100644 index 0000000..71a2d20 --- /dev/null +++ b/luxx/agents/dag.py @@ -0,0 +1,418 @@ +"""DAG (Directed Acyclic Graph) and TaskNode models""" +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set +from enum import Enum +from datetime import datetime +import uuid + + +class TaskNodeStatus(str, Enum): + """Task node status enumeration""" + PENDING = "pending" # Not yet started + READY = "ready" # Dependencies satisfied, can start + RUNNING = "running" # Currently executing + COMPLETED = "completed" # Successfully completed + FAILED = "failed" # Execution failed + CANCELLED = "cancelled" # Cancelled by user + BLOCKED = "blocked" # Blocked by failed dependency + + +@dataclass +class TaskResult: + """Task execution result""" + success: bool + data: Any = None + error: Optional[str] = None + output_data: Optional[Dict[str, Any]] = None # Structured output for supervisor + execution_time: float = 0.0 # seconds + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "success": self.success, + "data": self.data, + "error": self.error, + "output_data": self.output_data, + "execution_time": self.execution_time, + } + + @classmethod + def ok(cls, data: Any = None, output_data: Optional[Dict[str, Any]] = None, + execution_time: float = 0.0) -> "TaskResult": + """Create success result""" + return cls(success=True, data=data, output_data=output_data, execution_time=execution_time) + + @classmethod + def fail(cls, error: str, data: Any = None) -> "TaskResult": + """Create failure result""" + return cls(success=False, error=error, data=data) + + +@dataclass +class TaskNode: + """ + Task node in the DAG + + Represents a single executable task within the agent workflow. + """ + id: str + name: str + description: str = "" + + # Task definition + task_type: str = "generic" # e.g., "code", "shell", "file", "llm" + task_data: Dict[str, Any] = field(default_factory=dict) # Task-specific parameters + + # Dependencies (node IDs that must complete before this node) + dependencies: List[str] = field(default_factory=list) + + # Status tracking + status: TaskNodeStatus = TaskNodeStatus.PENDING + result: Optional[TaskResult] = None + + # Progress tracking + progress: float = 0.0 # 0.0 - 1.0 + progress_message: str = "" + + # Execution info + assigned_agent_id: Optional[str] = None + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + # Output data (key-value pairs for dependent tasks) + output_data: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + """Post-initialization""" + if not self.id: + self.id = str(uuid.uuid4())[:8] + + @property + def is_root(self) -> bool: + """Check if this is a root node (no dependencies)""" + return len(self.dependencies) == 0 + + @property + def is_leaf(self) -> bool: + """Check if this is a leaf node (no dependents)""" + return False # Will be set by DAG + + @property + def execution_time(self) -> float: + """Calculate execution time in seconds""" + if self.started_at and self.completed_at: + return (self.completed_at - self.started_at).total_seconds() + return 0.0 + + def mark_ready(self) -> None: + """Mark node as ready to execute""" + self.status = TaskNodeStatus.READY + + def mark_running(self, agent_id: str) -> None: + """Mark node as running""" + self.status = TaskNodeStatus.RUNNING + self.assigned_agent_id = agent_id + self.started_at = datetime.utcnow() + + def mark_completed(self, result: TaskResult) -> None: + """Mark node as completed""" + self.status = TaskNodeStatus.COMPLETED + self.result = result + self.completed_at = datetime.utcnow() + self.progress = 1.0 + if result.output_data: + self.output_data = result.output_data + + def mark_failed(self, error: str) -> None: + """Mark node as failed""" + self.status = TaskNodeStatus.FAILED + self.result = TaskResult.fail(error) + self.completed_at = datetime.utcnow() + + def mark_cancelled(self) -> None: + """Mark node as cancelled""" + self.status = TaskNodeStatus.CANCELLED + self.completed_at = datetime.utcnow() + + def update_progress(self, progress: float, message: str = "") -> None: + """Update execution progress""" + self.progress = max(0.0, min(1.0, progress)) + if message: + self.progress_message = message + + def can_execute(self, completed_nodes: Set[str]) -> bool: + """ + Check if this node can execute (all dependencies completed) + + Args: + completed_nodes: Set of completed node IDs + """ + if self.status != TaskNodeStatus.READY: + return False + return all(dep_id in completed_nodes for dep_id in self.dependencies) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "task_type": self.task_type, + "task_data": self.task_data, + "dependencies": self.dependencies, + "status": self.status.value, + "progress": self.progress, + "progress_message": self.progress_message, + "assigned_agent_id": self.assigned_agent_id, + "result": self.result.to_dict() if self.result else None, + "output_data": self.output_data, + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + } + + +@dataclass +class DAG: + """ + Directed Acyclic Graph for task scheduling + + Manages the task workflow with dependency tracking and parallel execution support. + """ + id: str + name: str = "" + description: str = "" + + # Nodes and edges + nodes: Dict[str, TaskNode] = field(default_factory=dict) + edges: List[tuple] = field(default_factory=list) # (from_id, to_id) + + # Metadata + created_at: datetime = field(default_factory=datetime.utcnow) + updated_at: datetime = field(default_factory=datetime.utcnow) + + # Root and leaf tracking + _root_nodes: Set[str] = field(default_factory=set) + _leaf_nodes: Set[str] = field(default_factory=set) + + def __post_init__(self): + """Post-initialization""" + if not self.id: + self.id = str(uuid.uuid4())[:8] + + def add_node(self, node: TaskNode) -> None: + """ + Add a task node to the DAG + + Args: + node: TaskNode to add + """ + self.nodes[node.id] = node + self._update_root_leaf_cache() + self.updated_at = datetime.utcnow() + + def add_edge(self, from_id: str, to_id: str) -> None: + """ + Add an edge (dependency) between nodes + + Args: + from_id: Source node ID (must complete first) + to_id: Target node ID (depends on source) + """ + if from_id not in self.nodes: + raise ValueError(f"Source node '{from_id}' not found") + if to_id not in self.nodes: + raise ValueError(f"Target node '{to_id}' not found") + + # Add dependency to target node + if from_id not in self.nodes[to_id].dependencies: + self.nodes[to_id].dependencies.append(from_id) + + self.edges.append((from_id, to_id)) + self._update_root_leaf_cache() + self.updated_at = datetime.utcnow() + + def _update_root_leaf_cache(self) -> None: + """Update cached root and leaf node sets""" + self._root_nodes = set() + self._leaf_nodes = set() + + for node_id in self.nodes: + if len(self.nodes[node_id].dependencies) == 0: + self._root_nodes.add(node_id) + + # Check if node is a leaf (no one depends on it) + is_leaf = True + for other in self.nodes.values(): + if node_id in other.dependencies: + is_leaf = False + break + if is_leaf: + self._leaf_nodes.add(node_id) + + @property + def root_nodes(self) -> List[TaskNode]: + """Get all root nodes (nodes with no dependencies)""" + return [self.nodes[nid] for nid in self._root_nodes if nid in self.nodes] + + @property + def leaf_nodes(self) -> List[TaskNode]: + """Get all leaf nodes (nodes no one depends on)""" + return [self.nodes[nid] for nid in self._leaf_nodes if nid in self.nodes] + + def get_ready_nodes(self, completed_nodes: Set[str]) -> List[TaskNode]: + """ + Get all nodes that are ready to execute + + A node is ready if: + 1. All its dependencies are in completed_nodes + 2. It is not already running or completed + + Args: + completed_nodes: Set of completed node IDs + """ + ready = [] + for node in self.nodes.values(): + if node.status == TaskNodeStatus.READY and node.can_execute(completed_nodes): + ready.append(node) + return ready + + def get_blocked_nodes(self, failed_node_id: str) -> List[TaskNode]: + """ + Get all nodes blocked by a failed node + + Args: + failed_node_id: ID of the failed node + """ + blocked = [] + for node in self.nodes.values(): + if failed_node_id in node.dependencies: + if node.status not in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED): + blocked.append(node) + return blocked + + def mark_node_completed(self, node_id: str, result: TaskResult) -> None: + """Mark a node as completed with result""" + if node_id in self.nodes: + self.nodes[node_id].mark_completed(result) + self.updated_at = datetime.utcnow() + + def mark_node_failed(self, node_id: str, error: str) -> None: + """Mark a node as failed""" + if node_id in self.nodes: + self.nodes[node_id].mark_failed(error) + self._propagate_failure(node_id) + self.updated_at = datetime.utcnow() + + def _propagate_failure(self, failed_node_id: str) -> None: + """Propagate failure to dependent nodes""" + for node in self.nodes.values(): + if failed_node_id in node.dependencies: + if node.status == TaskNodeStatus.READY: + node.status = TaskNodeStatus.BLOCKED + + @property + def completed_count(self) -> int: + """Count of completed nodes""" + return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.COMPLETED) + + @property + def failed_count(self) -> int: + """Count of failed nodes""" + return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.FAILED) + + @property + def total_count(self) -> int: + """Total number of nodes""" + return len(self.nodes) + + @property + def progress(self) -> float: + """Overall DAG progress (0.0 - 1.0)""" + if not self.nodes: + return 0.0 + return sum(n.progress for n in self.nodes.values()) / len(self.nodes) + + @property + def is_complete(self) -> bool: + """Check if DAG execution is complete (all nodes done)""" + return all( + n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED) + for n in self.nodes.values() + ) + + @property + def is_success(self) -> bool: + """Check if DAG completed successfully""" + return self.is_complete and self.failed_count == 0 + + def get_execution_order(self) -> List[List[TaskNode]]: + """ + Get nodes grouped by execution level (parallel-friendly order) + + Returns: + List of node groups, where each group can be executed in parallel + """ + levels: List[List[TaskNode]] = [] + completed: Set[str] = set() + + while len(completed) < len(self.nodes): + # Find nodes whose dependencies are all completed + current_level = [] + for node in self.nodes.values(): + if node.id not in completed and node.can_execute(completed): + current_level.append(node) + + if not current_level: + break # Circular dependency or all blocked + + levels.append(current_level) + completed.update(n.id for n in current_level) + + return levels + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "nodes": [n.to_dict() for n in self.nodes.values()], + "edges": [{"from": e[0], "to": e[1]} for e in self.edges], + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + "progress": self.progress, + "completed_count": self.completed_count, + "total_count": self.total_count, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DAG": + """ + Create DAG from dictionary + + Args: + data: Dictionary with DAG data + """ + dag = cls( + id=data.get("id", str(uuid.uuid4())[:8]), + name=data.get("name", ""), + description=data.get("description", ""), + ) + + # Recreate nodes + for node_data in data.get("nodes", []): + node = TaskNode( + id=node_data["id"], + name=node_data["name"], + description=node_data.get("description", ""), + task_type=node_data.get("task_type", "generic"), + task_data=node_data.get("task_data", {}), + dependencies=node_data.get("dependencies", []), + ) + dag.add_node(node) + + # Recreate edges + for edge in data.get("edges", []): + dag.add_edge(edge["from"], edge["to"]) + + return dag diff --git a/luxx/agents/dag_scheduler.py b/luxx/agents/dag_scheduler.py new file mode 100644 index 0000000..ccd8f2d --- /dev/null +++ b/luxx/agents/dag_scheduler.py @@ -0,0 +1,360 @@ +"""DAG Scheduler - orchestrates parallel task execution""" +import asyncio +import logging +from typing import Any, Callable, Dict, List, Optional, Set +from concurrent.futures import ThreadPoolExecutor +import threading + +from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus +from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult +from luxx.agents.supervisor import SupervisorAgent +from luxx.agents.worker import WorkerAgent +from luxx.tools.executor import ToolExecutor + +logger = logging.getLogger(__name__) + + +class DAGScheduler: + """ + DAG Scheduler + + Orchestrates parallel execution of tasks based on DAG structure. + Features: + - Parallel execution with max_workers limit + - Dependency-aware scheduling + - Real-time progress tracking + - Cancellation support + """ + + def __init__( + self, + dag: DAG, + supervisor: SupervisorAgent, + worker_factory: Callable[[], WorkerAgent], + max_workers: int = 3, + tool_executor: ToolExecutor = None + ): + """ + Initialize DAG Scheduler + + Args: + dag: DAG to execute + supervisor: Supervisor agent instance + worker_factory: Factory function to create worker agents + max_workers: Maximum parallel workers + tool_executor: Tool executor instance + """ + self.dag = dag + self.supervisor = supervisor + self.worker_factory = worker_factory + self.max_workers = max_workers + self.tool_executor = tool_executor or ToolExecutor() + + # Execution state + self._running_nodes: Set[str] = set() + self._completed_nodes: Set[str] = set() + self._node_results: Dict[str, TaskResult] = {} + self._parent_outputs: Dict[str, Dict[str, Any]] = {} + + # Control flags + self._cancelled = False + self._cancel_lock = threading.Lock() + + # Progress callback + self._progress_callback: Optional[Callable] = None + + def set_progress_callback(self, callback: Optional[Callable]) -> None: + """Set progress callback for real-time updates""" + self._progress_callback = callback + + def _emit_progress(self, node_id: str, progress: float, message: str = "") -> None: + """Emit progress update""" + if self._progress_callback: + self._progress_callback(node_id, progress, message) + + async def execute( + self, + context: Dict[str, Any], + task: str + ) -> Dict[str, Any]: + """ + Execute the DAG + + Args: + context: Execution context (workspace, user info, etc.) + task: Original task description + + Returns: + Execution results + """ + self._cancelled = False + self._running_nodes.clear() + self._completed_nodes.clear() + self._node_results.clear() + self._parent_outputs.clear() + + # Emit DAG start + self._emit_progress("dag", 0.0, "Starting DAG execution") + + # Initialize nodes - mark root nodes as ready + for node in self.dag.root_nodes: + node.mark_ready() + + # Main execution loop + while not self._is_execution_complete(): + if self._is_cancelled(): + await self._cancel_running_nodes() + return self._build_cancelled_result() + + # Get ready nodes that can be executed + ready_nodes = self._get_ready_nodes() + + if not ready_nodes and not self._running_nodes: + # No ready nodes and nothing running - deadlock or all blocked + logger.warning("No ready nodes and no running nodes - possible deadlock") + break + + # Launch ready nodes up to max_workers + nodes_to_launch = [ + n for n in ready_nodes + if len(self._running_nodes) < self.max_workers + and n.id not in self._running_nodes + ] + + for node in nodes_to_launch: + asyncio.create_task(self._execute_node(node, context)) + + # Small delay to prevent busy waiting + await asyncio.sleep(0.1) + + # Emit DAG completion + success = self.dag.is_success + self._emit_progress("dag", 1.0, "DAG execution complete" if success else "DAG execution failed") + + return self._build_result() + + def _is_execution_complete(self) -> bool: + """Check if execution is complete""" + return all( + n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED, TaskNodeStatus.BLOCKED) + for n in self.dag.nodes.values() + ) + + def _is_cancelled(self) -> bool: + """Check if execution was cancelled""" + with self._cancel_lock: + return self._cancelled + + def cancel(self) -> None: + """Cancel execution""" + with self._cancel_lock: + self._cancelled = True + self._emit_progress("dag", 0.0, "Execution cancelled") + + def _get_ready_nodes(self) -> List[TaskNode]: + """Get nodes that are ready to execute""" + return [ + n for n in self.dag.nodes.values() + if n.status == TaskNodeStatus.READY + and n.id not in self._running_nodes + and n.can_execute(self._completed_nodes) + ] + + async def _execute_node(self, node: TaskNode, context: Dict[str, Any]) -> None: + """ + Execute a single node + + Args: + node: Node to execute + context: Execution context + """ + if self._is_cancelled(): + node.mark_cancelled() + return + + # Mark node as running + self._running_nodes.add(node.id) + node.mark_running(self.supervisor.agent.id) + + self._emit_progress(node.id, 0.0, f"Starting: {node.name}") + + # Collect parent outputs for this node + parent_outputs = {} + for dep_id in node.dependencies: + if dep_id in self._node_results: + parent_outputs[dep_id] = self._node_results[dep_id].output_data or {} + + # Create worker for this task + worker = self.worker_factory() + + # Define progress callback for this node + def node_progress(progress: float, message: str = ""): + node.update_progress(progress, message) + self._emit_progress(node.id, progress, message) + + try: + # Execute task + result = await worker.execute_task( + node, + context, + parent_outputs=parent_outputs, + progress_callback=node_progress + ) + + # Store result + self._node_results[node.id] = result + + if result.success: + node.mark_completed(result) + self._completed_nodes.add(node.id) + self._emit_progress(node.id, 1.0, f"Completed: {node.name}") + + # Check if any blocked nodes can now run + self._unblock_nodes() + else: + node.mark_failed(result.error or "Unknown error") + self._emit_progress(node.id, 0.0, f"Failed: {node.name} - {result.error}") + + # Block dependent nodes + self._block_dependent_nodes(node.id) + + except Exception as e: + logger.error(f"Node {node.id} execution error: {e}") + node.mark_failed(str(e)) + self._node_results[node.id] = TaskResult.fail(error=str(e)) + self._block_dependent_nodes(node.id) + + finally: + self._running_nodes.discard(node.id) + + def _unblock_nodes(self) -> None: + """Unblock nodes whose dependencies are now satisfied""" + for node in self.dag.nodes.values(): + if node.status == TaskNodeStatus.BLOCKED: + if node.can_execute(self._completed_nodes): + node.mark_ready() + self._emit_progress(node.id, 0.0, f"Unblocked: {node.name}") + + def _block_dependent_nodes(self, failed_node_id: str) -> None: + """Block nodes that depend on a failed node""" + blocked = self.dag.get_blocked_nodes(failed_node_id) + for node in blocked: + node.status = TaskNodeStatus.BLOCKED + self._emit_progress(node.id, 0.0, f"Blocked due to: {failed_node_id}") + + async def _cancel_running_nodes(self) -> None: + """Cancel all running nodes""" + for node_id in self._running_nodes: + if node_id in self.dag.nodes: + self.dag.nodes[node_id].mark_cancelled() + self._running_nodes.clear() + + def _build_result(self) -> Dict[str, Any]: + """Build final result""" + return { + "success": self.dag.is_success, + "dag_id": self.dag.id, + "total_tasks": self.dag.total_count, + "completed_tasks": self.dag.completed_count, + "failed_tasks": self.dag.failed_count, + "progress": self.dag.progress, + "results": { + node_id: result.to_dict() + for node_id, result in self._node_results.items() + } + } + + def _build_cancelled_result(self) -> Dict[str, Any]: + """Build cancelled result""" + return { + "success": False, + "cancelled": True, + "dag_id": self.dag.id, + "total_tasks": self.dag.total_count, + "completed_tasks": self.dag.completed_count, + "failed_tasks": self.dag.failed_count, + "progress": self.dag.progress, + "results": { + node_id: result.to_dict() + for node_id, result in self._node_results.items() + } + } + + +class SchedulerPool: + """ + Pool of DAG schedulers for managing multiple concurrent DAG executions + """ + + def __init__(self, max_concurrent: int = 10): + """ + Initialize scheduler pool + + Args: + max_concurrent: Maximum concurrent DAG executions + """ + self.max_concurrent = max_concurrent + self._schedulers: Dict[str, DAGScheduler] = {} + self._lock = threading.Lock() + + def create_scheduler( + self, + task_id: str, + dag: DAG, + supervisor: SupervisorAgent, + worker_factory: Callable, + max_workers: int = 3 + ) -> DAGScheduler: + """ + Create and register a new scheduler + + Args: + task_id: Unique task identifier + dag: DAG to execute + supervisor: Supervisor agent + worker_factory: Worker factory function + max_workers: Max parallel workers + + Returns: + Created scheduler + """ + with self._lock: + if len(self._schedulers) >= self.max_concurrent: + raise RuntimeError("Maximum concurrent schedulers reached") + + scheduler = DAGScheduler( + dag=dag, + supervisor=supervisor, + worker_factory=worker_factory, + max_workers=max_workers + ) + + self._schedulers[task_id] = scheduler + return scheduler + + def get(self, task_id: str) -> Optional[DAGScheduler]: + """Get scheduler by task ID""" + with self._lock: + return self._schedulers.get(task_id) + + def remove(self, task_id: str) -> bool: + """Remove scheduler""" + with self._lock: + if task_id in self._schedulers: + del self._schedulers[task_id] + return True + return False + + def cancel(self, task_id: str) -> bool: + """Cancel a scheduler""" + scheduler = self.get(task_id) + if scheduler: + scheduler.cancel() + return True + return False + + @property + def active_count(self) -> int: + """Number of active schedulers""" + with self._lock: + return len(self._schedulers) diff --git a/luxx/agents/registry.py b/luxx/agents/registry.py new file mode 100644 index 0000000..2ebee3e --- /dev/null +++ b/luxx/agents/registry.py @@ -0,0 +1,266 @@ +"""Agent Registry - manages agent lifecycle and access""" +from typing import Dict, List, Optional, Set +import threading +import uuid + +from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus + + +class AgentRegistry: + """ + Agent Registry (Singleton) + + Thread-safe registry for managing all agents in the system. + Provides agent creation, retrieval, and lifecycle management. + """ + _instance: Optional["AgentRegistry"] = None + _lock: threading.Lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + self._agents: Dict[str, Agent] = {} + self._user_agents: Dict[int, Set[str]] = {} # user_id -> set of agent_ids + self._conversation_agents: Dict[str, Set[str]] = {} # conversation_id -> set of agent_ids + self._registry_lock = threading.Lock() + self._initialized = True + + def _generate_agent_id(self) -> str: + """Generate unique agent ID""" + return f"agent_{uuid.uuid4().hex[:12]}" + + def create_agent( + self, + config: AgentConfig, + user_id: Optional[int] = None, + conversation_id: Optional[str] = None, + workspace: Optional[str] = None, + ) -> Agent: + """ + Create and register a new agent + + Args: + config: Agent configuration + user_id: Associated user ID + conversation_id: Associated conversation ID + workspace: Agent's workspace path + + Returns: + Created Agent instance + """ + with self._registry_lock: + agent_id = self._generate_agent_id() + agent = Agent( + id=agent_id, + config=config, + user_id=user_id, + conversation_id=conversation_id, + workspace=workspace, + ) + + self._agents[agent_id] = agent + + # Track by user + if user_id: + if user_id not in self._user_agents: + self._user_agents[user_id] = set() + self._user_agents[user_id].add(agent_id) + + # Track by conversation + if conversation_id: + if conversation_id not in self._conversation_agents: + self._conversation_agents[conversation_id] = set() + self._conversation_agents[conversation_id].add(agent_id) + + return agent + + def get(self, agent_id: str) -> Optional[Agent]: + """ + Get agent by ID + + Args: + agent_id: Agent ID + + Returns: + Agent if found, None otherwise + """ + with self._registry_lock: + return self._agents.get(agent_id) + + def list_user_agents(self, user_id: int) -> List[Agent]: + """ + List all agents for a user + + Args: + user_id: User ID + + Returns: + List of user's agents + """ + with self._registry_lock: + agent_ids = self._user_agents.get(user_id, set()) + return [self._agents[aid] for aid in agent_ids if aid in self._agents] + + def list_conversation_agents(self, conversation_id: str) -> List[Agent]: + """ + List all agents in a conversation + + Args: + conversation_id: Conversation ID + + Returns: + List of conversation's agents + """ + with self._registry_lock: + agent_ids = self._conversation_agents.get(conversation_id, set()) + return [self._agents[aid] for aid in agent_ids if aid in self._agents] + + def list_by_type(self, agent_type: AgentType) -> List[Agent]: + """ + List all agents of a specific type + + Args: + agent_type: Type of agent to filter + + Returns: + List of agents of the specified type + """ + with self._registry_lock: + return [ + a for a in self._agents.values() + if a.config.agent_type == agent_type + ] + + def list_by_status(self, status: AgentStatus) -> List[Agent]: + """ + List all agents with a specific status + + Args: + status: Status to filter + + Returns: + List of agents with the specified status + """ + with self._registry_lock: + return [a for a in self._agents.values() if a.status == status] + + def update_status(self, agent_id: str, status: AgentStatus) -> bool: + """ + Update agent status + + Args: + agent_id: Agent ID + status: New status + + Returns: + True if updated, False if agent not found + """ + with self._registry_lock: + agent = self._agents.get(agent_id) + if agent: + agent.status = status + return True + return False + + def remove(self, agent_id: str) -> bool: + """ + Remove agent from registry + + Args: + agent_id: Agent ID + + Returns: + True if removed, False if not found + """ + with self._registry_lock: + agent = self._agents.pop(agent_id, None) + if agent is None: + return False + + # Remove from user tracking + if agent.user_id and agent.user_id in self._user_agents: + self._user_agents[agent.user_id].discard(agent_id) + + # Remove from conversation tracking + if agent.conversation_id and agent.conversation_id in self._conversation_agents: + self._conversation_agents[agent.conversation_id].discard(agent_id) + + return True + + def remove_user_agents(self, user_id: int) -> int: + """ + Remove all agents for a user + + Args: + user_id: User ID + + Returns: + Number of agents removed + """ + with self._registry_lock: + agent_ids = self._user_agents.pop(user_id, set()) + count = 0 + for agent_id in agent_ids: + if agent_id in self._agents: + del self._agents[agent_id] + count += 1 + return count + + def remove_conversation_agents(self, conversation_id: str) -> int: + """ + Remove all agents in a conversation + + Args: + conversation_id: Conversation ID + + Returns: + Number of agents removed + """ + with self._registry_lock: + agent_ids = self._conversation_agents.pop(conversation_id, set()) + count = 0 + for agent_id in agent_ids: + if agent_id in self._agents: + del self._agents[agent_id] + count += 1 + return count + + def clear(self) -> None: + """Clear all agents from registry""" + with self._registry_lock: + self._agents.clear() + self._user_agents.clear() + self._conversation_agents.clear() + + @property + def agent_count(self) -> int: + """Total number of agents""" + with self._registry_lock: + return len(self._agents) + + def get_stats(self) -> Dict: + """Get registry statistics""" + with self._registry_lock: + status_counts = {} + type_counts = {} + for agent in self._agents.values(): + status_counts[agent.status.value] = status_counts.get(agent.status.value, 0) + 1 + type_counts[agent.config.agent_type.value] = type_counts.get(agent.config.agent_type.value, 0) + 1 + + return { + "total": len(self._agents), + "by_status": status_counts, + "by_type": type_counts, + } + + +# Global registry instance +registry = AgentRegistry() diff --git a/luxx/agents/supervisor.py b/luxx/agents/supervisor.py new file mode 100644 index 0000000..1f1cb7c --- /dev/null +++ b/luxx/agents/supervisor.py @@ -0,0 +1,345 @@ +"""Supervisor Agent - task decomposition and result integration""" +import json +import logging +from typing import Any, Dict, List, Optional, Callable + +from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus +from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult +from luxx.services.llm_client import llm_client +from luxx.tools.core import registry as tool_registry + +logger = logging.getLogger(__name__) + + +class SupervisorAgent: + """ + Supervisor Agent + + Responsible for: + - Task decomposition using LLM + - Generating DAG (task graph) + - Result integration from workers + """ + + # System prompt for task decomposition + DEFAULT_SYSTEM_PROMPT = """You are a Supervisor Agent that decomposes complex tasks into executable subtasks. + +Your responsibilities: +1. Analyze the user's task and break it down into smaller, manageable subtasks +2. Create a DAG (Directed Acyclic Graph) where nodes are subtasks and edges represent dependencies +3. Each subtask should be specific and actionable +4. Consider parallel execution opportunities - tasks without dependencies can run concurrently +5. Store key results from subtasks for final integration + +Output format for task decomposition: +{ + "task_name": "Overall task name", + "task_description": "Description of what needs to be accomplished", + "nodes": [ + { + "id": "task_001", + "name": "Task name", + "description": "What this task does", + "task_type": "code|shell|file|llm|generic", + "task_data": {...}, # Task-specific parameters + "dependencies": [] # IDs of tasks that must complete first + } + ] +} + +Guidelines: +- Keep tasks focused and atomic +- Use meaningful task IDs (e.g., task_001, task_002) +- Mark parallelizable tasks with no dependencies +- Maximum 10 subtasks for a single decomposition +- Include only the output_data that matters for dependent tasks or final result +""" + + def __init__( + self, + agent: Agent, + llm_client=None, + max_subtasks: int = 10 + ): + """ + Initialize Supervisor Agent + + Args: + agent: Agent instance (should be SUPERVISOR type) + llm_client: LLM client instance + max_subtasks: Maximum number of subtasks to generate + """ + self.agent = agent + self.llm_client = llm_client or llm_client + self.max_subtasks = max_subtasks + + # Ensure agent has supervisor system prompt + if not self.agent.config.system_prompt: + self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT + + async def decompose_task( + self, + task: str, + context: Dict[str, Any], + progress_callback: Optional[Callable] = None + ) -> DAG: + """ + Decompose a task into subtasks using LLM + + Args: + task: User's task description + context: Execution context (workspace, user info, etc.) + progress_callback: Optional callback for progress updates + + Returns: + DAG representing the task decomposition + """ + self.agent.status = AgentStatus.PLANNING + + if progress_callback: + progress_callback(0.1, "Analyzing task...") + + # Build messages for LLM + messages = self.agent.get_context() + messages.append({ + "role": "user", + "content": f"Decompose this task into subtasks:\n{task}" + }) + + if progress_callback: + progress_callback(0.2, "Calling LLM for task decomposition...") + + # Call LLM + try: + response = await self.llm_client.sync_call( + model=self.agent.config.model, + messages=messages, + temperature=self.agent.config.temperature, + max_tokens=self.agent.config.max_tokens + ) + + if progress_callback: + progress_callback(0.5, "Processing decomposition...") + + # Parse LLM response to extract DAG + dag = self._parse_dag_from_response(response.content, task) + + # Add assistant response to context + self.agent.add_message("assistant", response.content) + + if progress_callback: + progress_callback(0.9, "Task decomposition complete") + + self.agent.status = AgentStatus.IDLE + return dag + + except Exception as e: + logger.error(f"Task decomposition failed: {e}") + self.agent.status = AgentStatus.FAILED + raise + + def _parse_dag_from_response(self, content: str, original_task: str) -> DAG: + """ + Parse LLM response to extract DAG structure + + Args: + content: LLM response content + original_task: Original task description + + Returns: + DAG instance + """ + # Try to extract JSON from response + dag_data = self._extract_json(content) + + if not dag_data: + # Fallback: create a simple single-node DAG + logger.warning("Could not parse DAG from LLM response, creating simple DAG") + dag = DAG( + id=f"dag_{self.agent.id}", + name=original_task[:50], + description=original_task + ) + node = TaskNode( + id="task_001", + name="Execute Task", + description=original_task, + task_type="llm", + task_data={"prompt": original_task} + ) + dag.add_node(node) + return dag + + # Build DAG from parsed data + dag = DAG( + id=f"dag_{self.agent.id}", + name=dag_data.get("task_name", original_task[:50]), + description=dag_data.get("task_description", original_task) + ) + + # Add nodes + for node_data in dag_data.get("nodes", []): + node = TaskNode( + id=node_data["id"], + name=node_data["name"], + description=node_data.get("description", ""), + task_type=node_data.get("task_type", "generic"), + task_data=node_data.get("task_data", {}) + ) + dag.add_node(node) + + # Add edges based on dependencies + for node_data in dag_data.get("nodes", []): + node_id = node_data["id"] + for dep_id in node_data.get("dependencies", []): + if dep_id in dag.nodes: + dag.add_edge(dep_id, node_id) + + return dag + + def _extract_json(self, content: str) -> Optional[Dict]: + """ + Extract JSON from LLM response + + Args: + content: Raw LLM response + + Returns: + Parsed JSON dict or None + """ + # Try to find JSON in markdown code blocks + import re + + # Look for ```json ... ``` blocks + json_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group(1)) + except json.JSONDecodeError: + pass + + # Look for raw JSON object + json_match = re.search(r"\{.*\}", content, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group(0)) + except json.JSONDecodeError: + pass + + return None + + def integrate_results(self, dag: DAG) -> Dict[str, Any]: + """ + Integrate results from all completed tasks + + Args: + dag: DAG with completed tasks + + Returns: + Integrated result + """ + self.agent.status = AgentStatus.EXECUTING + + # Collect all output data from completed nodes + results = {} + for node in dag.nodes.values(): + if node.status == TaskNodeStatus.COMPLETED and node.output_data: + results[node.id] = node.output_data + + # Store aggregated results + self.agent.accumulated_result = results + self.agent.status = AgentStatus.COMPLETED + + return { + "success": True, + "dag_id": dag.id, + "total_tasks": dag.total_count, + "completed_tasks": dag.completed_count, + "failed_tasks": dag.failed_count, + "results": results + } + + async def review_and_refine( + self, + dag: DAG, + task: str, + progress_callback: Optional[Callable] = None + ) -> Optional[DAG]: + """ + Review DAG execution and refine if needed + + Args: + dag: Current DAG state + task: Original task + progress_callback: Progress callback + + Returns: + Refined DAG or None if no refinement needed + """ + if dag.is_success: + return None # No refinement needed + + # Check if there are failed tasks + failed_nodes = [n for n in dag.nodes.values() if n.status == TaskNodeStatus.FAILED] + + if not failed_nodes: + return None + + # Build context for refinement + context = { + "original_task": task, + "failed_tasks": [ + { + "id": n.id, + "name": n.name, + "error": n.result.error if n.result else "Unknown error" + } + for n in failed_nodes + ], + "completed_tasks": [ + { + "id": n.id, + "name": n.name, + "output": n.output_data + } + for n in dag.nodes.values() if n.status == TaskNodeStatus.COMPLETED + ] + } + + messages = self.agent.get_context() + messages.append({ + "role": "user", + "content": f"""Review the task execution and suggest refinements: + +Task: {task} + +Failed tasks: {json.dumps(context['failed_tasks'], indent=2)} + +Completed tasks: {json.dumps(context['completed_tasks'], indent=2)} + +If a task failed, you can: +1. Break it into smaller tasks +2. Change the approach +3. Skip it if not critical + +Provide a refined subtask plan if needed, or indicate if the overall task should fail.""" + }) + + try: + response = await self.llm_client.sync_call( + model=self.agent.config.model, + messages=messages, + temperature=self.agent.config.temperature + ) + + # Check if refinement was suggested + refined_dag = self._parse_dag_from_response(response.content, task) + + # Only return if we got a valid refinement + if refined_dag and refined_dag.nodes: + return refined_dag + + except Exception as e: + logger.error(f"DAG refinement failed: {e}") + + return None diff --git a/luxx/agents/worker.py b/luxx/agents/worker.py new file mode 100644 index 0000000..ddfc3ab --- /dev/null +++ b/luxx/agents/worker.py @@ -0,0 +1,401 @@ +"""Worker Agent - executes specific tasks""" +import json +import logging +import time +from typing import Any, Dict, List, Optional, Callable + +from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus +from luxx.agents.dag import TaskNode, TaskNodeStatus, TaskResult +from luxx.services.llm_client import llm_client +from luxx.tools.core import registry as tool_registry, ToolContext, CommandPermission + +logger = logging.getLogger(__name__) + + +class WorkerAgent: + """ + Worker Agent + + Responsible for executing specific tasks using: + - LLM calls for reasoning tasks + - Tool execution for actionable tasks + + Follows sliding window context management. + """ + + # System prompt for worker tasks + DEFAULT_SYSTEM_PROMPT = """You are a Worker Agent that executes specific tasks efficiently. + +Your responsibilities: +1. Execute tasks assigned to you by the Supervisor +2. Use appropriate tools when needed +3. Report results clearly with structured output_data for dependent tasks +4. Be concise and focused on the task at hand + +Output format: +- Provide clear, structured results +- Include output_data for any data that dependent tasks might need +- If a tool fails, explain the error clearly +""" + + def __init__( + self, + agent: Agent, + llm_client=None, + tool_executor=None + ): + """ + Initialize Worker Agent + + Args: + agent: Agent instance (should be WORKER type) + llm_client: LLM client instance + tool_executor: Tool executor instance + """ + self.agent = agent + self.llm_client = llm_client or llm_client + self.tool_executor = tool_executor + + # Ensure agent has worker system prompt + if not self.agent.config.system_prompt: + self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT + + async def execute_task( + self, + task_node: TaskNode, + context: Dict[str, Any], + parent_outputs: Dict[str, Dict[str, Any]] = None, + progress_callback: Optional[Callable] = None + ) -> TaskResult: + """ + Execute a task node + + Args: + task_node: Task node to execute + context: Execution context (workspace, user info, etc.) + parent_outputs: Output data from parent tasks (dependency results) + progress_callback: Optional callback for progress updates + + Returns: + TaskResult with execution outcome + """ + self.agent.status = AgentStatus.EXECUTING + self.agent.current_task_id = task_node.id + + start_time = time.time() + + if progress_callback: + progress_callback(0.0, f"Starting task: {task_node.name}") + + try: + # Merge parent outputs into context + execution_context = self._prepare_context(context, parent_outputs) + + # Execute based on task type + if task_node.task_type == "llm": + result = await self._execute_llm_task(task_node, execution_context, progress_callback) + elif task_node.task_type == "code": + result = await self._execute_code_task(task_node, execution_context, progress_callback) + elif task_node.task_type == "shell": + result = await self._execute_shell_task(task_node, execution_context, progress_callback) + elif task_node.task_type == "file": + result = await self._execute_file_task(task_node, execution_context, progress_callback) + else: + result = await self._execute_generic_task(task_node, execution_context, progress_callback) + + execution_time = time.time() - start_time + result.execution_time = execution_time + + if progress_callback: + progress_callback(1.0, f"Task complete: {task_node.name}") + + self.agent.status = AgentStatus.IDLE + return result + + except Exception as e: + logger.error(f"Task execution failed: {e}") + execution_time = time.time() - start_time + self.agent.status = AgentStatus.FAILED + return TaskResult.fail(error=str(e)) + + def _prepare_context( + self, + context: Dict[str, Any], + parent_outputs: Dict[str, Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Prepare execution context by merging parent outputs + + Args: + context: Base context + parent_outputs: Output from parent tasks + + Returns: + Merged context + """ + execution_context = context.copy() + + if parent_outputs: + # Merge parent outputs into context + merged = {} + for parent_id, outputs in parent_outputs.items(): + merged.update(outputs) + execution_context["parent_outputs"] = parent_outputs + execution_context["merged_data"] = merged + + # Add user permission level + if "user_permission_level" not in execution_context: + execution_context["user_permission_level"] = self.agent.effective_permission.value + + return execution_context + + async def _execute_llm_task( + self, + task_node: TaskNode, + context: Dict[str, Any], + progress_callback: Optional[Callable] = None + ) -> TaskResult: + """Execute LLM reasoning task""" + task_data = task_node.task_data + + # Build prompt + prompt = task_data.get("prompt", task_node.description) + system_prompt = task_data.get("system", self.agent.config.system_prompt) + + messages = [{"role": "system", "content": system_prompt}] + + # Add parent data if available + if "merged_data" in context: + merged = context["merged_data"] + context_info = "\n".join([f"{k}: {v}" for k, v in merged.items()]) + messages.append({ + "role": "system", + "content": f"Context from dependent tasks:\n{context_info}" + }) + + messages.append({"role": "user", "content": prompt}) + + if progress_callback: + progress_callback(0.3, "Calling LLM...") + + try: + response = await self.llm_client.sync_call( + model=self.agent.config.model, + messages=messages, + temperature=self.agent.config.temperature, + max_tokens=self.agent.config.max_tokens + ) + + return TaskResult.ok( + data=response.content, + output_data=task_node.task_data.get("output_template", {}).copy() + ) + + except Exception as e: + return TaskResult.fail(error=str(e)) + + async def _execute_code_task( + self, + task_node: TaskNode, + context: Dict[str, Any], + progress_callback: Optional[Callable] = None + ) -> TaskResult: + """Execute code generation/writing task""" + task_data = task_node.task_data + + # Build prompt for code generation + prompt = task_data.get("prompt", task_node.description) + language = task_data.get("language", "python") + requirements = task_data.get("requirements", "") + + messages = [ + {"role": "system", "content": f"You are a {language} programmer. Write clean, efficient code."}, + {"role": "user", "content": f"Task: {prompt}\n\nRequirements: {requirements}"} + ] + + if progress_callback: + progress_callback(0.3, "Generating code...") + + try: + response = await self.llm_client.sync_call( + model=self.agent.config.model, + messages=messages, + temperature=0.2, # Lower temp for code + max_tokens=4096 + ) + + return TaskResult.ok( + data=response.content, + output_data={ + "code": response.content, + "language": language + } + ) + + except Exception as e: + return TaskResult.fail(error=str(e)) + + async def _execute_shell_task( + self, + task_node: TaskNode, + context: Dict[str, Any], + progress_callback: Optional[Callable] = None + ) -> TaskResult: + """Execute shell command task""" + task_data = task_node.task_data + + command = task_data.get("command") + if not command: + return TaskResult.fail(error="No command specified") + + if progress_callback: + progress_callback(0.3, f"Executing: {command[:50]}...") + + # Build tool context + tool_ctx = ToolContext( + workspace=context.get("workspace"), + user_id=context.get("user_id"), + username=context.get("username"), + extra={"user_permission_level": context.get("user_permission_level", 1)} + ) + + try: + # Execute shell command via tool + result = tool_registry.execute( + "shell_exec", + {"command": command}, + context=tool_ctx + ) + + if result.get("success"): + return TaskResult.ok( + data=result.get("data", {}).get("output", ""), + output_data={"output": result.get("data", {}).get("output", "")} + ) + else: + return TaskResult.fail(error=result.get("error", "Shell execution failed")) + + except Exception as e: + return TaskResult.fail(error=str(e)) + + async def _execute_file_task( + self, + task_node: TaskNode, + context: Dict[str, Any], + progress_callback: Optional[Callable] = None + ) -> TaskResult: + """Execute file operation task""" + task_data = task_node.task_data + + operation = task_data.get("operation") + file_path = task_data.get("path") + content = task_data.get("content", "") + + if not operation or not file_path: + return TaskResult.fail(error="Missing operation or path") + + if progress_callback: + progress_callback(0.3, f"File operation: {operation} {file_path}") + + tool_ctx = ToolContext( + workspace=context.get("workspace"), + user_id=context.get("user_id"), + username=context.get("username"), + extra={"user_permission_level": context.get("user_permission_level", 1)} + ) + + try: + tool_name = f"file_{operation}" + result = tool_registry.execute( + tool_name, + {"path": file_path, "content": content}, + context=tool_ctx + ) + + if result.get("success"): + return TaskResult.ok( + data=result.get("data"), + output_data={"path": file_path, "operation": operation} + ) + else: + return TaskResult.fail(error=result.get("error", "File operation failed")) + + except Exception as e: + return TaskResult.fail(error=str(e)) + + async def _execute_generic_task( + self, + task_node: TaskNode, + context: Dict[str, Any], + progress_callback: Optional[Callable] = None + ) -> TaskResult: + """Execute generic task using LLM with tools""" + task_data = task_node.task_data + + # Build prompt + prompt = task_data.get("prompt", task_node.description) + tools = task_data.get("tools", []) + + messages = [ + {"role": "system", "content": self.agent.config.system_prompt}, + {"role": "user", "content": prompt} + ] + + # Get tool definitions if specified + tool_defs = None + if tools: + tool_defs = [tool_registry.get(t).to_openai_format() for t in tools if tool_registry.get(t)] + + if progress_callback: + progress_callback(0.2, "Processing task...") + + max_iterations = 5 + iteration = 0 + + while iteration < max_iterations: + try: + response = await self.llm_client.sync_call( + model=self.agent.config.model, + messages=messages, + tools=tool_defs, + temperature=self.agent.config.temperature, + max_tokens=self.agent.config.max_tokens + ) + + # Add assistant response + messages.append({"role": "assistant", "content": response.content}) + + # Check for tool calls + if response.tool_calls: + if progress_callback: + progress_callback(0.5, f"Executing {len(response.tool_calls)} tools...") + + # Execute tools + tool_results = self.tool_executor.process_tool_calls( + response.tool_calls, + context + ) + + # Add tool results + for tr in tool_results: + messages.append({ + "role": "tool", + "tool_call_id": tr["tool_call_id"], + "content": tr["content"] + }) + + if progress_callback: + progress_callback(0.8, "Tools executed") + else: + # No tool calls, task complete + return TaskResult.ok( + data=response.content, + output_data=task_data.get("output_template", {}) + ) + + except Exception as e: + return TaskResult.fail(error=str(e)) + + iteration += 1 + + return TaskResult.fail(error="Max iterations exceeded") diff --git a/luxx/routes/__init__.py b/luxx/routes/__init__.py index c6519c8..6464f6d 100644 --- a/luxx/routes/__init__.py +++ b/luxx/routes/__init__.py @@ -2,6 +2,7 @@ from fastapi import APIRouter from luxx.routes import auth, conversations, messages, tools, providers +from luxx.routes.agents_ws import router as agents_ws_router api_router = APIRouter() @@ -12,3 +13,4 @@ api_router.include_router(conversations.router) api_router.include_router(messages.router) api_router.include_router(tools.router) api_router.include_router(providers.router) +api_router.include_router(agents_ws_router) diff --git a/luxx/routes/agents_ws.py b/luxx/routes/agents_ws.py new file mode 100644 index 0000000..b036287 --- /dev/null +++ b/luxx/routes/agents_ws.py @@ -0,0 +1,385 @@ +"""WebSocket routes for agent real-time communication""" +import asyncio +import json +import logging +import threading +from typing import Any, Dict, Optional, Set +from datetime import datetime +from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends +import uuid + +from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus +from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult +from luxx.agents.supervisor import SupervisorAgent +from luxx.agents.worker import WorkerAgent +from luxx.agents.dag_scheduler import DAGScheduler, SchedulerPool +from luxx.agents.registry import AgentRegistry +from luxx.services.llm_client import llm_client +from luxx.tools.executor import ToolExecutor + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +class ConnectionManager: + """ + WebSocket Connection Manager + + Manages WebSocket connections for real-time agent progress updates. + Features: + - Connection tracking by task_id + - Heartbeat mechanism + - Progress broadcasting + """ + + def __init__(self): + # task_id -> set of websocket connections + self._connections: Dict[str, Set[WebSocket]] = {} + # websocket -> task_id mapping + self._ws_to_task: Dict[WebSocket, str] = {} + # heartbeat tasks + self._heartbeat_tasks: Dict[WebSocket, asyncio.Task] = {} + # lock for thread safety + self._lock = threading.Lock() + # heartbeat interval in seconds + self._heartbeat_interval = 30 + + def connect(self, websocket: WebSocket, task_id: str) -> None: + """ + Accept and register a WebSocket connection + + Args: + websocket: WebSocket connection + task_id: Task ID to subscribe to + """ + websocket.accept() + + with self._lock: + if task_id not in self._connections: + self._connections[task_id] = set() + self._connections[task_id].add(websocket) + self._ws_to_task[websocket] = task_id + + logger.info(f"WebSocket connected for task {task_id}") + + def disconnect(self, websocket: WebSocket) -> None: + """ + Unregister a WebSocket connection + + Args: + websocket: WebSocket connection + """ + with self._lock: + task_id = self._ws_to_task.pop(websocket, None) + if task_id and task_id in self._connections: + self._connections[task_id].discard(websocket) + if not self._connections[task_id]: + del self._connections[task_id] + + # Cancel heartbeat task + if websocket in self._heartbeat_tasks: + self._heartbeat_tasks[websocket].cancel() + del self._heartbeat_tasks[websocket] + + if task_id: + logger.info(f"WebSocket disconnected for task {task_id}") + + async def send_to_task(self, task_id: str, message: Dict[str, Any]) -> None: + """ + Send message to all connections subscribed to a task + + Args: + task_id: Task ID + message: Message to send + """ + with self._lock: + connections = self._connections.get(task_id, set()).copy() + + if not connections: + return + + dead_connections = set() + + for websocket in connections: + try: + await websocket.send_json(message) + except Exception as e: + logger.warning(f"Failed to send to websocket: {e}") + dead_connections.add(websocket) + + # Clean up dead connections + for ws in dead_connections: + self.disconnect(ws) + + async def broadcast(self, message: Dict[str, Any]) -> None: + """ + Broadcast message to all connections + + Args: + message: Message to broadcast + """ + with self._lock: + all_connections = list(self._ws_to_task.keys()) + + for websocket in all_connections: + try: + await websocket.send_json(message) + except Exception as e: + logger.warning(f"Failed to broadcast: {e}") + + async def send_personal(self, websocket: WebSocket, message: Dict[str, Any]) -> None: + """ + Send message to a specific connection + + Args: + websocket: Target WebSocket + message: Message to send + """ + try: + await websocket.send_json(message) + except Exception as e: + logger.warning(f"Failed to send personal message: {e}") + + def start_heartbeat(self, websocket: WebSocket) -> None: + """ + Start heartbeat for a connection + + Args: + websocket: WebSocket connection + """ + async def heartbeat_loop(): + while True: + await asyncio.sleep(self._heartbeat_interval) + try: + await websocket.send_json({ + "type": "heartbeat", + "interval": self._heartbeat_interval + }) + except Exception: + break + + task = asyncio.create_task(heartbeat_loop()) + with self._lock: + self._heartbeat_tasks[websocket] = task + + @property + def connection_count(self) -> int: + """Total number of connections""" + with self._lock: + return len(self._ws_to_task) + + def get_task_connections(self, task_id: str) -> int: + """Get number of connections for a task""" + with self._lock: + return len(self._connections.get(task_id, set())) + + +# Global connection manager +connection_manager = ConnectionManager() + +# Global scheduler pool +scheduler_pool = SchedulerPool(max_concurrent=10) + +# Global tool executor +tool_executor = ToolExecutor() + + +class ProgressEmitter: + """ + Progress emitter that sends updates via WebSocket + + Wraps DAGScheduler progress callback to emit WebSocket messages. + """ + + def __init__(self, task_id: str, connection_manager: ConnectionManager): + self.task_id = task_id + self.connection_manager = connection_manager + + def __call__(self, node_id: str, progress: float, message: str = "") -> None: + """Progress callback""" + if node_id == "dag": + # DAG-level progress + asyncio.create_task(self.connection_manager.send_to_task( + self.task_id, + { + "type": "dag_progress", + "data": { + "progress": progress, + "message": message + } + } + )) + else: + # Node-level progress + asyncio.create_task(self.connection_manager.send_to_task( + self.task_id, + { + "type": "node_progress", + "data": { + "node_id": node_id, + "progress": progress, + "message": message + } + } + )) + + +@router.websocket("/ws/dag/{task_id}") +async def dag_websocket(websocket: WebSocket, task_id: str): + """ + WebSocket endpoint for DAG progress updates + + Protocol: + - Client sends: subscribe, get_status, ping, cancel_task + - Server sends: subscribed, heartbeat, dag_start, node_start, node_progress, + node_complete, node_error, dag_complete, pong + """ + # Accept connection + connection_manager.connect(websocket, task_id) + + # Send subscribed confirmation + await connection_manager.send_personal(websocket, { + "type": "subscribed", + "task_id": task_id + }) + + # Start heartbeat + connection_manager.start_heartbeat(websocket) + + # Get scheduler if exists + scheduler = scheduler_pool.get(task_id) + if scheduler: + # Send current DAG state + await connection_manager.send_personal(websocket, { + "type": "dag_status", + "data": scheduler.dag.to_dict() + }) + + try: + while True: + # Receive message + data = await websocket.receive_text() + + try: + message = json.loads(data) + except json.JSONDecodeError: + await connection_manager.send_personal(websocket, { + "type": "error", + "message": "Invalid JSON" + }) + continue + + msg_type = message.get("type") + + if msg_type == "subscribe": + # Already subscribed on connect + await connection_manager.send_personal(websocket, { + "type": "subscribed", + "task_id": task_id + }) + + elif msg_type == "get_status": + scheduler = scheduler_pool.get(task_id) + if scheduler: + await connection_manager.send_personal(websocket, { + "type": "dag_status", + "data": scheduler.dag.to_dict() + }) + else: + await connection_manager.send_personal(websocket, { + "type": "dag_status", + "data": None + }) + + elif msg_type == "ping": + await connection_manager.send_personal(websocket, { + "type": "pong" + }) + + elif msg_type == "cancel_task": + if scheduler_pool.cancel(task_id): + await connection_manager.send_personal(websocket, { + "type": "task_cancelled", + "task_id": task_id + }) + else: + await connection_manager.send_personal(websocket, { + "type": "error", + "message": "Task not found or already completed" + }) + + else: + await connection_manager.send_personal(websocket, { + "type": "error", + "message": f"Unknown message type: {msg_type}" + }) + + except WebSocketDisconnect: + connection_manager.disconnect(websocket) + except Exception as e: + logger.error(f"WebSocket error: {e}") + connection_manager.disconnect(websocket) + + +# Helper functions to emit progress from scheduler + +def create_progress_emitter(task_id: str) -> ProgressEmitter: + """Create a progress emitter for a task""" + return ProgressEmitter(task_id, connection_manager) + + +async def emit_dag_start(task_id: str, dag: DAG) -> None: + """Emit DAG start event""" + await connection_manager.send_to_task(task_id, { + "type": "dag_start", + "data": { + "graph": dag.to_dict() + } + }) + + +async def emit_node_start(task_id: str, node: TaskNode) -> None: + """Emit node start event""" + await connection_manager.send_to_task(task_id, { + "type": "node_start", + "data": { + "node_id": node.id, + "name": node.name, + "status": node.status.value + } + }) + + +async def emit_node_complete(task_id: str, node: TaskNode) -> None: + """Emit node complete event""" + await connection_manager.send_to_task(task_id, { + "type": "node_complete", + "data": { + "node_id": node.id, + "result": node.result.to_dict() if node.result else None, + "output_data": node.output_data + } + }) + + +async def emit_node_error(task_id: str, node: TaskNode, error: str) -> None: + """Emit node error event""" + await connection_manager.send_to_task(task_id, { + "type": "node_error", + "data": { + "node_id": node.id, + "error": error + } + }) + + +async def emit_dag_complete(task_id: str, success: bool, results: Dict) -> None: + """Emit DAG complete event""" + await connection_manager.send_to_task(task_id, { + "type": "dag_complete", + "data": { + "success": success, + "results": results + } + })