Luxx/luxx/agents/dag.py

419 lines
14 KiB
Python

"""DAG (Directed Acyclic Graph) and TaskNode models"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
from enum import Enum
from datetime import datetime
import uuid
class TaskNodeStatus(str, Enum):
"""Task node status enumeration"""
PENDING = "pending" # Not yet started
READY = "ready" # Dependencies satisfied, can start
RUNNING = "running" # Currently executing
COMPLETED = "completed" # Successfully completed
FAILED = "failed" # Execution failed
CANCELLED = "cancelled" # Cancelled by user
BLOCKED = "blocked" # Blocked by failed dependency
@dataclass
class TaskResult:
"""Task execution result"""
success: bool
data: Any = None
error: Optional[str] = None
output_data: Optional[Dict[str, Any]] = None # Structured output for supervisor
execution_time: float = 0.0 # seconds
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"success": self.success,
"data": self.data,
"error": self.error,
"output_data": self.output_data,
"execution_time": self.execution_time,
}
@classmethod
def ok(cls, data: Any = None, output_data: Optional[Dict[str, Any]] = None,
execution_time: float = 0.0) -> "TaskResult":
"""Create success result"""
return cls(success=True, data=data, output_data=output_data, execution_time=execution_time)
@classmethod
def fail(cls, error: str, data: Any = None) -> "TaskResult":
"""Create failure result"""
return cls(success=False, error=error, data=data)
@dataclass
class TaskNode:
"""
Task node in the DAG
Represents a single executable task within the agent workflow.
"""
id: str
name: str
description: str = ""
# Task definition
task_type: str = "generic" # e.g., "code", "shell", "file", "llm"
task_data: Dict[str, Any] = field(default_factory=dict) # Task-specific parameters
# Dependencies (node IDs that must complete before this node)
dependencies: List[str] = field(default_factory=list)
# Status tracking
status: TaskNodeStatus = TaskNodeStatus.PENDING
result: Optional[TaskResult] = None
# Progress tracking
progress: float = 0.0 # 0.0 - 1.0
progress_message: str = ""
# Execution info
assigned_agent_id: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
# Output data (key-value pairs for dependent tasks)
output_data: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Post-initialization"""
if not self.id:
self.id = str(uuid.uuid4())[:8]
@property
def is_root(self) -> bool:
"""Check if this is a root node (no dependencies)"""
return len(self.dependencies) == 0
@property
def is_leaf(self) -> bool:
"""Check if this is a leaf node (no dependents)"""
return False # Will be set by DAG
@property
def execution_time(self) -> float:
"""Calculate execution time in seconds"""
if self.started_at and self.completed_at:
return (self.completed_at - self.started_at).total_seconds()
return 0.0
def mark_ready(self) -> None:
"""Mark node as ready to execute"""
self.status = TaskNodeStatus.READY
def mark_running(self, agent_id: str) -> None:
"""Mark node as running"""
self.status = TaskNodeStatus.RUNNING
self.assigned_agent_id = agent_id
self.started_at = datetime.utcnow()
def mark_completed(self, result: TaskResult) -> None:
"""Mark node as completed"""
self.status = TaskNodeStatus.COMPLETED
self.result = result
self.completed_at = datetime.utcnow()
self.progress = 1.0
if result.output_data:
self.output_data = result.output_data
def mark_failed(self, error: str) -> None:
"""Mark node as failed"""
self.status = TaskNodeStatus.FAILED
self.result = TaskResult.fail(error)
self.completed_at = datetime.utcnow()
def mark_cancelled(self) -> None:
"""Mark node as cancelled"""
self.status = TaskNodeStatus.CANCELLED
self.completed_at = datetime.utcnow()
def update_progress(self, progress: float, message: str = "") -> None:
"""Update execution progress"""
self.progress = max(0.0, min(1.0, progress))
if message:
self.progress_message = message
def can_execute(self, completed_nodes: Set[str]) -> bool:
"""
Check if this node can execute (all dependencies completed)
Args:
completed_nodes: Set of completed node IDs
"""
if self.status != TaskNodeStatus.READY:
return False
return all(dep_id in completed_nodes for dep_id in self.dependencies)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"task_type": self.task_type,
"task_data": self.task_data,
"dependencies": self.dependencies,
"status": self.status.value,
"progress": self.progress,
"progress_message": self.progress_message,
"assigned_agent_id": self.assigned_agent_id,
"result": self.result.to_dict() if self.result else None,
"output_data": self.output_data,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
@dataclass
class DAG:
"""
Directed Acyclic Graph for task scheduling
Manages the task workflow with dependency tracking and parallel execution support.
"""
id: str
name: str = ""
description: str = ""
# Nodes and edges
nodes: Dict[str, TaskNode] = field(default_factory=dict)
edges: List[tuple] = field(default_factory=list) # (from_id, to_id)
# Metadata
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
# Root and leaf tracking
_root_nodes: Set[str] = field(default_factory=set)
_leaf_nodes: Set[str] = field(default_factory=set)
def __post_init__(self):
"""Post-initialization"""
if not self.id:
self.id = str(uuid.uuid4())[:8]
def add_node(self, node: TaskNode) -> None:
"""
Add a task node to the DAG
Args:
node: TaskNode to add
"""
self.nodes[node.id] = node
self._update_root_leaf_cache()
self.updated_at = datetime.utcnow()
def add_edge(self, from_id: str, to_id: str) -> None:
"""
Add an edge (dependency) between nodes
Args:
from_id: Source node ID (must complete first)
to_id: Target node ID (depends on source)
"""
if from_id not in self.nodes:
raise ValueError(f"Source node '{from_id}' not found")
if to_id not in self.nodes:
raise ValueError(f"Target node '{to_id}' not found")
# Add dependency to target node
if from_id not in self.nodes[to_id].dependencies:
self.nodes[to_id].dependencies.append(from_id)
self.edges.append((from_id, to_id))
self._update_root_leaf_cache()
self.updated_at = datetime.utcnow()
def _update_root_leaf_cache(self) -> None:
"""Update cached root and leaf node sets"""
self._root_nodes = set()
self._leaf_nodes = set()
for node_id in self.nodes:
if len(self.nodes[node_id].dependencies) == 0:
self._root_nodes.add(node_id)
# Check if node is a leaf (no one depends on it)
is_leaf = True
for other in self.nodes.values():
if node_id in other.dependencies:
is_leaf = False
break
if is_leaf:
self._leaf_nodes.add(node_id)
@property
def root_nodes(self) -> List[TaskNode]:
"""Get all root nodes (nodes with no dependencies)"""
return [self.nodes[nid] for nid in self._root_nodes if nid in self.nodes]
@property
def leaf_nodes(self) -> List[TaskNode]:
"""Get all leaf nodes (nodes no one depends on)"""
return [self.nodes[nid] for nid in self._leaf_nodes if nid in self.nodes]
def get_ready_nodes(self, completed_nodes: Set[str]) -> List[TaskNode]:
"""
Get all nodes that are ready to execute
A node is ready if:
1. All its dependencies are in completed_nodes
2. It is not already running or completed
Args:
completed_nodes: Set of completed node IDs
"""
ready = []
for node in self.nodes.values():
if node.status == TaskNodeStatus.READY and node.can_execute(completed_nodes):
ready.append(node)
return ready
def get_blocked_nodes(self, failed_node_id: str) -> List[TaskNode]:
"""
Get all nodes blocked by a failed node
Args:
failed_node_id: ID of the failed node
"""
blocked = []
for node in self.nodes.values():
if failed_node_id in node.dependencies:
if node.status not in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED):
blocked.append(node)
return blocked
def mark_node_completed(self, node_id: str, result: TaskResult) -> None:
"""Mark a node as completed with result"""
if node_id in self.nodes:
self.nodes[node_id].mark_completed(result)
self.updated_at = datetime.utcnow()
def mark_node_failed(self, node_id: str, error: str) -> None:
"""Mark a node as failed"""
if node_id in self.nodes:
self.nodes[node_id].mark_failed(error)
self._propagate_failure(node_id)
self.updated_at = datetime.utcnow()
def _propagate_failure(self, failed_node_id: str) -> None:
"""Propagate failure to dependent nodes"""
for node in self.nodes.values():
if failed_node_id in node.dependencies:
if node.status == TaskNodeStatus.READY:
node.status = TaskNodeStatus.BLOCKED
@property
def completed_count(self) -> int:
"""Count of completed nodes"""
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.COMPLETED)
@property
def failed_count(self) -> int:
"""Count of failed nodes"""
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.FAILED)
@property
def total_count(self) -> int:
"""Total number of nodes"""
return len(self.nodes)
@property
def progress(self) -> float:
"""Overall DAG progress (0.0 - 1.0)"""
if not self.nodes:
return 0.0
return sum(n.progress for n in self.nodes.values()) / len(self.nodes)
@property
def is_complete(self) -> bool:
"""Check if DAG execution is complete (all nodes done)"""
return all(
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED)
for n in self.nodes.values()
)
@property
def is_success(self) -> bool:
"""Check if DAG completed successfully"""
return self.is_complete and self.failed_count == 0
def get_execution_order(self) -> List[List[TaskNode]]:
"""
Get nodes grouped by execution level (parallel-friendly order)
Returns:
List of node groups, where each group can be executed in parallel
"""
levels: List[List[TaskNode]] = []
completed: Set[str] = set()
while len(completed) < len(self.nodes):
# Find nodes whose dependencies are all completed
current_level = []
for node in self.nodes.values():
if node.id not in completed and node.can_execute(completed):
current_level.append(node)
if not current_level:
break # Circular dependency or all blocked
levels.append(current_level)
completed.update(n.id for n in current_level)
return levels
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"nodes": [n.to_dict() for n in self.nodes.values()],
"edges": [{"from": e[0], "to": e[1]} for e in self.edges],
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"progress": self.progress,
"completed_count": self.completed_count,
"total_count": self.total_count,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DAG":
"""
Create DAG from dictionary
Args:
data: Dictionary with DAG data
"""
dag = cls(
id=data.get("id", str(uuid.uuid4())[:8]),
name=data.get("name", ""),
description=data.get("description", ""),
)
# Recreate nodes
for node_data in data.get("nodes", []):
node = TaskNode(
id=node_data["id"],
name=node_data["name"],
description=node_data.get("description", ""),
task_type=node_data.get("task_type", "generic"),
task_data=node_data.get("task_data", {}),
dependencies=node_data.get("dependencies", []),
)
dag.add_node(node)
# Recreate edges
for edge in data.get("edges", []):
dag.add_edge(edge["from"], edge["to"])
return dag