feat: 增加agent 的后端
This commit is contained in:
parent
22a4b8a4bb
commit
8089d94e78
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Multi-Agent system module"""
|
||||||
|
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||||
|
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
||||||
|
from luxx.agents.registry import AgentRegistry
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Agent",
|
||||||
|
"AgentConfig",
|
||||||
|
"AgentType",
|
||||||
|
"AgentStatus",
|
||||||
|
"DAG",
|
||||||
|
"TaskNode",
|
||||||
|
"TaskNodeStatus",
|
||||||
|
"TaskResult",
|
||||||
|
"AgentRegistry",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,161 @@
|
||||||
|
"""Agent core models"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from enum import Enum
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from luxx.tools.core import CommandPermission
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(str, Enum):
|
||||||
|
"""Agent type enumeration"""
|
||||||
|
SUPERVISOR = "supervisor"
|
||||||
|
WORKER = "worker"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStatus(str, Enum):
|
||||||
|
"""Agent status enumeration"""
|
||||||
|
IDLE = "idle"
|
||||||
|
PLANNING = "planning"
|
||||||
|
EXECUTING = "executing"
|
||||||
|
WAITING = "waiting"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentConfig:
|
||||||
|
"""Agent configuration"""
|
||||||
|
name: str
|
||||||
|
agent_type: AgentType
|
||||||
|
description: str = ""
|
||||||
|
max_permission: CommandPermission = CommandPermission.EXECUTE
|
||||||
|
max_turns: int = 10 # Context window: sliding window size
|
||||||
|
model: str = "deepseek-chat"
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: int = 4096
|
||||||
|
system_prompt: str = ""
|
||||||
|
tools: List[str] = field(default_factory=list) # Tool names available to this agent
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Agent:
|
||||||
|
"""
|
||||||
|
Agent entity
|
||||||
|
|
||||||
|
Represents an AI agent with its configuration, state, and context.
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
config: AgentConfig
|
||||||
|
status: AgentStatus = AgentStatus.IDLE
|
||||||
|
user_id: Optional[int] = None
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
workspace: Optional[str] = None
|
||||||
|
|
||||||
|
# Runtime state
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
# Context management
|
||||||
|
context_window: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
accumulated_result: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Progress tracking
|
||||||
|
current_task_id: Optional[str] = None
|
||||||
|
progress: float = 0.0 # 0.0 - 1.0
|
||||||
|
|
||||||
|
# Permission (effective permission = min(user_permission, agent.max_permission))
|
||||||
|
effective_permission: CommandPermission = field(default_factory=lambda: CommandPermission.EXECUTE)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Post-initialization processing"""
|
||||||
|
if self.config.system_prompt and not self.context_window:
|
||||||
|
# Initialize with system prompt
|
||||||
|
self.context_window = [
|
||||||
|
{"role": "system", "content": self.config.system_prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
def add_message(self, role: str, content: str) -> None:
|
||||||
|
"""
|
||||||
|
Add message to context window with sliding window management
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role: Message role (user/assistant/system)
|
||||||
|
content: Message content
|
||||||
|
"""
|
||||||
|
self.context_window.append({"role": role, "content": content})
|
||||||
|
self._trim_context()
|
||||||
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def _trim_context(self) -> None:
|
||||||
|
"""
|
||||||
|
Trim context window using sliding window strategy
|
||||||
|
|
||||||
|
Keeps system prompt and the most recent N turns (user+assistant pairs)
|
||||||
|
"""
|
||||||
|
max_items = 1 + (self.config.max_turns * 2) # system + (max_turns * 2)
|
||||||
|
|
||||||
|
# Always keep system prompt at index 0
|
||||||
|
if len(self.context_window) > max_items and len(self.context_window) > 1:
|
||||||
|
# Keep system prompt and the most recent messages
|
||||||
|
system_prompt = self.context_window[0]
|
||||||
|
remaining = self.context_window[1:]
|
||||||
|
trimmed = remaining[-(max_items - 1):]
|
||||||
|
self.context_window = [system_prompt] + trimmed
|
||||||
|
|
||||||
|
def get_context(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Get current context window"""
|
||||||
|
return self.context_window.copy()
|
||||||
|
|
||||||
|
def set_user_permission(self, user_permission: CommandPermission) -> None:
|
||||||
|
"""
|
||||||
|
Set effective permission based on user and agent limits
|
||||||
|
|
||||||
|
Effective permission = min(user_permission, agent.max_permission)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_permission: User's permission level
|
||||||
|
"""
|
||||||
|
self.effective_permission = min(user_permission, self.config.max_permission)
|
||||||
|
|
||||||
|
def store_result(self, key: str, value: Any) -> None:
|
||||||
|
"""
|
||||||
|
Store result for supervisor's result-based context management
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Result key
|
||||||
|
value: Result value
|
||||||
|
"""
|
||||||
|
self.accumulated_result[key] = value
|
||||||
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def get_result(self, key: str) -> Optional[Any]:
|
||||||
|
"""Get stored result by key"""
|
||||||
|
return self.accumulated_result.get(key)
|
||||||
|
|
||||||
|
def clear_context(self) -> None:
|
||||||
|
"""Clear context but keep system prompt"""
|
||||||
|
if self.context_window and self.context_window[0]["role"] == "system":
|
||||||
|
system_prompt = self.context_window[0]
|
||||||
|
self.context_window = [system_prompt]
|
||||||
|
else:
|
||||||
|
self.context_window = []
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for serialization"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.config.name,
|
||||||
|
"type": self.config.agent_type.value,
|
||||||
|
"status": self.status.value,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"conversation_id": self.conversation_id,
|
||||||
|
"workspace": self.workspace,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"updated_at": self.updated_at.isoformat(),
|
||||||
|
"current_task_id": self.current_task_id,
|
||||||
|
"progress": self.progress,
|
||||||
|
"effective_permission": self.effective_permission.name,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,418 @@
|
||||||
|
"""DAG (Directed Acyclic Graph) and TaskNode models"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
from enum import Enum
|
||||||
|
from datetime import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class TaskNodeStatus(str, Enum):
|
||||||
|
"""Task node status enumeration"""
|
||||||
|
PENDING = "pending" # Not yet started
|
||||||
|
READY = "ready" # Dependencies satisfied, can start
|
||||||
|
RUNNING = "running" # Currently executing
|
||||||
|
COMPLETED = "completed" # Successfully completed
|
||||||
|
FAILED = "failed" # Execution failed
|
||||||
|
CANCELLED = "cancelled" # Cancelled by user
|
||||||
|
BLOCKED = "blocked" # Blocked by failed dependency
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskResult:
|
||||||
|
"""Task execution result"""
|
||||||
|
success: bool
|
||||||
|
data: Any = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
output_data: Optional[Dict[str, Any]] = None # Structured output for supervisor
|
||||||
|
execution_time: float = 0.0 # seconds
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary"""
|
||||||
|
return {
|
||||||
|
"success": self.success,
|
||||||
|
"data": self.data,
|
||||||
|
"error": self.error,
|
||||||
|
"output_data": self.output_data,
|
||||||
|
"execution_time": self.execution_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ok(cls, data: Any = None, output_data: Optional[Dict[str, Any]] = None,
|
||||||
|
execution_time: float = 0.0) -> "TaskResult":
|
||||||
|
"""Create success result"""
|
||||||
|
return cls(success=True, data=data, output_data=output_data, execution_time=execution_time)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fail(cls, error: str, data: Any = None) -> "TaskResult":
|
||||||
|
"""Create failure result"""
|
||||||
|
return cls(success=False, error=error, data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskNode:
|
||||||
|
"""
|
||||||
|
Task node in the DAG
|
||||||
|
|
||||||
|
Represents a single executable task within the agent workflow.
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
# Task definition
|
||||||
|
task_type: str = "generic" # e.g., "code", "shell", "file", "llm"
|
||||||
|
task_data: Dict[str, Any] = field(default_factory=dict) # Task-specific parameters
|
||||||
|
|
||||||
|
# Dependencies (node IDs that must complete before this node)
|
||||||
|
dependencies: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Status tracking
|
||||||
|
status: TaskNodeStatus = TaskNodeStatus.PENDING
|
||||||
|
result: Optional[TaskResult] = None
|
||||||
|
|
||||||
|
# Progress tracking
|
||||||
|
progress: float = 0.0 # 0.0 - 1.0
|
||||||
|
progress_message: str = ""
|
||||||
|
|
||||||
|
# Execution info
|
||||||
|
assigned_agent_id: Optional[str] = None
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
# Output data (key-value pairs for dependent tasks)
|
||||||
|
output_data: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Post-initialization"""
|
||||||
|
if not self.id:
|
||||||
|
self.id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_root(self) -> bool:
|
||||||
|
"""Check if this is a root node (no dependencies)"""
|
||||||
|
return len(self.dependencies) == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_leaf(self) -> bool:
|
||||||
|
"""Check if this is a leaf node (no dependents)"""
|
||||||
|
return False # Will be set by DAG
|
||||||
|
|
||||||
|
@property
|
||||||
|
def execution_time(self) -> float:
|
||||||
|
"""Calculate execution time in seconds"""
|
||||||
|
if self.started_at and self.completed_at:
|
||||||
|
return (self.completed_at - self.started_at).total_seconds()
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def mark_ready(self) -> None:
|
||||||
|
"""Mark node as ready to execute"""
|
||||||
|
self.status = TaskNodeStatus.READY
|
||||||
|
|
||||||
|
def mark_running(self, agent_id: str) -> None:
|
||||||
|
"""Mark node as running"""
|
||||||
|
self.status = TaskNodeStatus.RUNNING
|
||||||
|
self.assigned_agent_id = agent_id
|
||||||
|
self.started_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def mark_completed(self, result: TaskResult) -> None:
|
||||||
|
"""Mark node as completed"""
|
||||||
|
self.status = TaskNodeStatus.COMPLETED
|
||||||
|
self.result = result
|
||||||
|
self.completed_at = datetime.utcnow()
|
||||||
|
self.progress = 1.0
|
||||||
|
if result.output_data:
|
||||||
|
self.output_data = result.output_data
|
||||||
|
|
||||||
|
def mark_failed(self, error: str) -> None:
|
||||||
|
"""Mark node as failed"""
|
||||||
|
self.status = TaskNodeStatus.FAILED
|
||||||
|
self.result = TaskResult.fail(error)
|
||||||
|
self.completed_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def mark_cancelled(self) -> None:
|
||||||
|
"""Mark node as cancelled"""
|
||||||
|
self.status = TaskNodeStatus.CANCELLED
|
||||||
|
self.completed_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def update_progress(self, progress: float, message: str = "") -> None:
|
||||||
|
"""Update execution progress"""
|
||||||
|
self.progress = max(0.0, min(1.0, progress))
|
||||||
|
if message:
|
||||||
|
self.progress_message = message
|
||||||
|
|
||||||
|
def can_execute(self, completed_nodes: Set[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this node can execute (all dependencies completed)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
completed_nodes: Set of completed node IDs
|
||||||
|
"""
|
||||||
|
if self.status != TaskNodeStatus.READY:
|
||||||
|
return False
|
||||||
|
return all(dep_id in completed_nodes for dep_id in self.dependencies)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"task_type": self.task_type,
|
||||||
|
"task_data": self.task_data,
|
||||||
|
"dependencies": self.dependencies,
|
||||||
|
"status": self.status.value,
|
||||||
|
"progress": self.progress,
|
||||||
|
"progress_message": self.progress_message,
|
||||||
|
"assigned_agent_id": self.assigned_agent_id,
|
||||||
|
"result": self.result.to_dict() if self.result else None,
|
||||||
|
"output_data": self.output_data,
|
||||||
|
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||||
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DAG:
|
||||||
|
"""
|
||||||
|
Directed Acyclic Graph for task scheduling
|
||||||
|
|
||||||
|
Manages the task workflow with dependency tracking and parallel execution support.
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
name: str = ""
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
# Nodes and edges
|
||||||
|
nodes: Dict[str, TaskNode] = field(default_factory=dict)
|
||||||
|
edges: List[tuple] = field(default_factory=list) # (from_id, to_id)
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
# Root and leaf tracking
|
||||||
|
_root_nodes: Set[str] = field(default_factory=set)
|
||||||
|
_leaf_nodes: Set[str] = field(default_factory=set)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Post-initialization"""
|
||||||
|
if not self.id:
|
||||||
|
self.id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
def add_node(self, node: TaskNode) -> None:
|
||||||
|
"""
|
||||||
|
Add a task node to the DAG
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: TaskNode to add
|
||||||
|
"""
|
||||||
|
self.nodes[node.id] = node
|
||||||
|
self._update_root_leaf_cache()
|
||||||
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def add_edge(self, from_id: str, to_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Add an edge (dependency) between nodes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_id: Source node ID (must complete first)
|
||||||
|
to_id: Target node ID (depends on source)
|
||||||
|
"""
|
||||||
|
if from_id not in self.nodes:
|
||||||
|
raise ValueError(f"Source node '{from_id}' not found")
|
||||||
|
if to_id not in self.nodes:
|
||||||
|
raise ValueError(f"Target node '{to_id}' not found")
|
||||||
|
|
||||||
|
# Add dependency to target node
|
||||||
|
if from_id not in self.nodes[to_id].dependencies:
|
||||||
|
self.nodes[to_id].dependencies.append(from_id)
|
||||||
|
|
||||||
|
self.edges.append((from_id, to_id))
|
||||||
|
self._update_root_leaf_cache()
|
||||||
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def _update_root_leaf_cache(self) -> None:
|
||||||
|
"""Update cached root and leaf node sets"""
|
||||||
|
self._root_nodes = set()
|
||||||
|
self._leaf_nodes = set()
|
||||||
|
|
||||||
|
for node_id in self.nodes:
|
||||||
|
if len(self.nodes[node_id].dependencies) == 0:
|
||||||
|
self._root_nodes.add(node_id)
|
||||||
|
|
||||||
|
# Check if node is a leaf (no one depends on it)
|
||||||
|
is_leaf = True
|
||||||
|
for other in self.nodes.values():
|
||||||
|
if node_id in other.dependencies:
|
||||||
|
is_leaf = False
|
||||||
|
break
|
||||||
|
if is_leaf:
|
||||||
|
self._leaf_nodes.add(node_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def root_nodes(self) -> List[TaskNode]:
|
||||||
|
"""Get all root nodes (nodes with no dependencies)"""
|
||||||
|
return [self.nodes[nid] for nid in self._root_nodes if nid in self.nodes]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def leaf_nodes(self) -> List[TaskNode]:
|
||||||
|
"""Get all leaf nodes (nodes no one depends on)"""
|
||||||
|
return [self.nodes[nid] for nid in self._leaf_nodes if nid in self.nodes]
|
||||||
|
|
||||||
|
def get_ready_nodes(self, completed_nodes: Set[str]) -> List[TaskNode]:
|
||||||
|
"""
|
||||||
|
Get all nodes that are ready to execute
|
||||||
|
|
||||||
|
A node is ready if:
|
||||||
|
1. All its dependencies are in completed_nodes
|
||||||
|
2. It is not already running or completed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
completed_nodes: Set of completed node IDs
|
||||||
|
"""
|
||||||
|
ready = []
|
||||||
|
for node in self.nodes.values():
|
||||||
|
if node.status == TaskNodeStatus.READY and node.can_execute(completed_nodes):
|
||||||
|
ready.append(node)
|
||||||
|
return ready
|
||||||
|
|
||||||
|
def get_blocked_nodes(self, failed_node_id: str) -> List[TaskNode]:
|
||||||
|
"""
|
||||||
|
Get all nodes blocked by a failed node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
failed_node_id: ID of the failed node
|
||||||
|
"""
|
||||||
|
blocked = []
|
||||||
|
for node in self.nodes.values():
|
||||||
|
if failed_node_id in node.dependencies:
|
||||||
|
if node.status not in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED):
|
||||||
|
blocked.append(node)
|
||||||
|
return blocked
|
||||||
|
|
||||||
|
def mark_node_completed(self, node_id: str, result: TaskResult) -> None:
|
||||||
|
"""Mark a node as completed with result"""
|
||||||
|
if node_id in self.nodes:
|
||||||
|
self.nodes[node_id].mark_completed(result)
|
||||||
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def mark_node_failed(self, node_id: str, error: str) -> None:
|
||||||
|
"""Mark a node as failed"""
|
||||||
|
if node_id in self.nodes:
|
||||||
|
self.nodes[node_id].mark_failed(error)
|
||||||
|
self._propagate_failure(node_id)
|
||||||
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def _propagate_failure(self, failed_node_id: str) -> None:
|
||||||
|
"""Propagate failure to dependent nodes"""
|
||||||
|
for node in self.nodes.values():
|
||||||
|
if failed_node_id in node.dependencies:
|
||||||
|
if node.status == TaskNodeStatus.READY:
|
||||||
|
node.status = TaskNodeStatus.BLOCKED
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completed_count(self) -> int:
|
||||||
|
"""Count of completed nodes"""
|
||||||
|
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.COMPLETED)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failed_count(self) -> int:
|
||||||
|
"""Count of failed nodes"""
|
||||||
|
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.FAILED)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_count(self) -> int:
|
||||||
|
"""Total number of nodes"""
|
||||||
|
return len(self.nodes)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def progress(self) -> float:
|
||||||
|
"""Overall DAG progress (0.0 - 1.0)"""
|
||||||
|
if not self.nodes:
|
||||||
|
return 0.0
|
||||||
|
return sum(n.progress for n in self.nodes.values()) / len(self.nodes)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_complete(self) -> bool:
|
||||||
|
"""Check if DAG execution is complete (all nodes done)"""
|
||||||
|
return all(
|
||||||
|
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED)
|
||||||
|
for n in self.nodes.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_success(self) -> bool:
|
||||||
|
"""Check if DAG completed successfully"""
|
||||||
|
return self.is_complete and self.failed_count == 0
|
||||||
|
|
||||||
|
def get_execution_order(self) -> List[List[TaskNode]]:
|
||||||
|
"""
|
||||||
|
Get nodes grouped by execution level (parallel-friendly order)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of node groups, where each group can be executed in parallel
|
||||||
|
"""
|
||||||
|
levels: List[List[TaskNode]] = []
|
||||||
|
completed: Set[str] = set()
|
||||||
|
|
||||||
|
while len(completed) < len(self.nodes):
|
||||||
|
# Find nodes whose dependencies are all completed
|
||||||
|
current_level = []
|
||||||
|
for node in self.nodes.values():
|
||||||
|
if node.id not in completed and node.can_execute(completed):
|
||||||
|
current_level.append(node)
|
||||||
|
|
||||||
|
if not current_level:
|
||||||
|
break # Circular dependency or all blocked
|
||||||
|
|
||||||
|
levels.append(current_level)
|
||||||
|
completed.update(n.id for n in current_level)
|
||||||
|
|
||||||
|
return levels
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for serialization"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"nodes": [n.to_dict() for n in self.nodes.values()],
|
||||||
|
"edges": [{"from": e[0], "to": e[1]} for e in self.edges],
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"updated_at": self.updated_at.isoformat(),
|
||||||
|
"progress": self.progress,
|
||||||
|
"completed_count": self.completed_count,
|
||||||
|
"total_count": self.total_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "DAG":
|
||||||
|
"""
|
||||||
|
Create DAG from dictionary
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary with DAG data
|
||||||
|
"""
|
||||||
|
dag = cls(
|
||||||
|
id=data.get("id", str(uuid.uuid4())[:8]),
|
||||||
|
name=data.get("name", ""),
|
||||||
|
description=data.get("description", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recreate nodes
|
||||||
|
for node_data in data.get("nodes", []):
|
||||||
|
node = TaskNode(
|
||||||
|
id=node_data["id"],
|
||||||
|
name=node_data["name"],
|
||||||
|
description=node_data.get("description", ""),
|
||||||
|
task_type=node_data.get("task_type", "generic"),
|
||||||
|
task_data=node_data.get("task_data", {}),
|
||||||
|
dependencies=node_data.get("dependencies", []),
|
||||||
|
)
|
||||||
|
dag.add_node(node)
|
||||||
|
|
||||||
|
# Recreate edges
|
||||||
|
for edge in data.get("edges", []):
|
||||||
|
dag.add_edge(edge["from"], edge["to"])
|
||||||
|
|
||||||
|
return dag
|
||||||
|
|
@ -0,0 +1,360 @@
|
||||||
|
"""DAG Scheduler - orchestrates parallel task execution"""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||||
|
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
||||||
|
from luxx.agents.supervisor import SupervisorAgent
|
||||||
|
from luxx.agents.worker import WorkerAgent
|
||||||
|
from luxx.tools.executor import ToolExecutor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DAGScheduler:
|
||||||
|
"""
|
||||||
|
DAG Scheduler
|
||||||
|
|
||||||
|
Orchestrates parallel execution of tasks based on DAG structure.
|
||||||
|
Features:
|
||||||
|
- Parallel execution with max_workers limit
|
||||||
|
- Dependency-aware scheduling
|
||||||
|
- Real-time progress tracking
|
||||||
|
- Cancellation support
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dag: DAG,
|
||||||
|
supervisor: SupervisorAgent,
|
||||||
|
worker_factory: Callable[[], WorkerAgent],
|
||||||
|
max_workers: int = 3,
|
||||||
|
tool_executor: ToolExecutor = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize DAG Scheduler
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag: DAG to execute
|
||||||
|
supervisor: Supervisor agent instance
|
||||||
|
worker_factory: Factory function to create worker agents
|
||||||
|
max_workers: Maximum parallel workers
|
||||||
|
tool_executor: Tool executor instance
|
||||||
|
"""
|
||||||
|
self.dag = dag
|
||||||
|
self.supervisor = supervisor
|
||||||
|
self.worker_factory = worker_factory
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.tool_executor = tool_executor or ToolExecutor()
|
||||||
|
|
||||||
|
# Execution state
|
||||||
|
self._running_nodes: Set[str] = set()
|
||||||
|
self._completed_nodes: Set[str] = set()
|
||||||
|
self._node_results: Dict[str, TaskResult] = {}
|
||||||
|
self._parent_outputs: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# Control flags
|
||||||
|
self._cancelled = False
|
||||||
|
self._cancel_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Progress callback
|
||||||
|
self._progress_callback: Optional[Callable] = None
|
||||||
|
|
||||||
|
def set_progress_callback(self, callback: Optional[Callable]) -> None:
|
||||||
|
"""Set progress callback for real-time updates"""
|
||||||
|
self._progress_callback = callback
|
||||||
|
|
||||||
|
def _emit_progress(self, node_id: str, progress: float, message: str = "") -> None:
|
||||||
|
"""Emit progress update"""
|
||||||
|
if self._progress_callback:
|
||||||
|
self._progress_callback(node_id, progress, message)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
task: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute the DAG
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Execution context (workspace, user info, etc.)
|
||||||
|
task: Original task description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Execution results
|
||||||
|
"""
|
||||||
|
self._cancelled = False
|
||||||
|
self._running_nodes.clear()
|
||||||
|
self._completed_nodes.clear()
|
||||||
|
self._node_results.clear()
|
||||||
|
self._parent_outputs.clear()
|
||||||
|
|
||||||
|
# Emit DAG start
|
||||||
|
self._emit_progress("dag", 0.0, "Starting DAG execution")
|
||||||
|
|
||||||
|
# Initialize nodes - mark root nodes as ready
|
||||||
|
for node in self.dag.root_nodes:
|
||||||
|
node.mark_ready()
|
||||||
|
|
||||||
|
# Main execution loop
|
||||||
|
while not self._is_execution_complete():
|
||||||
|
if self._is_cancelled():
|
||||||
|
await self._cancel_running_nodes()
|
||||||
|
return self._build_cancelled_result()
|
||||||
|
|
||||||
|
# Get ready nodes that can be executed
|
||||||
|
ready_nodes = self._get_ready_nodes()
|
||||||
|
|
||||||
|
if not ready_nodes and not self._running_nodes:
|
||||||
|
# No ready nodes and nothing running - deadlock or all blocked
|
||||||
|
logger.warning("No ready nodes and no running nodes - possible deadlock")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Launch ready nodes up to max_workers
|
||||||
|
nodes_to_launch = [
|
||||||
|
n for n in ready_nodes
|
||||||
|
if len(self._running_nodes) < self.max_workers
|
||||||
|
and n.id not in self._running_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
for node in nodes_to_launch:
|
||||||
|
asyncio.create_task(self._execute_node(node, context))
|
||||||
|
|
||||||
|
# Small delay to prevent busy waiting
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Emit DAG completion
|
||||||
|
success = self.dag.is_success
|
||||||
|
self._emit_progress("dag", 1.0, "DAG execution complete" if success else "DAG execution failed")
|
||||||
|
|
||||||
|
return self._build_result()
|
||||||
|
|
||||||
|
def _is_execution_complete(self) -> bool:
|
||||||
|
"""Check if execution is complete"""
|
||||||
|
return all(
|
||||||
|
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED, TaskNodeStatus.BLOCKED)
|
||||||
|
for n in self.dag.nodes.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_cancelled(self) -> bool:
|
||||||
|
"""Check if execution was cancelled"""
|
||||||
|
with self._cancel_lock:
|
||||||
|
return self._cancelled
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
"""Cancel execution"""
|
||||||
|
with self._cancel_lock:
|
||||||
|
self._cancelled = True
|
||||||
|
self._emit_progress("dag", 0.0, "Execution cancelled")
|
||||||
|
|
||||||
|
def _get_ready_nodes(self) -> List[TaskNode]:
|
||||||
|
"""Get nodes that are ready to execute"""
|
||||||
|
return [
|
||||||
|
n for n in self.dag.nodes.values()
|
||||||
|
if n.status == TaskNodeStatus.READY
|
||||||
|
and n.id not in self._running_nodes
|
||||||
|
and n.can_execute(self._completed_nodes)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _execute_node(self, node: TaskNode, context: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Execute a single node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Node to execute
|
||||||
|
context: Execution context
|
||||||
|
"""
|
||||||
|
if self._is_cancelled():
|
||||||
|
node.mark_cancelled()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Mark node as running
|
||||||
|
self._running_nodes.add(node.id)
|
||||||
|
node.mark_running(self.supervisor.agent.id)
|
||||||
|
|
||||||
|
self._emit_progress(node.id, 0.0, f"Starting: {node.name}")
|
||||||
|
|
||||||
|
# Collect parent outputs for this node
|
||||||
|
parent_outputs = {}
|
||||||
|
for dep_id in node.dependencies:
|
||||||
|
if dep_id in self._node_results:
|
||||||
|
parent_outputs[dep_id] = self._node_results[dep_id].output_data or {}
|
||||||
|
|
||||||
|
# Create worker for this task
|
||||||
|
worker = self.worker_factory()
|
||||||
|
|
||||||
|
# Define progress callback for this node
|
||||||
|
def node_progress(progress: float, message: str = ""):
|
||||||
|
node.update_progress(progress, message)
|
||||||
|
self._emit_progress(node.id, progress, message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute task
|
||||||
|
result = await worker.execute_task(
|
||||||
|
node,
|
||||||
|
context,
|
||||||
|
parent_outputs=parent_outputs,
|
||||||
|
progress_callback=node_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store result
|
||||||
|
self._node_results[node.id] = result
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
node.mark_completed(result)
|
||||||
|
self._completed_nodes.add(node.id)
|
||||||
|
self._emit_progress(node.id, 1.0, f"Completed: {node.name}")
|
||||||
|
|
||||||
|
# Check if any blocked nodes can now run
|
||||||
|
self._unblock_nodes()
|
||||||
|
else:
|
||||||
|
node.mark_failed(result.error or "Unknown error")
|
||||||
|
self._emit_progress(node.id, 0.0, f"Failed: {node.name} - {result.error}")
|
||||||
|
|
||||||
|
# Block dependent nodes
|
||||||
|
self._block_dependent_nodes(node.id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Node {node.id} execution error: {e}")
|
||||||
|
node.mark_failed(str(e))
|
||||||
|
self._node_results[node.id] = TaskResult.fail(error=str(e))
|
||||||
|
self._block_dependent_nodes(node.id)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._running_nodes.discard(node.id)
|
||||||
|
|
||||||
|
def _unblock_nodes(self) -> None:
|
||||||
|
"""Unblock nodes whose dependencies are now satisfied"""
|
||||||
|
for node in self.dag.nodes.values():
|
||||||
|
if node.status == TaskNodeStatus.BLOCKED:
|
||||||
|
if node.can_execute(self._completed_nodes):
|
||||||
|
node.mark_ready()
|
||||||
|
self._emit_progress(node.id, 0.0, f"Unblocked: {node.name}")
|
||||||
|
|
||||||
|
def _block_dependent_nodes(self, failed_node_id: str) -> None:
|
||||||
|
"""Block nodes that depend on a failed node"""
|
||||||
|
blocked = self.dag.get_blocked_nodes(failed_node_id)
|
||||||
|
for node in blocked:
|
||||||
|
node.status = TaskNodeStatus.BLOCKED
|
||||||
|
self._emit_progress(node.id, 0.0, f"Blocked due to: {failed_node_id}")
|
||||||
|
|
||||||
|
async def _cancel_running_nodes(self) -> None:
|
||||||
|
"""Cancel all running nodes"""
|
||||||
|
for node_id in self._running_nodes:
|
||||||
|
if node_id in self.dag.nodes:
|
||||||
|
self.dag.nodes[node_id].mark_cancelled()
|
||||||
|
self._running_nodes.clear()
|
||||||
|
|
||||||
|
def _build_result(self) -> Dict[str, Any]:
|
||||||
|
"""Build final result"""
|
||||||
|
return {
|
||||||
|
"success": self.dag.is_success,
|
||||||
|
"dag_id": self.dag.id,
|
||||||
|
"total_tasks": self.dag.total_count,
|
||||||
|
"completed_tasks": self.dag.completed_count,
|
||||||
|
"failed_tasks": self.dag.failed_count,
|
||||||
|
"progress": self.dag.progress,
|
||||||
|
"results": {
|
||||||
|
node_id: result.to_dict()
|
||||||
|
for node_id, result in self._node_results.items()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_cancelled_result(self) -> Dict[str, Any]:
|
||||||
|
"""Build cancelled result"""
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"cancelled": True,
|
||||||
|
"dag_id": self.dag.id,
|
||||||
|
"total_tasks": self.dag.total_count,
|
||||||
|
"completed_tasks": self.dag.completed_count,
|
||||||
|
"failed_tasks": self.dag.failed_count,
|
||||||
|
"progress": self.dag.progress,
|
||||||
|
"results": {
|
||||||
|
node_id: result.to_dict()
|
||||||
|
for node_id, result in self._node_results.items()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerPool:
|
||||||
|
"""
|
||||||
|
Pool of DAG schedulers for managing multiple concurrent DAG executions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_concurrent: int = 10):
|
||||||
|
"""
|
||||||
|
Initialize scheduler pool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_concurrent: Maximum concurrent DAG executions
|
||||||
|
"""
|
||||||
|
self.max_concurrent = max_concurrent
|
||||||
|
self._schedulers: Dict[str, DAGScheduler] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
dag: DAG,
|
||||||
|
supervisor: SupervisorAgent,
|
||||||
|
worker_factory: Callable,
|
||||||
|
max_workers: int = 3
|
||||||
|
) -> DAGScheduler:
|
||||||
|
"""
|
||||||
|
Create and register a new scheduler
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique task identifier
|
||||||
|
dag: DAG to execute
|
||||||
|
supervisor: Supervisor agent
|
||||||
|
worker_factory: Worker factory function
|
||||||
|
max_workers: Max parallel workers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created scheduler
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if len(self._schedulers) >= self.max_concurrent:
|
||||||
|
raise RuntimeError("Maximum concurrent schedulers reached")
|
||||||
|
|
||||||
|
scheduler = DAGScheduler(
|
||||||
|
dag=dag,
|
||||||
|
supervisor=supervisor,
|
||||||
|
worker_factory=worker_factory,
|
||||||
|
max_workers=max_workers
|
||||||
|
)
|
||||||
|
|
||||||
|
self._schedulers[task_id] = scheduler
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
def get(self, task_id: str) -> Optional[DAGScheduler]:
|
||||||
|
"""Get scheduler by task ID"""
|
||||||
|
with self._lock:
|
||||||
|
return self._schedulers.get(task_id)
|
||||||
|
|
||||||
|
def remove(self, task_id: str) -> bool:
|
||||||
|
"""Remove scheduler"""
|
||||||
|
with self._lock:
|
||||||
|
if task_id in self._schedulers:
|
||||||
|
del self._schedulers[task_id]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cancel(self, task_id: str) -> bool:
|
||||||
|
"""Cancel a scheduler"""
|
||||||
|
scheduler = self.get(task_id)
|
||||||
|
if scheduler:
|
||||||
|
scheduler.cancel()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_count(self) -> int:
|
||||||
|
"""Number of active schedulers"""
|
||||||
|
with self._lock:
|
||||||
|
return len(self._schedulers)
|
||||||
|
|
@ -0,0 +1,266 @@
|
||||||
|
"""Agent Registry - manages agent lifecycle and access"""
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
"""
|
||||||
|
Agent Registry (Singleton)
|
||||||
|
|
||||||
|
Thread-safe registry for managing all agents in the system.
|
||||||
|
Provides agent creation, retrieval, and lifecycle management.
|
||||||
|
"""
|
||||||
|
_instance: Optional["AgentRegistry"] = None
|
||||||
|
_lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
self._agents: Dict[str, Agent] = {}
|
||||||
|
self._user_agents: Dict[int, Set[str]] = {} # user_id -> set of agent_ids
|
||||||
|
self._conversation_agents: Dict[str, Set[str]] = {} # conversation_id -> set of agent_ids
|
||||||
|
self._registry_lock = threading.Lock()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def _generate_agent_id(self) -> str:
|
||||||
|
"""Generate unique agent ID"""
|
||||||
|
return f"agent_{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
def create_agent(
|
||||||
|
self,
|
||||||
|
config: AgentConfig,
|
||||||
|
user_id: Optional[int] = None,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
workspace: Optional[str] = None,
|
||||||
|
) -> Agent:
|
||||||
|
"""
|
||||||
|
Create and register a new agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Agent configuration
|
||||||
|
user_id: Associated user ID
|
||||||
|
conversation_id: Associated conversation ID
|
||||||
|
workspace: Agent's workspace path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created Agent instance
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
agent_id = self._generate_agent_id()
|
||||||
|
agent = Agent(
|
||||||
|
id=agent_id,
|
||||||
|
config=config,
|
||||||
|
user_id=user_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
workspace=workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._agents[agent_id] = agent
|
||||||
|
|
||||||
|
# Track by user
|
||||||
|
if user_id:
|
||||||
|
if user_id not in self._user_agents:
|
||||||
|
self._user_agents[user_id] = set()
|
||||||
|
self._user_agents[user_id].add(agent_id)
|
||||||
|
|
||||||
|
# Track by conversation
|
||||||
|
if conversation_id:
|
||||||
|
if conversation_id not in self._conversation_agents:
|
||||||
|
self._conversation_agents[conversation_id] = set()
|
||||||
|
self._conversation_agents[conversation_id].add(agent_id)
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def get(self, agent_id: str) -> Optional[Agent]:
|
||||||
|
"""
|
||||||
|
Get agent by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent if found, None otherwise
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
return self._agents.get(agent_id)
|
||||||
|
|
||||||
|
def list_user_agents(self, user_id: int) -> List[Agent]:
|
||||||
|
"""
|
||||||
|
List all agents for a user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user's agents
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
agent_ids = self._user_agents.get(user_id, set())
|
||||||
|
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
|
||||||
|
|
||||||
|
def list_conversation_agents(self, conversation_id: str) -> List[Agent]:
|
||||||
|
"""
|
||||||
|
List all agents in a conversation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Conversation ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of conversation's agents
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
agent_ids = self._conversation_agents.get(conversation_id, set())
|
||||||
|
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
|
||||||
|
|
||||||
|
def list_by_type(self, agent_type: AgentType) -> List[Agent]:
|
||||||
|
"""
|
||||||
|
List all agents of a specific type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_type: Type of agent to filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of agents of the specified type
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
return [
|
||||||
|
a for a in self._agents.values()
|
||||||
|
if a.config.agent_type == agent_type
|
||||||
|
]
|
||||||
|
|
||||||
|
def list_by_status(self, status: AgentStatus) -> List[Agent]:
|
||||||
|
"""
|
||||||
|
List all agents with a specific status
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Status to filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of agents with the specified status
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
return [a for a in self._agents.values() if a.status == status]
|
||||||
|
|
||||||
|
def update_status(self, agent_id: str, status: AgentStatus) -> bool:
|
||||||
|
"""
|
||||||
|
Update agent status
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
status: New status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if updated, False if agent not found
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
agent = self._agents.get(agent_id)
|
||||||
|
if agent:
|
||||||
|
agent.status = status
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def remove(self, agent_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove agent from registry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if removed, False if not found
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
agent = self._agents.pop(agent_id, None)
|
||||||
|
if agent is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Remove from user tracking
|
||||||
|
if agent.user_id and agent.user_id in self._user_agents:
|
||||||
|
self._user_agents[agent.user_id].discard(agent_id)
|
||||||
|
|
||||||
|
# Remove from conversation tracking
|
||||||
|
if agent.conversation_id and agent.conversation_id in self._conversation_agents:
|
||||||
|
self._conversation_agents[agent.conversation_id].discard(agent_id)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def remove_user_agents(self, user_id: int) -> int:
|
||||||
|
"""
|
||||||
|
Remove all agents for a user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of agents removed
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
agent_ids = self._user_agents.pop(user_id, set())
|
||||||
|
count = 0
|
||||||
|
for agent_id in agent_ids:
|
||||||
|
if agent_id in self._agents:
|
||||||
|
del self._agents[agent_id]
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def remove_conversation_agents(self, conversation_id: str) -> int:
|
||||||
|
"""
|
||||||
|
Remove all agents in a conversation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Conversation ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of agents removed
|
||||||
|
"""
|
||||||
|
with self._registry_lock:
|
||||||
|
agent_ids = self._conversation_agents.pop(conversation_id, set())
|
||||||
|
count = 0
|
||||||
|
for agent_id in agent_ids:
|
||||||
|
if agent_id in self._agents:
|
||||||
|
del self._agents[agent_id]
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all agents from registry"""
|
||||||
|
with self._registry_lock:
|
||||||
|
self._agents.clear()
|
||||||
|
self._user_agents.clear()
|
||||||
|
self._conversation_agents.clear()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent_count(self) -> int:
|
||||||
|
"""Total number of agents"""
|
||||||
|
with self._registry_lock:
|
||||||
|
return len(self._agents)
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict:
|
||||||
|
"""Get registry statistics"""
|
||||||
|
with self._registry_lock:
|
||||||
|
status_counts = {}
|
||||||
|
type_counts = {}
|
||||||
|
for agent in self._agents.values():
|
||||||
|
status_counts[agent.status.value] = status_counts.get(agent.status.value, 0) + 1
|
||||||
|
type_counts[agent.config.agent_type.value] = type_counts.get(agent.config.agent_type.value, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": len(self._agents),
|
||||||
|
"by_status": status_counts,
|
||||||
|
"by_type": type_counts,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
registry = AgentRegistry()
|
||||||
|
|
@ -0,0 +1,345 @@
|
||||||
|
"""Supervisor Agent - task decomposition and result integration"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Callable
|
||||||
|
|
||||||
|
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||||
|
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
||||||
|
from luxx.services.llm_client import llm_client
|
||||||
|
from luxx.tools.core import registry as tool_registry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SupervisorAgent:
|
||||||
|
"""
|
||||||
|
Supervisor Agent
|
||||||
|
|
||||||
|
Responsible for:
|
||||||
|
- Task decomposition using LLM
|
||||||
|
- Generating DAG (task graph)
|
||||||
|
- Result integration from workers
|
||||||
|
"""
|
||||||
|
|
||||||
|
# System prompt for task decomposition
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """You are a Supervisor Agent that decomposes complex tasks into executable subtasks.
|
||||||
|
|
||||||
|
Your responsibilities:
|
||||||
|
1. Analyze the user's task and break it down into smaller, manageable subtasks
|
||||||
|
2. Create a DAG (Directed Acyclic Graph) where nodes are subtasks and edges represent dependencies
|
||||||
|
3. Each subtask should be specific and actionable
|
||||||
|
4. Consider parallel execution opportunities - tasks without dependencies can run concurrently
|
||||||
|
5. Store key results from subtasks for final integration
|
||||||
|
|
||||||
|
Output format for task decomposition:
|
||||||
|
{
|
||||||
|
"task_name": "Overall task name",
|
||||||
|
"task_description": "Description of what needs to be accomplished",
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "task_001",
|
||||||
|
"name": "Task name",
|
||||||
|
"description": "What this task does",
|
||||||
|
"task_type": "code|shell|file|llm|generic",
|
||||||
|
"task_data": {...}, # Task-specific parameters
|
||||||
|
"dependencies": [] # IDs of tasks that must complete first
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
- Keep tasks focused and atomic
|
||||||
|
- Use meaningful task IDs (e.g., task_001, task_002)
|
||||||
|
- Mark parallelizable tasks with no dependencies
|
||||||
|
- Maximum 10 subtasks for a single decomposition
|
||||||
|
- Include only the output_data that matters for dependent tasks or final result
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent: Agent,
|
||||||
|
llm_client=None,
|
||||||
|
max_subtasks: int = 10
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Supervisor Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: Agent instance (should be SUPERVISOR type)
|
||||||
|
llm_client: LLM client instance
|
||||||
|
max_subtasks: Maximum number of subtasks to generate
|
||||||
|
"""
|
||||||
|
self.agent = agent
|
||||||
|
self.llm_client = llm_client or llm_client
|
||||||
|
self.max_subtasks = max_subtasks
|
||||||
|
|
||||||
|
# Ensure agent has supervisor system prompt
|
||||||
|
if not self.agent.config.system_prompt:
|
||||||
|
self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
async def decompose_task(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
progress_callback: Optional[Callable] = None
|
||||||
|
) -> DAG:
|
||||||
|
"""
|
||||||
|
Decompose a task into subtasks using LLM
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: User's task description
|
||||||
|
context: Execution context (workspace, user info, etc.)
|
||||||
|
progress_callback: Optional callback for progress updates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DAG representing the task decomposition
|
||||||
|
"""
|
||||||
|
self.agent.status = AgentStatus.PLANNING
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.1, "Analyzing task...")
|
||||||
|
|
||||||
|
# Build messages for LLM
|
||||||
|
messages = self.agent.get_context()
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Decompose this task into subtasks:\n{task}"
|
||||||
|
})
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.2, "Calling LLM for task decomposition...")
|
||||||
|
|
||||||
|
# Call LLM
|
||||||
|
try:
|
||||||
|
response = await self.llm_client.sync_call(
|
||||||
|
model=self.agent.config.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.agent.config.temperature,
|
||||||
|
max_tokens=self.agent.config.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.5, "Processing decomposition...")
|
||||||
|
|
||||||
|
# Parse LLM response to extract DAG
|
||||||
|
dag = self._parse_dag_from_response(response.content, task)
|
||||||
|
|
||||||
|
# Add assistant response to context
|
||||||
|
self.agent.add_message("assistant", response.content)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.9, "Task decomposition complete")
|
||||||
|
|
||||||
|
self.agent.status = AgentStatus.IDLE
|
||||||
|
return dag
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task decomposition failed: {e}")
|
||||||
|
self.agent.status = AgentStatus.FAILED
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _parse_dag_from_response(self, content: str, original_task: str) -> DAG:
|
||||||
|
"""
|
||||||
|
Parse LLM response to extract DAG structure
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: LLM response content
|
||||||
|
original_task: Original task description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DAG instance
|
||||||
|
"""
|
||||||
|
# Try to extract JSON from response
|
||||||
|
dag_data = self._extract_json(content)
|
||||||
|
|
||||||
|
if not dag_data:
|
||||||
|
# Fallback: create a simple single-node DAG
|
||||||
|
logger.warning("Could not parse DAG from LLM response, creating simple DAG")
|
||||||
|
dag = DAG(
|
||||||
|
id=f"dag_{self.agent.id}",
|
||||||
|
name=original_task[:50],
|
||||||
|
description=original_task
|
||||||
|
)
|
||||||
|
node = TaskNode(
|
||||||
|
id="task_001",
|
||||||
|
name="Execute Task",
|
||||||
|
description=original_task,
|
||||||
|
task_type="llm",
|
||||||
|
task_data={"prompt": original_task}
|
||||||
|
)
|
||||||
|
dag.add_node(node)
|
||||||
|
return dag
|
||||||
|
|
||||||
|
# Build DAG from parsed data
|
||||||
|
dag = DAG(
|
||||||
|
id=f"dag_{self.agent.id}",
|
||||||
|
name=dag_data.get("task_name", original_task[:50]),
|
||||||
|
description=dag_data.get("task_description", original_task)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add nodes
|
||||||
|
for node_data in dag_data.get("nodes", []):
|
||||||
|
node = TaskNode(
|
||||||
|
id=node_data["id"],
|
||||||
|
name=node_data["name"],
|
||||||
|
description=node_data.get("description", ""),
|
||||||
|
task_type=node_data.get("task_type", "generic"),
|
||||||
|
task_data=node_data.get("task_data", {})
|
||||||
|
)
|
||||||
|
dag.add_node(node)
|
||||||
|
|
||||||
|
# Add edges based on dependencies
|
||||||
|
for node_data in dag_data.get("nodes", []):
|
||||||
|
node_id = node_data["id"]
|
||||||
|
for dep_id in node_data.get("dependencies", []):
|
||||||
|
if dep_id in dag.nodes:
|
||||||
|
dag.add_edge(dep_id, node_id)
|
||||||
|
|
||||||
|
return dag
|
||||||
|
|
||||||
|
def _extract_json(self, content: str) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
Extract JSON from LLM response
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Raw LLM response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON dict or None
|
||||||
|
"""
|
||||||
|
# Try to find JSON in markdown code blocks
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Look for ```json ... ``` blocks
|
||||||
|
json_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
try:
|
||||||
|
return json.loads(json_match.group(1))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Look for raw JSON object
|
||||||
|
json_match = re.search(r"\{.*\}", content, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
try:
|
||||||
|
return json.loads(json_match.group(0))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def integrate_results(self, dag: DAG) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Integrate results from all completed tasks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag: DAG with completed tasks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integrated result
|
||||||
|
"""
|
||||||
|
self.agent.status = AgentStatus.EXECUTING
|
||||||
|
|
||||||
|
# Collect all output data from completed nodes
|
||||||
|
results = {}
|
||||||
|
for node in dag.nodes.values():
|
||||||
|
if node.status == TaskNodeStatus.COMPLETED and node.output_data:
|
||||||
|
results[node.id] = node.output_data
|
||||||
|
|
||||||
|
# Store aggregated results
|
||||||
|
self.agent.accumulated_result = results
|
||||||
|
self.agent.status = AgentStatus.COMPLETED
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"dag_id": dag.id,
|
||||||
|
"total_tasks": dag.total_count,
|
||||||
|
"completed_tasks": dag.completed_count,
|
||||||
|
"failed_tasks": dag.failed_count,
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
|
async def review_and_refine(
|
||||||
|
self,
|
||||||
|
dag: DAG,
|
||||||
|
task: str,
|
||||||
|
progress_callback: Optional[Callable] = None
|
||||||
|
) -> Optional[DAG]:
|
||||||
|
"""
|
||||||
|
Review DAG execution and refine if needed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag: Current DAG state
|
||||||
|
task: Original task
|
||||||
|
progress_callback: Progress callback
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Refined DAG or None if no refinement needed
|
||||||
|
"""
|
||||||
|
if dag.is_success:
|
||||||
|
return None # No refinement needed
|
||||||
|
|
||||||
|
# Check if there are failed tasks
|
||||||
|
failed_nodes = [n for n in dag.nodes.values() if n.status == TaskNodeStatus.FAILED]
|
||||||
|
|
||||||
|
if not failed_nodes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build context for refinement
|
||||||
|
context = {
|
||||||
|
"original_task": task,
|
||||||
|
"failed_tasks": [
|
||||||
|
{
|
||||||
|
"id": n.id,
|
||||||
|
"name": n.name,
|
||||||
|
"error": n.result.error if n.result else "Unknown error"
|
||||||
|
}
|
||||||
|
for n in failed_nodes
|
||||||
|
],
|
||||||
|
"completed_tasks": [
|
||||||
|
{
|
||||||
|
"id": n.id,
|
||||||
|
"name": n.name,
|
||||||
|
"output": n.output_data
|
||||||
|
}
|
||||||
|
for n in dag.nodes.values() if n.status == TaskNodeStatus.COMPLETED
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = self.agent.get_context()
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""Review the task execution and suggest refinements:
|
||||||
|
|
||||||
|
Task: {task}
|
||||||
|
|
||||||
|
Failed tasks: {json.dumps(context['failed_tasks'], indent=2)}
|
||||||
|
|
||||||
|
Completed tasks: {json.dumps(context['completed_tasks'], indent=2)}
|
||||||
|
|
||||||
|
If a task failed, you can:
|
||||||
|
1. Break it into smaller tasks
|
||||||
|
2. Change the approach
|
||||||
|
3. Skip it if not critical
|
||||||
|
|
||||||
|
Provide a refined subtask plan if needed, or indicate if the overall task should fail."""
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.llm_client.sync_call(
|
||||||
|
model=self.agent.config.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.agent.config.temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if refinement was suggested
|
||||||
|
refined_dag = self._parse_dag_from_response(response.content, task)
|
||||||
|
|
||||||
|
# Only return if we got a valid refinement
|
||||||
|
if refined_dag and refined_dag.nodes:
|
||||||
|
return refined_dag
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DAG refinement failed: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,401 @@
|
||||||
|
"""Worker Agent - executes specific tasks"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Callable
|
||||||
|
|
||||||
|
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||||
|
from luxx.agents.dag import TaskNode, TaskNodeStatus, TaskResult
|
||||||
|
from luxx.services.llm_client import llm_client
|
||||||
|
from luxx.tools.core import registry as tool_registry, ToolContext, CommandPermission
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerAgent:
|
||||||
|
"""
|
||||||
|
Worker Agent
|
||||||
|
|
||||||
|
Responsible for executing specific tasks using:
|
||||||
|
- LLM calls for reasoning tasks
|
||||||
|
- Tool execution for actionable tasks
|
||||||
|
|
||||||
|
Follows sliding window context management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# System prompt for worker tasks
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """You are a Worker Agent that executes specific tasks efficiently.
|
||||||
|
|
||||||
|
Your responsibilities:
|
||||||
|
1. Execute tasks assigned to you by the Supervisor
|
||||||
|
2. Use appropriate tools when needed
|
||||||
|
3. Report results clearly with structured output_data for dependent tasks
|
||||||
|
4. Be concise and focused on the task at hand
|
||||||
|
|
||||||
|
Output format:
|
||||||
|
- Provide clear, structured results
|
||||||
|
- Include output_data for any data that dependent tasks might need
|
||||||
|
- If a tool fails, explain the error clearly
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent: Agent,
|
||||||
|
llm_client=None,
|
||||||
|
tool_executor=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Worker Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: Agent instance (should be WORKER type)
|
||||||
|
llm_client: LLM client instance
|
||||||
|
tool_executor: Tool executor instance
|
||||||
|
"""
|
||||||
|
self.agent = agent
|
||||||
|
self.llm_client = llm_client or llm_client
|
||||||
|
self.tool_executor = tool_executor
|
||||||
|
|
||||||
|
# Ensure agent has worker system prompt
|
||||||
|
if not self.agent.config.system_prompt:
|
||||||
|
self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
async def execute_task(
|
||||||
|
self,
|
||||||
|
task_node: TaskNode,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
parent_outputs: Dict[str, Dict[str, Any]] = None,
|
||||||
|
progress_callback: Optional[Callable] = None
|
||||||
|
) -> TaskResult:
|
||||||
|
"""
|
||||||
|
Execute a task node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_node: Task node to execute
|
||||||
|
context: Execution context (workspace, user info, etc.)
|
||||||
|
parent_outputs: Output data from parent tasks (dependency results)
|
||||||
|
progress_callback: Optional callback for progress updates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskResult with execution outcome
|
||||||
|
"""
|
||||||
|
self.agent.status = AgentStatus.EXECUTING
|
||||||
|
self.agent.current_task_id = task_node.id
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.0, f"Starting task: {task_node.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Merge parent outputs into context
|
||||||
|
execution_context = self._prepare_context(context, parent_outputs)
|
||||||
|
|
||||||
|
# Execute based on task type
|
||||||
|
if task_node.task_type == "llm":
|
||||||
|
result = await self._execute_llm_task(task_node, execution_context, progress_callback)
|
||||||
|
elif task_node.task_type == "code":
|
||||||
|
result = await self._execute_code_task(task_node, execution_context, progress_callback)
|
||||||
|
elif task_node.task_type == "shell":
|
||||||
|
result = await self._execute_shell_task(task_node, execution_context, progress_callback)
|
||||||
|
elif task_node.task_type == "file":
|
||||||
|
result = await self._execute_file_task(task_node, execution_context, progress_callback)
|
||||||
|
else:
|
||||||
|
result = await self._execute_generic_task(task_node, execution_context, progress_callback)
|
||||||
|
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
result.execution_time = execution_time
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(1.0, f"Task complete: {task_node.name}")
|
||||||
|
|
||||||
|
self.agent.status = AgentStatus.IDLE
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task execution failed: {e}")
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
self.agent.status = AgentStatus.FAILED
|
||||||
|
return TaskResult.fail(error=str(e))
|
||||||
|
|
||||||
|
def _prepare_context(
|
||||||
|
self,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
parent_outputs: Dict[str, Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Prepare execution context by merging parent outputs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Base context
|
||||||
|
parent_outputs: Output from parent tasks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged context
|
||||||
|
"""
|
||||||
|
execution_context = context.copy()
|
||||||
|
|
||||||
|
if parent_outputs:
|
||||||
|
# Merge parent outputs into context
|
||||||
|
merged = {}
|
||||||
|
for parent_id, outputs in parent_outputs.items():
|
||||||
|
merged.update(outputs)
|
||||||
|
execution_context["parent_outputs"] = parent_outputs
|
||||||
|
execution_context["merged_data"] = merged
|
||||||
|
|
||||||
|
# Add user permission level
|
||||||
|
if "user_permission_level" not in execution_context:
|
||||||
|
execution_context["user_permission_level"] = self.agent.effective_permission.value
|
||||||
|
|
||||||
|
return execution_context
|
||||||
|
|
||||||
|
async def _execute_llm_task(
|
||||||
|
self,
|
||||||
|
task_node: TaskNode,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
progress_callback: Optional[Callable] = None
|
||||||
|
) -> TaskResult:
|
||||||
|
"""Execute LLM reasoning task"""
|
||||||
|
task_data = task_node.task_data
|
||||||
|
|
||||||
|
# Build prompt
|
||||||
|
prompt = task_data.get("prompt", task_node.description)
|
||||||
|
system_prompt = task_data.get("system", self.agent.config.system_prompt)
|
||||||
|
|
||||||
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
|
# Add parent data if available
|
||||||
|
if "merged_data" in context:
|
||||||
|
merged = context["merged_data"]
|
||||||
|
context_info = "\n".join([f"{k}: {v}" for k, v in merged.items()])
|
||||||
|
messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Context from dependent tasks:\n{context_info}"
|
||||||
|
})
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.3, "Calling LLM...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.llm_client.sync_call(
|
||||||
|
model=self.agent.config.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.agent.config.temperature,
|
||||||
|
max_tokens=self.agent.config.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return TaskResult.ok(
|
||||||
|
data=response.content,
|
||||||
|
output_data=task_node.task_data.get("output_template", {}).copy()
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return TaskResult.fail(error=str(e))
|
||||||
|
|
||||||
|
async def _execute_code_task(
|
||||||
|
self,
|
||||||
|
task_node: TaskNode,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
progress_callback: Optional[Callable] = None
|
||||||
|
) -> TaskResult:
|
||||||
|
"""Execute code generation/writing task"""
|
||||||
|
task_data = task_node.task_data
|
||||||
|
|
||||||
|
# Build prompt for code generation
|
||||||
|
prompt = task_data.get("prompt", task_node.description)
|
||||||
|
language = task_data.get("language", "python")
|
||||||
|
requirements = task_data.get("requirements", "")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": f"You are a {language} programmer. Write clean, efficient code."},
|
||||||
|
{"role": "user", "content": f"Task: {prompt}\n\nRequirements: {requirements}"}
|
||||||
|
]
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.3, "Generating code...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.llm_client.sync_call(
|
||||||
|
model=self.agent.config.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.2, # Lower temp for code
|
||||||
|
max_tokens=4096
|
||||||
|
)
|
||||||
|
|
||||||
|
return TaskResult.ok(
|
||||||
|
data=response.content,
|
||||||
|
output_data={
|
||||||
|
"code": response.content,
|
||||||
|
"language": language
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return TaskResult.fail(error=str(e))
|
||||||
|
|
||||||
|
async def _execute_shell_task(
|
||||||
|
self,
|
||||||
|
task_node: TaskNode,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
progress_callback: Optional[Callable] = None
|
||||||
|
) -> TaskResult:
|
||||||
|
"""Execute shell command task"""
|
||||||
|
task_data = task_node.task_data
|
||||||
|
|
||||||
|
command = task_data.get("command")
|
||||||
|
if not command:
|
||||||
|
return TaskResult.fail(error="No command specified")
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.3, f"Executing: {command[:50]}...")
|
||||||
|
|
||||||
|
# Build tool context
|
||||||
|
tool_ctx = ToolContext(
|
||||||
|
workspace=context.get("workspace"),
|
||||||
|
user_id=context.get("user_id"),
|
||||||
|
username=context.get("username"),
|
||||||
|
extra={"user_permission_level": context.get("user_permission_level", 1)}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute shell command via tool
|
||||||
|
result = tool_registry.execute(
|
||||||
|
"shell_exec",
|
||||||
|
{"command": command},
|
||||||
|
context=tool_ctx
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("success"):
|
||||||
|
return TaskResult.ok(
|
||||||
|
data=result.get("data", {}).get("output", ""),
|
||||||
|
output_data={"output": result.get("data", {}).get("output", "")}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return TaskResult.fail(error=result.get("error", "Shell execution failed"))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return TaskResult.fail(error=str(e))
|
||||||
|
|
||||||
|
async def _execute_file_task(
|
||||||
|
self,
|
||||||
|
task_node: TaskNode,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
progress_callback: Optional[Callable] = None
|
||||||
|
) -> TaskResult:
|
||||||
|
"""Execute file operation task"""
|
||||||
|
task_data = task_node.task_data
|
||||||
|
|
||||||
|
operation = task_data.get("operation")
|
||||||
|
file_path = task_data.get("path")
|
||||||
|
content = task_data.get("content", "")
|
||||||
|
|
||||||
|
if not operation or not file_path:
|
||||||
|
return TaskResult.fail(error="Missing operation or path")
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.3, f"File operation: {operation} {file_path}")
|
||||||
|
|
||||||
|
tool_ctx = ToolContext(
|
||||||
|
workspace=context.get("workspace"),
|
||||||
|
user_id=context.get("user_id"),
|
||||||
|
username=context.get("username"),
|
||||||
|
extra={"user_permission_level": context.get("user_permission_level", 1)}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_name = f"file_{operation}"
|
||||||
|
result = tool_registry.execute(
|
||||||
|
tool_name,
|
||||||
|
{"path": file_path, "content": content},
|
||||||
|
context=tool_ctx
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("success"):
|
||||||
|
return TaskResult.ok(
|
||||||
|
data=result.get("data"),
|
||||||
|
output_data={"path": file_path, "operation": operation}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return TaskResult.fail(error=result.get("error", "File operation failed"))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return TaskResult.fail(error=str(e))
|
||||||
|
|
||||||
|
async def _execute_generic_task(
|
||||||
|
self,
|
||||||
|
task_node: TaskNode,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
progress_callback: Optional[Callable] = None
|
||||||
|
) -> TaskResult:
|
||||||
|
"""Execute generic task using LLM with tools"""
|
||||||
|
task_data = task_node.task_data
|
||||||
|
|
||||||
|
# Build prompt
|
||||||
|
prompt = task_data.get("prompt", task_node.description)
|
||||||
|
tools = task_data.get("tools", [])
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": self.agent.config.system_prompt},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get tool definitions if specified
|
||||||
|
tool_defs = None
|
||||||
|
if tools:
|
||||||
|
tool_defs = [tool_registry.get(t).to_openai_format() for t in tools if tool_registry.get(t)]
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.2, "Processing task...")
|
||||||
|
|
||||||
|
max_iterations = 5
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
|
while iteration < max_iterations:
|
||||||
|
try:
|
||||||
|
response = await self.llm_client.sync_call(
|
||||||
|
model=self.agent.config.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=tool_defs,
|
||||||
|
temperature=self.agent.config.temperature,
|
||||||
|
max_tokens=self.agent.config.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add assistant response
|
||||||
|
messages.append({"role": "assistant", "content": response.content})
|
||||||
|
|
||||||
|
# Check for tool calls
|
||||||
|
if response.tool_calls:
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.5, f"Executing {len(response.tool_calls)} tools...")
|
||||||
|
|
||||||
|
# Execute tools
|
||||||
|
tool_results = self.tool_executor.process_tool_calls(
|
||||||
|
response.tool_calls,
|
||||||
|
context
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add tool results
|
||||||
|
for tr in tool_results:
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tr["tool_call_id"],
|
||||||
|
"content": tr["content"]
|
||||||
|
})
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.8, "Tools executed")
|
||||||
|
else:
|
||||||
|
# No tool calls, task complete
|
||||||
|
return TaskResult.ok(
|
||||||
|
data=response.content,
|
||||||
|
output_data=task_data.get("output_template", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return TaskResult.fail(error=str(e))
|
||||||
|
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
return TaskResult.fail(error="Max iterations exceeded")
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from luxx.routes import auth, conversations, messages, tools, providers
|
from luxx.routes import auth, conversations, messages, tools, providers
|
||||||
|
from luxx.routes.agents_ws import router as agents_ws_router
|
||||||
|
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
|
|
@ -12,3 +13,4 @@ api_router.include_router(conversations.router)
|
||||||
api_router.include_router(messages.router)
|
api_router.include_router(messages.router)
|
||||||
api_router.include_router(tools.router)
|
api_router.include_router(tools.router)
|
||||||
api_router.include_router(providers.router)
|
api_router.include_router(providers.router)
|
||||||
|
api_router.include_router(agents_ws_router)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,385 @@
|
||||||
|
"""WebSocket routes for agent real-time communication"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any, Dict, Optional, Set
|
||||||
|
from datetime import datetime
|
||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||||
|
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
||||||
|
from luxx.agents.supervisor import SupervisorAgent
|
||||||
|
from luxx.agents.worker import WorkerAgent
|
||||||
|
from luxx.agents.dag_scheduler import DAGScheduler, SchedulerPool
|
||||||
|
from luxx.agents.registry import AgentRegistry
|
||||||
|
from luxx.services.llm_client import llm_client
|
||||||
|
from luxx.tools.executor import ToolExecutor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
"""
|
||||||
|
WebSocket Connection Manager
|
||||||
|
|
||||||
|
Manages WebSocket connections for real-time agent progress updates.
|
||||||
|
Features:
|
||||||
|
- Connection tracking by task_id
|
||||||
|
- Heartbeat mechanism
|
||||||
|
- Progress broadcasting
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# task_id -> set of websocket connections
|
||||||
|
self._connections: Dict[str, Set[WebSocket]] = {}
|
||||||
|
# websocket -> task_id mapping
|
||||||
|
self._ws_to_task: Dict[WebSocket, str] = {}
|
||||||
|
# heartbeat tasks
|
||||||
|
self._heartbeat_tasks: Dict[WebSocket, asyncio.Task] = {}
|
||||||
|
# lock for thread safety
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
# heartbeat interval in seconds
|
||||||
|
self._heartbeat_interval = 30
|
||||||
|
|
||||||
|
def connect(self, websocket: WebSocket, task_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Accept and register a WebSocket connection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: WebSocket connection
|
||||||
|
task_id: Task ID to subscribe to
|
||||||
|
"""
|
||||||
|
websocket.accept()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if task_id not in self._connections:
|
||||||
|
self._connections[task_id] = set()
|
||||||
|
self._connections[task_id].add(websocket)
|
||||||
|
self._ws_to_task[websocket] = task_id
|
||||||
|
|
||||||
|
logger.info(f"WebSocket connected for task {task_id}")
|
||||||
|
|
||||||
|
def disconnect(self, websocket: WebSocket) -> None:
|
||||||
|
"""
|
||||||
|
Unregister a WebSocket connection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: WebSocket connection
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
task_id = self._ws_to_task.pop(websocket, None)
|
||||||
|
if task_id and task_id in self._connections:
|
||||||
|
self._connections[task_id].discard(websocket)
|
||||||
|
if not self._connections[task_id]:
|
||||||
|
del self._connections[task_id]
|
||||||
|
|
||||||
|
# Cancel heartbeat task
|
||||||
|
if websocket in self._heartbeat_tasks:
|
||||||
|
self._heartbeat_tasks[websocket].cancel()
|
||||||
|
del self._heartbeat_tasks[websocket]
|
||||||
|
|
||||||
|
if task_id:
|
||||||
|
logger.info(f"WebSocket disconnected for task {task_id}")
|
||||||
|
|
||||||
|
async def send_to_task(self, task_id: str, message: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Send message to all connections subscribed to a task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID
|
||||||
|
message: Message to send
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
connections = self._connections.get(task_id, set()).copy()
|
||||||
|
|
||||||
|
if not connections:
|
||||||
|
return
|
||||||
|
|
||||||
|
dead_connections = set()
|
||||||
|
|
||||||
|
for websocket in connections:
|
||||||
|
try:
|
||||||
|
await websocket.send_json(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send to websocket: {e}")
|
||||||
|
dead_connections.add(websocket)
|
||||||
|
|
||||||
|
# Clean up dead connections
|
||||||
|
for ws in dead_connections:
|
||||||
|
self.disconnect(ws)
|
||||||
|
|
||||||
|
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Broadcast message to all connections
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message to broadcast
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
all_connections = list(self._ws_to_task.keys())
|
||||||
|
|
||||||
|
for websocket in all_connections:
|
||||||
|
try:
|
||||||
|
await websocket.send_json(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to broadcast: {e}")
|
||||||
|
|
||||||
|
async def send_personal(self, websocket: WebSocket, message: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Send message to a specific connection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: Target WebSocket
|
||||||
|
message: Message to send
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await websocket.send_json(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send personal message: {e}")
|
||||||
|
|
||||||
|
def start_heartbeat(self, websocket: WebSocket) -> None:
|
||||||
|
"""
|
||||||
|
Start heartbeat for a connection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: WebSocket connection
|
||||||
|
"""
|
||||||
|
async def heartbeat_loop():
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(self._heartbeat_interval)
|
||||||
|
try:
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "heartbeat",
|
||||||
|
"interval": self._heartbeat_interval
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
|
||||||
|
task = asyncio.create_task(heartbeat_loop())
|
||||||
|
with self._lock:
|
||||||
|
self._heartbeat_tasks[websocket] = task
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection_count(self) -> int:
|
||||||
|
"""Total number of connections"""
|
||||||
|
with self._lock:
|
||||||
|
return len(self._ws_to_task)
|
||||||
|
|
||||||
|
def get_task_connections(self, task_id: str) -> int:
|
||||||
|
"""Get number of connections for a task"""
|
||||||
|
with self._lock:
|
||||||
|
return len(self._connections.get(task_id, set()))
|
||||||
|
|
||||||
|
|
||||||
|
# Global connection manager
|
||||||
|
connection_manager = ConnectionManager()
|
||||||
|
|
||||||
|
# Global scheduler pool
|
||||||
|
scheduler_pool = SchedulerPool(max_concurrent=10)
|
||||||
|
|
||||||
|
# Global tool executor
|
||||||
|
tool_executor = ToolExecutor()
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressEmitter:
|
||||||
|
"""
|
||||||
|
Progress emitter that sends updates via WebSocket
|
||||||
|
|
||||||
|
Wraps DAGScheduler progress callback to emit WebSocket messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task_id: str, connection_manager: ConnectionManager):
|
||||||
|
self.task_id = task_id
|
||||||
|
self.connection_manager = connection_manager
|
||||||
|
|
||||||
|
def __call__(self, node_id: str, progress: float, message: str = "") -> None:
|
||||||
|
"""Progress callback"""
|
||||||
|
if node_id == "dag":
|
||||||
|
# DAG-level progress
|
||||||
|
asyncio.create_task(self.connection_manager.send_to_task(
|
||||||
|
self.task_id,
|
||||||
|
{
|
||||||
|
"type": "dag_progress",
|
||||||
|
"data": {
|
||||||
|
"progress": progress,
|
||||||
|
"message": message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
# Node-level progress
|
||||||
|
asyncio.create_task(self.connection_manager.send_to_task(
|
||||||
|
self.task_id,
|
||||||
|
{
|
||||||
|
"type": "node_progress",
|
||||||
|
"data": {
|
||||||
|
"node_id": node_id,
|
||||||
|
"progress": progress,
|
||||||
|
"message": message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/ws/dag/{task_id}")
|
||||||
|
async def dag_websocket(websocket: WebSocket, task_id: str):
|
||||||
|
"""
|
||||||
|
WebSocket endpoint for DAG progress updates
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
- Client sends: subscribe, get_status, ping, cancel_task
|
||||||
|
- Server sends: subscribed, heartbeat, dag_start, node_start, node_progress,
|
||||||
|
node_complete, node_error, dag_complete, pong
|
||||||
|
"""
|
||||||
|
# Accept connection
|
||||||
|
connection_manager.connect(websocket, task_id)
|
||||||
|
|
||||||
|
# Send subscribed confirmation
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "subscribed",
|
||||||
|
"task_id": task_id
|
||||||
|
})
|
||||||
|
|
||||||
|
# Start heartbeat
|
||||||
|
connection_manager.start_heartbeat(websocket)
|
||||||
|
|
||||||
|
# Get scheduler if exists
|
||||||
|
scheduler = scheduler_pool.get(task_id)
|
||||||
|
if scheduler:
|
||||||
|
# Send current DAG state
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "dag_status",
|
||||||
|
"data": scheduler.dag.to_dict()
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Receive message
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = json.loads(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "error",
|
||||||
|
"message": "Invalid JSON"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
msg_type = message.get("type")
|
||||||
|
|
||||||
|
if msg_type == "subscribe":
|
||||||
|
# Already subscribed on connect
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "subscribed",
|
||||||
|
"task_id": task_id
|
||||||
|
})
|
||||||
|
|
||||||
|
elif msg_type == "get_status":
|
||||||
|
scheduler = scheduler_pool.get(task_id)
|
||||||
|
if scheduler:
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "dag_status",
|
||||||
|
"data": scheduler.dag.to_dict()
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "dag_status",
|
||||||
|
"data": None
|
||||||
|
})
|
||||||
|
|
||||||
|
elif msg_type == "ping":
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "pong"
|
||||||
|
})
|
||||||
|
|
||||||
|
elif msg_type == "cancel_task":
|
||||||
|
if scheduler_pool.cancel(task_id):
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "task_cancelled",
|
||||||
|
"task_id": task_id
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "error",
|
||||||
|
"message": "Task not found or already completed"
|
||||||
|
})
|
||||||
|
|
||||||
|
else:
|
||||||
|
await connection_manager.send_personal(websocket, {
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Unknown message type: {msg_type}"
|
||||||
|
})
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
connection_manager.disconnect(websocket)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket error: {e}")
|
||||||
|
connection_manager.disconnect(websocket)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions to emit progress from scheduler
|
||||||
|
|
||||||
|
def create_progress_emitter(task_id: str) -> ProgressEmitter:
|
||||||
|
"""Create a progress emitter for a task"""
|
||||||
|
return ProgressEmitter(task_id, connection_manager)
|
||||||
|
|
||||||
|
|
||||||
|
async def emit_dag_start(task_id: str, dag: DAG) -> None:
|
||||||
|
"""Emit DAG start event"""
|
||||||
|
await connection_manager.send_to_task(task_id, {
|
||||||
|
"type": "dag_start",
|
||||||
|
"data": {
|
||||||
|
"graph": dag.to_dict()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def emit_node_start(task_id: str, node: TaskNode) -> None:
|
||||||
|
"""Emit node start event"""
|
||||||
|
await connection_manager.send_to_task(task_id, {
|
||||||
|
"type": "node_start",
|
||||||
|
"data": {
|
||||||
|
"node_id": node.id,
|
||||||
|
"name": node.name,
|
||||||
|
"status": node.status.value
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def emit_node_complete(task_id: str, node: TaskNode) -> None:
|
||||||
|
"""Emit node complete event"""
|
||||||
|
await connection_manager.send_to_task(task_id, {
|
||||||
|
"type": "node_complete",
|
||||||
|
"data": {
|
||||||
|
"node_id": node.id,
|
||||||
|
"result": node.result.to_dict() if node.result else None,
|
||||||
|
"output_data": node.output_data
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def emit_node_error(task_id: str, node: TaskNode, error: str) -> None:
|
||||||
|
"""Emit node error event"""
|
||||||
|
await connection_manager.send_to_task(task_id, {
|
||||||
|
"type": "node_error",
|
||||||
|
"data": {
|
||||||
|
"node_id": node.id,
|
||||||
|
"error": error
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def emit_dag_complete(task_id: str, success: bool, results: Dict) -> None:
|
||||||
|
"""Emit DAG complete event"""
|
||||||
|
await connection_manager.send_to_task(task_id, {
|
||||||
|
"type": "dag_complete",
|
||||||
|
"data": {
|
||||||
|
"success": success,
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
})
|
||||||
Loading…
Reference in New Issue