361 lines
12 KiB
Python
361 lines
12 KiB
Python
"""DAG Scheduler - orchestrates parallel task execution"""
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Callable, Dict, List, Optional, Set
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import threading
|
|
|
|
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
|
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
|
from luxx.agents.supervisor import SupervisorAgent
|
|
from luxx.agents.worker import WorkerAgent
|
|
from luxx.tools.executor import ToolExecutor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DAGScheduler:
|
|
"""
|
|
DAG Scheduler
|
|
|
|
Orchestrates parallel execution of tasks based on DAG structure.
|
|
Features:
|
|
- Parallel execution with max_workers limit
|
|
- Dependency-aware scheduling
|
|
- Real-time progress tracking
|
|
- Cancellation support
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dag: DAG,
|
|
supervisor: SupervisorAgent,
|
|
worker_factory: Callable[[], WorkerAgent],
|
|
max_workers: int = 3,
|
|
tool_executor: ToolExecutor = None
|
|
):
|
|
"""
|
|
Initialize DAG Scheduler
|
|
|
|
Args:
|
|
dag: DAG to execute
|
|
supervisor: Supervisor agent instance
|
|
worker_factory: Factory function to create worker agents
|
|
max_workers: Maximum parallel workers
|
|
tool_executor: Tool executor instance
|
|
"""
|
|
self.dag = dag
|
|
self.supervisor = supervisor
|
|
self.worker_factory = worker_factory
|
|
self.max_workers = max_workers
|
|
self.tool_executor = tool_executor or ToolExecutor()
|
|
|
|
# Execution state
|
|
self._running_nodes: Set[str] = set()
|
|
self._completed_nodes: Set[str] = set()
|
|
self._node_results: Dict[str, TaskResult] = {}
|
|
self._parent_outputs: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# Control flags
|
|
self._cancelled = False
|
|
self._cancel_lock = threading.Lock()
|
|
|
|
# Progress callback
|
|
self._progress_callback: Optional[Callable] = None
|
|
|
|
def set_progress_callback(self, callback: Optional[Callable]) -> None:
|
|
"""Set progress callback for real-time updates"""
|
|
self._progress_callback = callback
|
|
|
|
def _emit_progress(self, node_id: str, progress: float, message: str = "") -> None:
|
|
"""Emit progress update"""
|
|
if self._progress_callback:
|
|
self._progress_callback(node_id, progress, message)
|
|
|
|
async def execute(
|
|
self,
|
|
context: Dict[str, Any],
|
|
task: str
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Execute the DAG
|
|
|
|
Args:
|
|
context: Execution context (workspace, user info, etc.)
|
|
task: Original task description
|
|
|
|
Returns:
|
|
Execution results
|
|
"""
|
|
self._cancelled = False
|
|
self._running_nodes.clear()
|
|
self._completed_nodes.clear()
|
|
self._node_results.clear()
|
|
self._parent_outputs.clear()
|
|
|
|
# Emit DAG start
|
|
self._emit_progress("dag", 0.0, "Starting DAG execution")
|
|
|
|
# Initialize nodes - mark root nodes as ready
|
|
for node in self.dag.root_nodes:
|
|
node.mark_ready()
|
|
|
|
# Main execution loop
|
|
while not self._is_execution_complete():
|
|
if self._is_cancelled():
|
|
await self._cancel_running_nodes()
|
|
return self._build_cancelled_result()
|
|
|
|
# Get ready nodes that can be executed
|
|
ready_nodes = self._get_ready_nodes()
|
|
|
|
if not ready_nodes and not self._running_nodes:
|
|
# No ready nodes and nothing running - deadlock or all blocked
|
|
logger.warning("No ready nodes and no running nodes - possible deadlock")
|
|
break
|
|
|
|
# Launch ready nodes up to max_workers
|
|
nodes_to_launch = [
|
|
n for n in ready_nodes
|
|
if len(self._running_nodes) < self.max_workers
|
|
and n.id not in self._running_nodes
|
|
]
|
|
|
|
for node in nodes_to_launch:
|
|
asyncio.create_task(self._execute_node(node, context))
|
|
|
|
# Small delay to prevent busy waiting
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Emit DAG completion
|
|
success = self.dag.is_success
|
|
self._emit_progress("dag", 1.0, "DAG execution complete" if success else "DAG execution failed")
|
|
|
|
return self._build_result()
|
|
|
|
def _is_execution_complete(self) -> bool:
|
|
"""Check if execution is complete"""
|
|
return all(
|
|
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED, TaskNodeStatus.BLOCKED)
|
|
for n in self.dag.nodes.values()
|
|
)
|
|
|
|
def _is_cancelled(self) -> bool:
|
|
"""Check if execution was cancelled"""
|
|
with self._cancel_lock:
|
|
return self._cancelled
|
|
|
|
def cancel(self) -> None:
|
|
"""Cancel execution"""
|
|
with self._cancel_lock:
|
|
self._cancelled = True
|
|
self._emit_progress("dag", 0.0, "Execution cancelled")
|
|
|
|
def _get_ready_nodes(self) -> List[TaskNode]:
|
|
"""Get nodes that are ready to execute"""
|
|
return [
|
|
n for n in self.dag.nodes.values()
|
|
if n.status == TaskNodeStatus.READY
|
|
and n.id not in self._running_nodes
|
|
and n.can_execute(self._completed_nodes)
|
|
]
|
|
|
|
async def _execute_node(self, node: TaskNode, context: Dict[str, Any]) -> None:
|
|
"""
|
|
Execute a single node
|
|
|
|
Args:
|
|
node: Node to execute
|
|
context: Execution context
|
|
"""
|
|
if self._is_cancelled():
|
|
node.mark_cancelled()
|
|
return
|
|
|
|
# Mark node as running
|
|
self._running_nodes.add(node.id)
|
|
node.mark_running(self.supervisor.agent.id)
|
|
|
|
self._emit_progress(node.id, 0.0, f"Starting: {node.name}")
|
|
|
|
# Collect parent outputs for this node
|
|
parent_outputs = {}
|
|
for dep_id in node.dependencies:
|
|
if dep_id in self._node_results:
|
|
parent_outputs[dep_id] = self._node_results[dep_id].output_data or {}
|
|
|
|
# Create worker for this task
|
|
worker = self.worker_factory()
|
|
|
|
# Define progress callback for this node
|
|
def node_progress(progress: float, message: str = ""):
|
|
node.update_progress(progress, message)
|
|
self._emit_progress(node.id, progress, message)
|
|
|
|
try:
|
|
# Execute task
|
|
result = await worker.execute_task(
|
|
node,
|
|
context,
|
|
parent_outputs=parent_outputs,
|
|
progress_callback=node_progress
|
|
)
|
|
|
|
# Store result
|
|
self._node_results[node.id] = result
|
|
|
|
if result.success:
|
|
node.mark_completed(result)
|
|
self._completed_nodes.add(node.id)
|
|
self._emit_progress(node.id, 1.0, f"Completed: {node.name}")
|
|
|
|
# Check if any blocked nodes can now run
|
|
self._unblock_nodes()
|
|
else:
|
|
node.mark_failed(result.error or "Unknown error")
|
|
self._emit_progress(node.id, 0.0, f"Failed: {node.name} - {result.error}")
|
|
|
|
# Block dependent nodes
|
|
self._block_dependent_nodes(node.id)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Node {node.id} execution error: {e}")
|
|
node.mark_failed(str(e))
|
|
self._node_results[node.id] = TaskResult.fail(error=str(e))
|
|
self._block_dependent_nodes(node.id)
|
|
|
|
finally:
|
|
self._running_nodes.discard(node.id)
|
|
|
|
def _unblock_nodes(self) -> None:
|
|
"""Unblock nodes whose dependencies are now satisfied"""
|
|
for node in self.dag.nodes.values():
|
|
if node.status == TaskNodeStatus.BLOCKED:
|
|
if node.can_execute(self._completed_nodes):
|
|
node.mark_ready()
|
|
self._emit_progress(node.id, 0.0, f"Unblocked: {node.name}")
|
|
|
|
def _block_dependent_nodes(self, failed_node_id: str) -> None:
|
|
"""Block nodes that depend on a failed node"""
|
|
blocked = self.dag.get_blocked_nodes(failed_node_id)
|
|
for node in blocked:
|
|
node.status = TaskNodeStatus.BLOCKED
|
|
self._emit_progress(node.id, 0.0, f"Blocked due to: {failed_node_id}")
|
|
|
|
async def _cancel_running_nodes(self) -> None:
|
|
"""Cancel all running nodes"""
|
|
for node_id in self._running_nodes:
|
|
if node_id in self.dag.nodes:
|
|
self.dag.nodes[node_id].mark_cancelled()
|
|
self._running_nodes.clear()
|
|
|
|
def _build_result(self) -> Dict[str, Any]:
|
|
"""Build final result"""
|
|
return {
|
|
"success": self.dag.is_success,
|
|
"dag_id": self.dag.id,
|
|
"total_tasks": self.dag.total_count,
|
|
"completed_tasks": self.dag.completed_count,
|
|
"failed_tasks": self.dag.failed_count,
|
|
"progress": self.dag.progress,
|
|
"results": {
|
|
node_id: result.to_dict()
|
|
for node_id, result in self._node_results.items()
|
|
}
|
|
}
|
|
|
|
def _build_cancelled_result(self) -> Dict[str, Any]:
|
|
"""Build cancelled result"""
|
|
return {
|
|
"success": False,
|
|
"cancelled": True,
|
|
"dag_id": self.dag.id,
|
|
"total_tasks": self.dag.total_count,
|
|
"completed_tasks": self.dag.completed_count,
|
|
"failed_tasks": self.dag.failed_count,
|
|
"progress": self.dag.progress,
|
|
"results": {
|
|
node_id: result.to_dict()
|
|
for node_id, result in self._node_results.items()
|
|
}
|
|
}
|
|
|
|
|
|
class SchedulerPool:
|
|
"""
|
|
Pool of DAG schedulers for managing multiple concurrent DAG executions
|
|
"""
|
|
|
|
def __init__(self, max_concurrent: int = 10):
|
|
"""
|
|
Initialize scheduler pool
|
|
|
|
Args:
|
|
max_concurrent: Maximum concurrent DAG executions
|
|
"""
|
|
self.max_concurrent = max_concurrent
|
|
self._schedulers: Dict[str, DAGScheduler] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def create_scheduler(
|
|
self,
|
|
task_id: str,
|
|
dag: DAG,
|
|
supervisor: SupervisorAgent,
|
|
worker_factory: Callable,
|
|
max_workers: int = 3
|
|
) -> DAGScheduler:
|
|
"""
|
|
Create and register a new scheduler
|
|
|
|
Args:
|
|
task_id: Unique task identifier
|
|
dag: DAG to execute
|
|
supervisor: Supervisor agent
|
|
worker_factory: Worker factory function
|
|
max_workers: Max parallel workers
|
|
|
|
Returns:
|
|
Created scheduler
|
|
"""
|
|
with self._lock:
|
|
if len(self._schedulers) >= self.max_concurrent:
|
|
raise RuntimeError("Maximum concurrent schedulers reached")
|
|
|
|
scheduler = DAGScheduler(
|
|
dag=dag,
|
|
supervisor=supervisor,
|
|
worker_factory=worker_factory,
|
|
max_workers=max_workers
|
|
)
|
|
|
|
self._schedulers[task_id] = scheduler
|
|
return scheduler
|
|
|
|
def get(self, task_id: str) -> Optional[DAGScheduler]:
|
|
"""Get scheduler by task ID"""
|
|
with self._lock:
|
|
return self._schedulers.get(task_id)
|
|
|
|
def remove(self, task_id: str) -> bool:
|
|
"""Remove scheduler"""
|
|
with self._lock:
|
|
if task_id in self._schedulers:
|
|
del self._schedulers[task_id]
|
|
return True
|
|
return False
|
|
|
|
def cancel(self, task_id: str) -> bool:
|
|
"""Cancel a scheduler"""
|
|
scheduler = self.get(task_id)
|
|
if scheduler:
|
|
scheduler.cancel()
|
|
return True
|
|
return False
|
|
|
|
@property
|
|
def active_count(self) -> int:
|
|
"""Number of active schedulers"""
|
|
with self._lock:
|
|
return len(self._schedulers)
|