419 lines
14 KiB
Python
419 lines
14 KiB
Python
"""DAG (Directed Acyclic Graph) and TaskNode models"""
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Set
|
|
from enum import Enum
|
|
from datetime import datetime
|
|
import uuid
|
|
|
|
|
|
class TaskNodeStatus(str, Enum):
|
|
"""Task node status enumeration"""
|
|
PENDING = "pending" # Not yet started
|
|
READY = "ready" # Dependencies satisfied, can start
|
|
RUNNING = "running" # Currently executing
|
|
COMPLETED = "completed" # Successfully completed
|
|
FAILED = "failed" # Execution failed
|
|
CANCELLED = "cancelled" # Cancelled by user
|
|
BLOCKED = "blocked" # Blocked by failed dependency
|
|
|
|
|
|
@dataclass
|
|
class TaskResult:
|
|
"""Task execution result"""
|
|
success: bool
|
|
data: Any = None
|
|
error: Optional[str] = None
|
|
output_data: Optional[Dict[str, Any]] = None # Structured output for supervisor
|
|
execution_time: float = 0.0 # seconds
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary"""
|
|
return {
|
|
"success": self.success,
|
|
"data": self.data,
|
|
"error": self.error,
|
|
"output_data": self.output_data,
|
|
"execution_time": self.execution_time,
|
|
}
|
|
|
|
@classmethod
|
|
def ok(cls, data: Any = None, output_data: Optional[Dict[str, Any]] = None,
|
|
execution_time: float = 0.0) -> "TaskResult":
|
|
"""Create success result"""
|
|
return cls(success=True, data=data, output_data=output_data, execution_time=execution_time)
|
|
|
|
@classmethod
|
|
def fail(cls, error: str, data: Any = None) -> "TaskResult":
|
|
"""Create failure result"""
|
|
return cls(success=False, error=error, data=data)
|
|
|
|
|
|
@dataclass
|
|
class TaskNode:
|
|
"""
|
|
Task node in the DAG
|
|
|
|
Represents a single executable task within the agent workflow.
|
|
"""
|
|
id: str
|
|
name: str
|
|
description: str = ""
|
|
|
|
# Task definition
|
|
task_type: str = "generic" # e.g., "code", "shell", "file", "llm"
|
|
task_data: Dict[str, Any] = field(default_factory=dict) # Task-specific parameters
|
|
|
|
# Dependencies (node IDs that must complete before this node)
|
|
dependencies: List[str] = field(default_factory=list)
|
|
|
|
# Status tracking
|
|
status: TaskNodeStatus = TaskNodeStatus.PENDING
|
|
result: Optional[TaskResult] = None
|
|
|
|
# Progress tracking
|
|
progress: float = 0.0 # 0.0 - 1.0
|
|
progress_message: str = ""
|
|
|
|
# Execution info
|
|
assigned_agent_id: Optional[str] = None
|
|
started_at: Optional[datetime] = None
|
|
completed_at: Optional[datetime] = None
|
|
|
|
# Output data (key-value pairs for dependent tasks)
|
|
output_data: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def __post_init__(self):
|
|
"""Post-initialization"""
|
|
if not self.id:
|
|
self.id = str(uuid.uuid4())[:8]
|
|
|
|
@property
|
|
def is_root(self) -> bool:
|
|
"""Check if this is a root node (no dependencies)"""
|
|
return len(self.dependencies) == 0
|
|
|
|
@property
|
|
def is_leaf(self) -> bool:
|
|
"""Check if this is a leaf node (no dependents)"""
|
|
return False # Will be set by DAG
|
|
|
|
@property
|
|
def execution_time(self) -> float:
|
|
"""Calculate execution time in seconds"""
|
|
if self.started_at and self.completed_at:
|
|
return (self.completed_at - self.started_at).total_seconds()
|
|
return 0.0
|
|
|
|
def mark_ready(self) -> None:
|
|
"""Mark node as ready to execute"""
|
|
self.status = TaskNodeStatus.READY
|
|
|
|
def mark_running(self, agent_id: str) -> None:
|
|
"""Mark node as running"""
|
|
self.status = TaskNodeStatus.RUNNING
|
|
self.assigned_agent_id = agent_id
|
|
self.started_at = datetime.utcnow()
|
|
|
|
def mark_completed(self, result: TaskResult) -> None:
|
|
"""Mark node as completed"""
|
|
self.status = TaskNodeStatus.COMPLETED
|
|
self.result = result
|
|
self.completed_at = datetime.utcnow()
|
|
self.progress = 1.0
|
|
if result.output_data:
|
|
self.output_data = result.output_data
|
|
|
|
def mark_failed(self, error: str) -> None:
|
|
"""Mark node as failed"""
|
|
self.status = TaskNodeStatus.FAILED
|
|
self.result = TaskResult.fail(error)
|
|
self.completed_at = datetime.utcnow()
|
|
|
|
def mark_cancelled(self) -> None:
|
|
"""Mark node as cancelled"""
|
|
self.status = TaskNodeStatus.CANCELLED
|
|
self.completed_at = datetime.utcnow()
|
|
|
|
def update_progress(self, progress: float, message: str = "") -> None:
|
|
"""Update execution progress"""
|
|
self.progress = max(0.0, min(1.0, progress))
|
|
if message:
|
|
self.progress_message = message
|
|
|
|
def can_execute(self, completed_nodes: Set[str]) -> bool:
|
|
"""
|
|
Check if this node can execute (all dependencies completed)
|
|
|
|
Args:
|
|
completed_nodes: Set of completed node IDs
|
|
"""
|
|
if self.status != TaskNodeStatus.READY:
|
|
return False
|
|
return all(dep_id in completed_nodes for dep_id in self.dependencies)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary"""
|
|
return {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"task_type": self.task_type,
|
|
"task_data": self.task_data,
|
|
"dependencies": self.dependencies,
|
|
"status": self.status.value,
|
|
"progress": self.progress,
|
|
"progress_message": self.progress_message,
|
|
"assigned_agent_id": self.assigned_agent_id,
|
|
"result": self.result.to_dict() if self.result else None,
|
|
"output_data": self.output_data,
|
|
"started_at": self.started_at.isoformat() if self.started_at else None,
|
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class DAG:
|
|
"""
|
|
Directed Acyclic Graph for task scheduling
|
|
|
|
Manages the task workflow with dependency tracking and parallel execution support.
|
|
"""
|
|
id: str
|
|
name: str = ""
|
|
description: str = ""
|
|
|
|
# Nodes and edges
|
|
nodes: Dict[str, TaskNode] = field(default_factory=dict)
|
|
edges: List[tuple] = field(default_factory=list) # (from_id, to_id)
|
|
|
|
# Metadata
|
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
|
updated_at: datetime = field(default_factory=datetime.utcnow)
|
|
|
|
# Root and leaf tracking
|
|
_root_nodes: Set[str] = field(default_factory=set)
|
|
_leaf_nodes: Set[str] = field(default_factory=set)
|
|
|
|
def __post_init__(self):
|
|
"""Post-initialization"""
|
|
if not self.id:
|
|
self.id = str(uuid.uuid4())[:8]
|
|
|
|
def add_node(self, node: TaskNode) -> None:
|
|
"""
|
|
Add a task node to the DAG
|
|
|
|
Args:
|
|
node: TaskNode to add
|
|
"""
|
|
self.nodes[node.id] = node
|
|
self._update_root_leaf_cache()
|
|
self.updated_at = datetime.utcnow()
|
|
|
|
def add_edge(self, from_id: str, to_id: str) -> None:
|
|
"""
|
|
Add an edge (dependency) between nodes
|
|
|
|
Args:
|
|
from_id: Source node ID (must complete first)
|
|
to_id: Target node ID (depends on source)
|
|
"""
|
|
if from_id not in self.nodes:
|
|
raise ValueError(f"Source node '{from_id}' not found")
|
|
if to_id not in self.nodes:
|
|
raise ValueError(f"Target node '{to_id}' not found")
|
|
|
|
# Add dependency to target node
|
|
if from_id not in self.nodes[to_id].dependencies:
|
|
self.nodes[to_id].dependencies.append(from_id)
|
|
|
|
self.edges.append((from_id, to_id))
|
|
self._update_root_leaf_cache()
|
|
self.updated_at = datetime.utcnow()
|
|
|
|
def _update_root_leaf_cache(self) -> None:
|
|
"""Update cached root and leaf node sets"""
|
|
self._root_nodes = set()
|
|
self._leaf_nodes = set()
|
|
|
|
for node_id in self.nodes:
|
|
if len(self.nodes[node_id].dependencies) == 0:
|
|
self._root_nodes.add(node_id)
|
|
|
|
# Check if node is a leaf (no one depends on it)
|
|
is_leaf = True
|
|
for other in self.nodes.values():
|
|
if node_id in other.dependencies:
|
|
is_leaf = False
|
|
break
|
|
if is_leaf:
|
|
self._leaf_nodes.add(node_id)
|
|
|
|
@property
|
|
def root_nodes(self) -> List[TaskNode]:
|
|
"""Get all root nodes (nodes with no dependencies)"""
|
|
return [self.nodes[nid] for nid in self._root_nodes if nid in self.nodes]
|
|
|
|
@property
|
|
def leaf_nodes(self) -> List[TaskNode]:
|
|
"""Get all leaf nodes (nodes no one depends on)"""
|
|
return [self.nodes[nid] for nid in self._leaf_nodes if nid in self.nodes]
|
|
|
|
def get_ready_nodes(self, completed_nodes: Set[str]) -> List[TaskNode]:
|
|
"""
|
|
Get all nodes that are ready to execute
|
|
|
|
A node is ready if:
|
|
1. All its dependencies are in completed_nodes
|
|
2. It is not already running or completed
|
|
|
|
Args:
|
|
completed_nodes: Set of completed node IDs
|
|
"""
|
|
ready = []
|
|
for node in self.nodes.values():
|
|
if node.status == TaskNodeStatus.READY and node.can_execute(completed_nodes):
|
|
ready.append(node)
|
|
return ready
|
|
|
|
def get_blocked_nodes(self, failed_node_id: str) -> List[TaskNode]:
|
|
"""
|
|
Get all nodes blocked by a failed node
|
|
|
|
Args:
|
|
failed_node_id: ID of the failed node
|
|
"""
|
|
blocked = []
|
|
for node in self.nodes.values():
|
|
if failed_node_id in node.dependencies:
|
|
if node.status not in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED):
|
|
blocked.append(node)
|
|
return blocked
|
|
|
|
def mark_node_completed(self, node_id: str, result: TaskResult) -> None:
|
|
"""Mark a node as completed with result"""
|
|
if node_id in self.nodes:
|
|
self.nodes[node_id].mark_completed(result)
|
|
self.updated_at = datetime.utcnow()
|
|
|
|
def mark_node_failed(self, node_id: str, error: str) -> None:
|
|
"""Mark a node as failed"""
|
|
if node_id in self.nodes:
|
|
self.nodes[node_id].mark_failed(error)
|
|
self._propagate_failure(node_id)
|
|
self.updated_at = datetime.utcnow()
|
|
|
|
def _propagate_failure(self, failed_node_id: str) -> None:
|
|
"""Propagate failure to dependent nodes"""
|
|
for node in self.nodes.values():
|
|
if failed_node_id in node.dependencies:
|
|
if node.status == TaskNodeStatus.READY:
|
|
node.status = TaskNodeStatus.BLOCKED
|
|
|
|
@property
|
|
def completed_count(self) -> int:
|
|
"""Count of completed nodes"""
|
|
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.COMPLETED)
|
|
|
|
@property
|
|
def failed_count(self) -> int:
|
|
"""Count of failed nodes"""
|
|
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.FAILED)
|
|
|
|
@property
|
|
def total_count(self) -> int:
|
|
"""Total number of nodes"""
|
|
return len(self.nodes)
|
|
|
|
@property
|
|
def progress(self) -> float:
|
|
"""Overall DAG progress (0.0 - 1.0)"""
|
|
if not self.nodes:
|
|
return 0.0
|
|
return sum(n.progress for n in self.nodes.values()) / len(self.nodes)
|
|
|
|
@property
|
|
def is_complete(self) -> bool:
|
|
"""Check if DAG execution is complete (all nodes done)"""
|
|
return all(
|
|
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED)
|
|
for n in self.nodes.values()
|
|
)
|
|
|
|
@property
|
|
def is_success(self) -> bool:
|
|
"""Check if DAG completed successfully"""
|
|
return self.is_complete and self.failed_count == 0
|
|
|
|
def get_execution_order(self) -> List[List[TaskNode]]:
|
|
"""
|
|
Get nodes grouped by execution level (parallel-friendly order)
|
|
|
|
Returns:
|
|
List of node groups, where each group can be executed in parallel
|
|
"""
|
|
levels: List[List[TaskNode]] = []
|
|
completed: Set[str] = set()
|
|
|
|
while len(completed) < len(self.nodes):
|
|
# Find nodes whose dependencies are all completed
|
|
current_level = []
|
|
for node in self.nodes.values():
|
|
if node.id not in completed and node.can_execute(completed):
|
|
current_level.append(node)
|
|
|
|
if not current_level:
|
|
break # Circular dependency or all blocked
|
|
|
|
levels.append(current_level)
|
|
completed.update(n.id for n in current_level)
|
|
|
|
return levels
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for serialization"""
|
|
return {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"nodes": [n.to_dict() for n in self.nodes.values()],
|
|
"edges": [{"from": e[0], "to": e[1]} for e in self.edges],
|
|
"created_at": self.created_at.isoformat(),
|
|
"updated_at": self.updated_at.isoformat(),
|
|
"progress": self.progress,
|
|
"completed_count": self.completed_count,
|
|
"total_count": self.total_count,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "DAG":
|
|
"""
|
|
Create DAG from dictionary
|
|
|
|
Args:
|
|
data: Dictionary with DAG data
|
|
"""
|
|
dag = cls(
|
|
id=data.get("id", str(uuid.uuid4())[:8]),
|
|
name=data.get("name", ""),
|
|
description=data.get("description", ""),
|
|
)
|
|
|
|
# Recreate nodes
|
|
for node_data in data.get("nodes", []):
|
|
node = TaskNode(
|
|
id=node_data["id"],
|
|
name=node_data["name"],
|
|
description=node_data.get("description", ""),
|
|
task_type=node_data.get("task_type", "generic"),
|
|
task_data=node_data.get("task_data", {}),
|
|
dependencies=node_data.get("dependencies", []),
|
|
)
|
|
dag.add_node(node)
|
|
|
|
# Recreate edges
|
|
for edge in data.get("edges", []):
|
|
dag.add_edge(edge["from"], edge["to"])
|
|
|
|
return dag
|