"""DAG (Directed Acyclic Graph) and TaskNode models""" from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set from enum import Enum from datetime import datetime import uuid class TaskNodeStatus(str, Enum): """Task node status enumeration""" PENDING = "pending" # Not yet started READY = "ready" # Dependencies satisfied, can start RUNNING = "running" # Currently executing COMPLETED = "completed" # Successfully completed FAILED = "failed" # Execution failed CANCELLED = "cancelled" # Cancelled by user BLOCKED = "blocked" # Blocked by failed dependency @dataclass class TaskResult: """Task execution result""" success: bool data: Any = None error: Optional[str] = None output_data: Optional[Dict[str, Any]] = None # Structured output for supervisor execution_time: float = 0.0 # seconds def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { "success": self.success, "data": self.data, "error": self.error, "output_data": self.output_data, "execution_time": self.execution_time, } @classmethod def ok(cls, data: Any = None, output_data: Optional[Dict[str, Any]] = None, execution_time: float = 0.0) -> "TaskResult": """Create success result""" return cls(success=True, data=data, output_data=output_data, execution_time=execution_time) @classmethod def fail(cls, error: str, data: Any = None) -> "TaskResult": """Create failure result""" return cls(success=False, error=error, data=data) @dataclass class TaskNode: """ Task node in the DAG Represents a single executable task within the agent workflow. """ id: str name: str description: str = "" # Task definition task_type: str = "generic" # e.g., "code", "shell", "file", "llm" task_data: Dict[str, Any] = field(default_factory=dict) # Task-specific parameters # Dependencies (node IDs that must complete before this node) dependencies: List[str] = field(default_factory=list) # Status tracking status: TaskNodeStatus = TaskNodeStatus.PENDING result: Optional[TaskResult] = None # Progress tracking progress: float = 0.0 # 0.0 - 1.0 progress_message: str = "" # Execution info assigned_agent_id: Optional[str] = None started_at: Optional[datetime] = None completed_at: Optional[datetime] = None # Output data (key-value pairs for dependent tasks) output_data: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): """Post-initialization""" if not self.id: self.id = str(uuid.uuid4())[:8] @property def is_root(self) -> bool: """Check if this is a root node (no dependencies)""" return len(self.dependencies) == 0 @property def is_leaf(self) -> bool: """Check if this is a leaf node (no dependents)""" return False # Will be set by DAG @property def execution_time(self) -> float: """Calculate execution time in seconds""" if self.started_at and self.completed_at: return (self.completed_at - self.started_at).total_seconds() return 0.0 def mark_ready(self) -> None: """Mark node as ready to execute""" self.status = TaskNodeStatus.READY def mark_running(self, agent_id: str) -> None: """Mark node as running""" self.status = TaskNodeStatus.RUNNING self.assigned_agent_id = agent_id self.started_at = datetime.utcnow() def mark_completed(self, result: TaskResult) -> None: """Mark node as completed""" self.status = TaskNodeStatus.COMPLETED self.result = result self.completed_at = datetime.utcnow() self.progress = 1.0 if result.output_data: self.output_data = result.output_data def mark_failed(self, error: str) -> None: """Mark node as failed""" self.status = TaskNodeStatus.FAILED self.result = TaskResult.fail(error) self.completed_at = datetime.utcnow() def mark_cancelled(self) -> None: """Mark node as cancelled""" self.status = TaskNodeStatus.CANCELLED self.completed_at = datetime.utcnow() def update_progress(self, progress: float, message: str = "") -> None: """Update execution progress""" self.progress = max(0.0, min(1.0, progress)) if message: self.progress_message = message def can_execute(self, completed_nodes: Set[str]) -> bool: """ Check if this node can execute (all dependencies completed) Args: completed_nodes: Set of completed node IDs """ if self.status != TaskNodeStatus.READY: return False return all(dep_id in completed_nodes for dep_id in self.dependencies) def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { "id": self.id, "name": self.name, "description": self.description, "task_type": self.task_type, "task_data": self.task_data, "dependencies": self.dependencies, "status": self.status.value, "progress": self.progress, "progress_message": self.progress_message, "assigned_agent_id": self.assigned_agent_id, "result": self.result.to_dict() if self.result else None, "output_data": self.output_data, "started_at": self.started_at.isoformat() if self.started_at else None, "completed_at": self.completed_at.isoformat() if self.completed_at else None, } @dataclass class DAG: """ Directed Acyclic Graph for task scheduling Manages the task workflow with dependency tracking and parallel execution support. """ id: str name: str = "" description: str = "" # Nodes and edges nodes: Dict[str, TaskNode] = field(default_factory=dict) edges: List[tuple] = field(default_factory=list) # (from_id, to_id) # Metadata created_at: datetime = field(default_factory=datetime.utcnow) updated_at: datetime = field(default_factory=datetime.utcnow) # Root and leaf tracking _root_nodes: Set[str] = field(default_factory=set) _leaf_nodes: Set[str] = field(default_factory=set) def __post_init__(self): """Post-initialization""" if not self.id: self.id = str(uuid.uuid4())[:8] def add_node(self, node: TaskNode) -> None: """ Add a task node to the DAG Args: node: TaskNode to add """ self.nodes[node.id] = node self._update_root_leaf_cache() self.updated_at = datetime.utcnow() def add_edge(self, from_id: str, to_id: str) -> None: """ Add an edge (dependency) between nodes Args: from_id: Source node ID (must complete first) to_id: Target node ID (depends on source) """ if from_id not in self.nodes: raise ValueError(f"Source node '{from_id}' not found") if to_id not in self.nodes: raise ValueError(f"Target node '{to_id}' not found") # Add dependency to target node if from_id not in self.nodes[to_id].dependencies: self.nodes[to_id].dependencies.append(from_id) self.edges.append((from_id, to_id)) self._update_root_leaf_cache() self.updated_at = datetime.utcnow() def _update_root_leaf_cache(self) -> None: """Update cached root and leaf node sets""" self._root_nodes = set() self._leaf_nodes = set() for node_id in self.nodes: if len(self.nodes[node_id].dependencies) == 0: self._root_nodes.add(node_id) # Check if node is a leaf (no one depends on it) is_leaf = True for other in self.nodes.values(): if node_id in other.dependencies: is_leaf = False break if is_leaf: self._leaf_nodes.add(node_id) @property def root_nodes(self) -> List[TaskNode]: """Get all root nodes (nodes with no dependencies)""" return [self.nodes[nid] for nid in self._root_nodes if nid in self.nodes] @property def leaf_nodes(self) -> List[TaskNode]: """Get all leaf nodes (nodes no one depends on)""" return [self.nodes[nid] for nid in self._leaf_nodes if nid in self.nodes] def get_ready_nodes(self, completed_nodes: Set[str]) -> List[TaskNode]: """ Get all nodes that are ready to execute A node is ready if: 1. All its dependencies are in completed_nodes 2. It is not already running or completed Args: completed_nodes: Set of completed node IDs """ ready = [] for node in self.nodes.values(): if node.status == TaskNodeStatus.READY and node.can_execute(completed_nodes): ready.append(node) return ready def get_blocked_nodes(self, failed_node_id: str) -> List[TaskNode]: """ Get all nodes blocked by a failed node Args: failed_node_id: ID of the failed node """ blocked = [] for node in self.nodes.values(): if failed_node_id in node.dependencies: if node.status not in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED): blocked.append(node) return blocked def mark_node_completed(self, node_id: str, result: TaskResult) -> None: """Mark a node as completed with result""" if node_id in self.nodes: self.nodes[node_id].mark_completed(result) self.updated_at = datetime.utcnow() def mark_node_failed(self, node_id: str, error: str) -> None: """Mark a node as failed""" if node_id in self.nodes: self.nodes[node_id].mark_failed(error) self._propagate_failure(node_id) self.updated_at = datetime.utcnow() def _propagate_failure(self, failed_node_id: str) -> None: """Propagate failure to dependent nodes""" for node in self.nodes.values(): if failed_node_id in node.dependencies: if node.status == TaskNodeStatus.READY: node.status = TaskNodeStatus.BLOCKED @property def completed_count(self) -> int: """Count of completed nodes""" return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.COMPLETED) @property def failed_count(self) -> int: """Count of failed nodes""" return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.FAILED) @property def total_count(self) -> int: """Total number of nodes""" return len(self.nodes) @property def progress(self) -> float: """Overall DAG progress (0.0 - 1.0)""" if not self.nodes: return 0.0 return sum(n.progress for n in self.nodes.values()) / len(self.nodes) @property def is_complete(self) -> bool: """Check if DAG execution is complete (all nodes done)""" return all( n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED) for n in self.nodes.values() ) @property def is_success(self) -> bool: """Check if DAG completed successfully""" return self.is_complete and self.failed_count == 0 def get_execution_order(self) -> List[List[TaskNode]]: """ Get nodes grouped by execution level (parallel-friendly order) Returns: List of node groups, where each group can be executed in parallel """ levels: List[List[TaskNode]] = [] completed: Set[str] = set() while len(completed) < len(self.nodes): # Find nodes whose dependencies are all completed current_level = [] for node in self.nodes.values(): if node.id not in completed and node.can_execute(completed): current_level.append(node) if not current_level: break # Circular dependency or all blocked levels.append(current_level) completed.update(n.id for n in current_level) return levels def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization""" return { "id": self.id, "name": self.name, "description": self.description, "nodes": [n.to_dict() for n in self.nodes.values()], "edges": [{"from": e[0], "to": e[1]} for e in self.edges], "created_at": self.created_at.isoformat(), "updated_at": self.updated_at.isoformat(), "progress": self.progress, "completed_count": self.completed_count, "total_count": self.total_count, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "DAG": """ Create DAG from dictionary Args: data: Dictionary with DAG data """ dag = cls( id=data.get("id", str(uuid.uuid4())[:8]), name=data.get("name", ""), description=data.get("description", ""), ) # Recreate nodes for node_data in data.get("nodes", []): node = TaskNode( id=node_data["id"], name=node_data["name"], description=node_data.get("description", ""), task_type=node_data.get("task_type", "generic"), task_data=node_data.get("task_data", {}), dependencies=node_data.get("dependencies", []), ) dag.add_node(node) # Recreate edges for edge in data.get("edges", []): dag.add_edge(edge["from"], edge["to"]) return dag