Luxx/luxx/agents/dag_scheduler.py

361 lines
12 KiB
Python

"""DAG Scheduler - orchestrates parallel task execution"""
import asyncio
import logging
from typing import Any, Callable, Dict, List, Optional, Set
from concurrent.futures import ThreadPoolExecutor
import threading
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
from luxx.agents.supervisor import SupervisorAgent
from luxx.agents.worker import WorkerAgent
from luxx.tools.executor import ToolExecutor
logger = logging.getLogger(__name__)
class DAGScheduler:
"""
DAG Scheduler
Orchestrates parallel execution of tasks based on DAG structure.
Features:
- Parallel execution with max_workers limit
- Dependency-aware scheduling
- Real-time progress tracking
- Cancellation support
"""
def __init__(
self,
dag: DAG,
supervisor: SupervisorAgent,
worker_factory: Callable[[], WorkerAgent],
max_workers: int = 3,
tool_executor: ToolExecutor = None
):
"""
Initialize DAG Scheduler
Args:
dag: DAG to execute
supervisor: Supervisor agent instance
worker_factory: Factory function to create worker agents
max_workers: Maximum parallel workers
tool_executor: Tool executor instance
"""
self.dag = dag
self.supervisor = supervisor
self.worker_factory = worker_factory
self.max_workers = max_workers
self.tool_executor = tool_executor or ToolExecutor()
# Execution state
self._running_nodes: Set[str] = set()
self._completed_nodes: Set[str] = set()
self._node_results: Dict[str, TaskResult] = {}
self._parent_outputs: Dict[str, Dict[str, Any]] = {}
# Control flags
self._cancelled = False
self._cancel_lock = threading.Lock()
# Progress callback
self._progress_callback: Optional[Callable] = None
def set_progress_callback(self, callback: Optional[Callable]) -> None:
"""Set progress callback for real-time updates"""
self._progress_callback = callback
def _emit_progress(self, node_id: str, progress: float, message: str = "") -> None:
"""Emit progress update"""
if self._progress_callback:
self._progress_callback(node_id, progress, message)
async def execute(
self,
context: Dict[str, Any],
task: str
) -> Dict[str, Any]:
"""
Execute the DAG
Args:
context: Execution context (workspace, user info, etc.)
task: Original task description
Returns:
Execution results
"""
self._cancelled = False
self._running_nodes.clear()
self._completed_nodes.clear()
self._node_results.clear()
self._parent_outputs.clear()
# Emit DAG start
self._emit_progress("dag", 0.0, "Starting DAG execution")
# Initialize nodes - mark root nodes as ready
for node in self.dag.root_nodes:
node.mark_ready()
# Main execution loop
while not self._is_execution_complete():
if self._is_cancelled():
await self._cancel_running_nodes()
return self._build_cancelled_result()
# Get ready nodes that can be executed
ready_nodes = self._get_ready_nodes()
if not ready_nodes and not self._running_nodes:
# No ready nodes and nothing running - deadlock or all blocked
logger.warning("No ready nodes and no running nodes - possible deadlock")
break
# Launch ready nodes up to max_workers
nodes_to_launch = [
n for n in ready_nodes
if len(self._running_nodes) < self.max_workers
and n.id not in self._running_nodes
]
for node in nodes_to_launch:
asyncio.create_task(self._execute_node(node, context))
# Small delay to prevent busy waiting
await asyncio.sleep(0.1)
# Emit DAG completion
success = self.dag.is_success
self._emit_progress("dag", 1.0, "DAG execution complete" if success else "DAG execution failed")
return self._build_result()
def _is_execution_complete(self) -> bool:
"""Check if execution is complete"""
return all(
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED, TaskNodeStatus.BLOCKED)
for n in self.dag.nodes.values()
)
def _is_cancelled(self) -> bool:
"""Check if execution was cancelled"""
with self._cancel_lock:
return self._cancelled
def cancel(self) -> None:
"""Cancel execution"""
with self._cancel_lock:
self._cancelled = True
self._emit_progress("dag", 0.0, "Execution cancelled")
def _get_ready_nodes(self) -> List[TaskNode]:
"""Get nodes that are ready to execute"""
return [
n for n in self.dag.nodes.values()
if n.status == TaskNodeStatus.READY
and n.id not in self._running_nodes
and n.can_execute(self._completed_nodes)
]
async def _execute_node(self, node: TaskNode, context: Dict[str, Any]) -> None:
"""
Execute a single node
Args:
node: Node to execute
context: Execution context
"""
if self._is_cancelled():
node.mark_cancelled()
return
# Mark node as running
self._running_nodes.add(node.id)
node.mark_running(self.supervisor.agent.id)
self._emit_progress(node.id, 0.0, f"Starting: {node.name}")
# Collect parent outputs for this node
parent_outputs = {}
for dep_id in node.dependencies:
if dep_id in self._node_results:
parent_outputs[dep_id] = self._node_results[dep_id].output_data or {}
# Create worker for this task
worker = self.worker_factory()
# Define progress callback for this node
def node_progress(progress: float, message: str = ""):
node.update_progress(progress, message)
self._emit_progress(node.id, progress, message)
try:
# Execute task
result = await worker.execute_task(
node,
context,
parent_outputs=parent_outputs,
progress_callback=node_progress
)
# Store result
self._node_results[node.id] = result
if result.success:
node.mark_completed(result)
self._completed_nodes.add(node.id)
self._emit_progress(node.id, 1.0, f"Completed: {node.name}")
# Check if any blocked nodes can now run
self._unblock_nodes()
else:
node.mark_failed(result.error or "Unknown error")
self._emit_progress(node.id, 0.0, f"Failed: {node.name} - {result.error}")
# Block dependent nodes
self._block_dependent_nodes(node.id)
except Exception as e:
logger.error(f"Node {node.id} execution error: {e}")
node.mark_failed(str(e))
self._node_results[node.id] = TaskResult.fail(error=str(e))
self._block_dependent_nodes(node.id)
finally:
self._running_nodes.discard(node.id)
def _unblock_nodes(self) -> None:
"""Unblock nodes whose dependencies are now satisfied"""
for node in self.dag.nodes.values():
if node.status == TaskNodeStatus.BLOCKED:
if node.can_execute(self._completed_nodes):
node.mark_ready()
self._emit_progress(node.id, 0.0, f"Unblocked: {node.name}")
def _block_dependent_nodes(self, failed_node_id: str) -> None:
"""Block nodes that depend on a failed node"""
blocked = self.dag.get_blocked_nodes(failed_node_id)
for node in blocked:
node.status = TaskNodeStatus.BLOCKED
self._emit_progress(node.id, 0.0, f"Blocked due to: {failed_node_id}")
async def _cancel_running_nodes(self) -> None:
"""Cancel all running nodes"""
for node_id in self._running_nodes:
if node_id in self.dag.nodes:
self.dag.nodes[node_id].mark_cancelled()
self._running_nodes.clear()
def _build_result(self) -> Dict[str, Any]:
"""Build final result"""
return {
"success": self.dag.is_success,
"dag_id": self.dag.id,
"total_tasks": self.dag.total_count,
"completed_tasks": self.dag.completed_count,
"failed_tasks": self.dag.failed_count,
"progress": self.dag.progress,
"results": {
node_id: result.to_dict()
for node_id, result in self._node_results.items()
}
}
def _build_cancelled_result(self) -> Dict[str, Any]:
"""Build cancelled result"""
return {
"success": False,
"cancelled": True,
"dag_id": self.dag.id,
"total_tasks": self.dag.total_count,
"completed_tasks": self.dag.completed_count,
"failed_tasks": self.dag.failed_count,
"progress": self.dag.progress,
"results": {
node_id: result.to_dict()
for node_id, result in self._node_results.items()
}
}
class SchedulerPool:
"""
Pool of DAG schedulers for managing multiple concurrent DAG executions
"""
def __init__(self, max_concurrent: int = 10):
"""
Initialize scheduler pool
Args:
max_concurrent: Maximum concurrent DAG executions
"""
self.max_concurrent = max_concurrent
self._schedulers: Dict[str, DAGScheduler] = {}
self._lock = threading.Lock()
def create_scheduler(
self,
task_id: str,
dag: DAG,
supervisor: SupervisorAgent,
worker_factory: Callable,
max_workers: int = 3
) -> DAGScheduler:
"""
Create and register a new scheduler
Args:
task_id: Unique task identifier
dag: DAG to execute
supervisor: Supervisor agent
worker_factory: Worker factory function
max_workers: Max parallel workers
Returns:
Created scheduler
"""
with self._lock:
if len(self._schedulers) >= self.max_concurrent:
raise RuntimeError("Maximum concurrent schedulers reached")
scheduler = DAGScheduler(
dag=dag,
supervisor=supervisor,
worker_factory=worker_factory,
max_workers=max_workers
)
self._schedulers[task_id] = scheduler
return scheduler
def get(self, task_id: str) -> Optional[DAGScheduler]:
"""Get scheduler by task ID"""
with self._lock:
return self._schedulers.get(task_id)
def remove(self, task_id: str) -> bool:
"""Remove scheduler"""
with self._lock:
if task_id in self._schedulers:
del self._schedulers[task_id]
return True
return False
def cancel(self, task_id: str) -> bool:
"""Cancel a scheduler"""
scheduler = self.get(task_id)
if scheduler:
scheduler.cancel()
return True
return False
@property
def active_count(self) -> int:
"""Number of active schedulers"""
with self._lock:
return len(self._schedulers)