"""DAG Scheduler - orchestrates parallel task execution""" import asyncio import logging from typing import Any, Callable, Dict, List, Optional, Set from concurrent.futures import ThreadPoolExecutor import threading from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult from luxx.agents.supervisor import SupervisorAgent from luxx.agents.worker import WorkerAgent from luxx.tools.executor import ToolExecutor logger = logging.getLogger(__name__) class DAGScheduler: """ DAG Scheduler Orchestrates parallel execution of tasks based on DAG structure. Features: - Parallel execution with max_workers limit - Dependency-aware scheduling - Real-time progress tracking - Cancellation support """ def __init__( self, dag: DAG, supervisor: SupervisorAgent, worker_factory: Callable[[], WorkerAgent], max_workers: int = 3, tool_executor: ToolExecutor = None ): """ Initialize DAG Scheduler Args: dag: DAG to execute supervisor: Supervisor agent instance worker_factory: Factory function to create worker agents max_workers: Maximum parallel workers tool_executor: Tool executor instance """ self.dag = dag self.supervisor = supervisor self.worker_factory = worker_factory self.max_workers = max_workers self.tool_executor = tool_executor or ToolExecutor() # Execution state self._running_nodes: Set[str] = set() self._completed_nodes: Set[str] = set() self._node_results: Dict[str, TaskResult] = {} self._parent_outputs: Dict[str, Dict[str, Any]] = {} # Control flags self._cancelled = False self._cancel_lock = threading.Lock() # Progress callback self._progress_callback: Optional[Callable] = None def set_progress_callback(self, callback: Optional[Callable]) -> None: """Set progress callback for real-time updates""" self._progress_callback = callback def _emit_progress(self, node_id: str, progress: float, message: str = "") -> None: """Emit progress update""" if self._progress_callback: self._progress_callback(node_id, progress, message) async def execute( self, context: Dict[str, Any], task: str ) -> Dict[str, Any]: """ Execute the DAG Args: context: Execution context (workspace, user info, etc.) task: Original task description Returns: Execution results """ self._cancelled = False self._running_nodes.clear() self._completed_nodes.clear() self._node_results.clear() self._parent_outputs.clear() # Emit DAG start self._emit_progress("dag", 0.0, "Starting DAG execution") # Initialize nodes - mark root nodes as ready for node in self.dag.root_nodes: node.mark_ready() # Main execution loop while not self._is_execution_complete(): if self._is_cancelled(): await self._cancel_running_nodes() return self._build_cancelled_result() # Get ready nodes that can be executed ready_nodes = self._get_ready_nodes() if not ready_nodes and not self._running_nodes: # No ready nodes and nothing running - deadlock or all blocked logger.warning("No ready nodes and no running nodes - possible deadlock") break # Launch ready nodes up to max_workers nodes_to_launch = [ n for n in ready_nodes if len(self._running_nodes) < self.max_workers and n.id not in self._running_nodes ] for node in nodes_to_launch: asyncio.create_task(self._execute_node(node, context)) # Small delay to prevent busy waiting await asyncio.sleep(0.1) # Emit DAG completion success = self.dag.is_success self._emit_progress("dag", 1.0, "DAG execution complete" if success else "DAG execution failed") return self._build_result() def _is_execution_complete(self) -> bool: """Check if execution is complete""" return all( n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED, TaskNodeStatus.BLOCKED) for n in self.dag.nodes.values() ) def _is_cancelled(self) -> bool: """Check if execution was cancelled""" with self._cancel_lock: return self._cancelled def cancel(self) -> None: """Cancel execution""" with self._cancel_lock: self._cancelled = True self._emit_progress("dag", 0.0, "Execution cancelled") def _get_ready_nodes(self) -> List[TaskNode]: """Get nodes that are ready to execute""" return [ n for n in self.dag.nodes.values() if n.status == TaskNodeStatus.READY and n.id not in self._running_nodes and n.can_execute(self._completed_nodes) ] async def _execute_node(self, node: TaskNode, context: Dict[str, Any]) -> None: """ Execute a single node Args: node: Node to execute context: Execution context """ if self._is_cancelled(): node.mark_cancelled() return # Mark node as running self._running_nodes.add(node.id) node.mark_running(self.supervisor.agent.id) self._emit_progress(node.id, 0.0, f"Starting: {node.name}") # Collect parent outputs for this node parent_outputs = {} for dep_id in node.dependencies: if dep_id in self._node_results: parent_outputs[dep_id] = self._node_results[dep_id].output_data or {} # Create worker for this task worker = self.worker_factory() # Define progress callback for this node def node_progress(progress: float, message: str = ""): node.update_progress(progress, message) self._emit_progress(node.id, progress, message) try: # Execute task result = await worker.execute_task( node, context, parent_outputs=parent_outputs, progress_callback=node_progress ) # Store result self._node_results[node.id] = result if result.success: node.mark_completed(result) self._completed_nodes.add(node.id) self._emit_progress(node.id, 1.0, f"Completed: {node.name}") # Check if any blocked nodes can now run self._unblock_nodes() else: node.mark_failed(result.error or "Unknown error") self._emit_progress(node.id, 0.0, f"Failed: {node.name} - {result.error}") # Block dependent nodes self._block_dependent_nodes(node.id) except Exception as e: logger.error(f"Node {node.id} execution error: {e}") node.mark_failed(str(e)) self._node_results[node.id] = TaskResult.fail(error=str(e)) self._block_dependent_nodes(node.id) finally: self._running_nodes.discard(node.id) def _unblock_nodes(self) -> None: """Unblock nodes whose dependencies are now satisfied""" for node in self.dag.nodes.values(): if node.status == TaskNodeStatus.BLOCKED: if node.can_execute(self._completed_nodes): node.mark_ready() self._emit_progress(node.id, 0.0, f"Unblocked: {node.name}") def _block_dependent_nodes(self, failed_node_id: str) -> None: """Block nodes that depend on a failed node""" blocked = self.dag.get_blocked_nodes(failed_node_id) for node in blocked: node.status = TaskNodeStatus.BLOCKED self._emit_progress(node.id, 0.0, f"Blocked due to: {failed_node_id}") async def _cancel_running_nodes(self) -> None: """Cancel all running nodes""" for node_id in self._running_nodes: if node_id in self.dag.nodes: self.dag.nodes[node_id].mark_cancelled() self._running_nodes.clear() def _build_result(self) -> Dict[str, Any]: """Build final result""" return { "success": self.dag.is_success, "dag_id": self.dag.id, "total_tasks": self.dag.total_count, "completed_tasks": self.dag.completed_count, "failed_tasks": self.dag.failed_count, "progress": self.dag.progress, "results": { node_id: result.to_dict() for node_id, result in self._node_results.items() } } def _build_cancelled_result(self) -> Dict[str, Any]: """Build cancelled result""" return { "success": False, "cancelled": True, "dag_id": self.dag.id, "total_tasks": self.dag.total_count, "completed_tasks": self.dag.completed_count, "failed_tasks": self.dag.failed_count, "progress": self.dag.progress, "results": { node_id: result.to_dict() for node_id, result in self._node_results.items() } } class SchedulerPool: """ Pool of DAG schedulers for managing multiple concurrent DAG executions """ def __init__(self, max_concurrent: int = 10): """ Initialize scheduler pool Args: max_concurrent: Maximum concurrent DAG executions """ self.max_concurrent = max_concurrent self._schedulers: Dict[str, DAGScheduler] = {} self._lock = threading.Lock() def create_scheduler( self, task_id: str, dag: DAG, supervisor: SupervisorAgent, worker_factory: Callable, max_workers: int = 3 ) -> DAGScheduler: """ Create and register a new scheduler Args: task_id: Unique task identifier dag: DAG to execute supervisor: Supervisor agent worker_factory: Worker factory function max_workers: Max parallel workers Returns: Created scheduler """ with self._lock: if len(self._schedulers) >= self.max_concurrent: raise RuntimeError("Maximum concurrent schedulers reached") scheduler = DAGScheduler( dag=dag, supervisor=supervisor, worker_factory=worker_factory, max_workers=max_workers ) self._schedulers[task_id] = scheduler return scheduler def get(self, task_id: str) -> Optional[DAGScheduler]: """Get scheduler by task ID""" with self._lock: return self._schedulers.get(task_id) def remove(self, task_id: str) -> bool: """Remove scheduler""" with self._lock: if task_id in self._schedulers: del self._schedulers[task_id] return True return False def cancel(self, task_id: str) -> bool: """Cancel a scheduler""" scheduler = self.get(task_id) if scheduler: scheduler.cancel() return True return False @property def active_count(self) -> int: """Number of active schedulers""" with self._lock: return len(self._schedulers)