feat: 增加agent 的后端
This commit is contained in:
parent
22a4b8a4bb
commit
8089d94e78
|
|
@ -0,0 +1,16 @@
|
|||
"""Multi-Agent system module"""
|
||||
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
||||
from luxx.agents.registry import AgentRegistry
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"AgentConfig",
|
||||
"AgentType",
|
||||
"AgentStatus",
|
||||
"DAG",
|
||||
"TaskNode",
|
||||
"TaskNodeStatus",
|
||||
"TaskResult",
|
||||
"AgentRegistry",
|
||||
]
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
"""Agent core models"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
|
||||
from luxx.tools.core import CommandPermission
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
"""Agent type enumeration"""
|
||||
SUPERVISOR = "supervisor"
|
||||
WORKER = "worker"
|
||||
|
||||
|
||||
class AgentStatus(str, Enum):
|
||||
"""Agent status enumeration"""
|
||||
IDLE = "idle"
|
||||
PLANNING = "planning"
|
||||
EXECUTING = "executing"
|
||||
WAITING = "waiting"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Agent configuration"""
|
||||
name: str
|
||||
agent_type: AgentType
|
||||
description: str = ""
|
||||
max_permission: CommandPermission = CommandPermission.EXECUTE
|
||||
max_turns: int = 10 # Context window: sliding window size
|
||||
model: str = "deepseek-chat"
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
system_prompt: str = ""
|
||||
tools: List[str] = field(default_factory=list) # Tool names available to this agent
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Agent:
|
||||
"""
|
||||
Agent entity
|
||||
|
||||
Represents an AI agent with its configuration, state, and context.
|
||||
"""
|
||||
id: str
|
||||
config: AgentConfig
|
||||
status: AgentStatus = AgentStatus.IDLE
|
||||
user_id: Optional[int] = None
|
||||
conversation_id: Optional[str] = None
|
||||
workspace: Optional[str] = None
|
||||
|
||||
# Runtime state
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
# Context management
|
||||
context_window: List[Dict[str, Any]] = field(default_factory=list)
|
||||
accumulated_result: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Progress tracking
|
||||
current_task_id: Optional[str] = None
|
||||
progress: float = 0.0 # 0.0 - 1.0
|
||||
|
||||
# Permission (effective permission = min(user_permission, agent.max_permission))
|
||||
effective_permission: CommandPermission = field(default_factory=lambda: CommandPermission.EXECUTE)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialization processing"""
|
||||
if self.config.system_prompt and not self.context_window:
|
||||
# Initialize with system prompt
|
||||
self.context_window = [
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
]
|
||||
|
||||
def add_message(self, role: str, content: str) -> None:
|
||||
"""
|
||||
Add message to context window with sliding window management
|
||||
|
||||
Args:
|
||||
role: Message role (user/assistant/system)
|
||||
content: Message content
|
||||
"""
|
||||
self.context_window.append({"role": role, "content": content})
|
||||
self._trim_context()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def _trim_context(self) -> None:
|
||||
"""
|
||||
Trim context window using sliding window strategy
|
||||
|
||||
Keeps system prompt and the most recent N turns (user+assistant pairs)
|
||||
"""
|
||||
max_items = 1 + (self.config.max_turns * 2) # system + (max_turns * 2)
|
||||
|
||||
# Always keep system prompt at index 0
|
||||
if len(self.context_window) > max_items and len(self.context_window) > 1:
|
||||
# Keep system prompt and the most recent messages
|
||||
system_prompt = self.context_window[0]
|
||||
remaining = self.context_window[1:]
|
||||
trimmed = remaining[-(max_items - 1):]
|
||||
self.context_window = [system_prompt] + trimmed
|
||||
|
||||
def get_context(self) -> List[Dict[str, Any]]:
|
||||
"""Get current context window"""
|
||||
return self.context_window.copy()
|
||||
|
||||
def set_user_permission(self, user_permission: CommandPermission) -> None:
|
||||
"""
|
||||
Set effective permission based on user and agent limits
|
||||
|
||||
Effective permission = min(user_permission, agent.max_permission)
|
||||
|
||||
Args:
|
||||
user_permission: User's permission level
|
||||
"""
|
||||
self.effective_permission = min(user_permission, self.config.max_permission)
|
||||
|
||||
def store_result(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
Store result for supervisor's result-based context management
|
||||
|
||||
Args:
|
||||
key: Result key
|
||||
value: Result value
|
||||
"""
|
||||
self.accumulated_result[key] = value
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def get_result(self, key: str) -> Optional[Any]:
|
||||
"""Get stored result by key"""
|
||||
return self.accumulated_result.get(key)
|
||||
|
||||
def clear_context(self) -> None:
|
||||
"""Clear context but keep system prompt"""
|
||||
if self.context_window and self.context_window[0]["role"] == "system":
|
||||
system_prompt = self.context_window[0]
|
||||
self.context_window = [system_prompt]
|
||||
else:
|
||||
self.context_window = []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.config.name,
|
||||
"type": self.config.agent_type.value,
|
||||
"status": self.status.value,
|
||||
"user_id": self.user_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"workspace": self.workspace,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"current_task_id": self.current_task_id,
|
||||
"progress": self.progress,
|
||||
"effective_permission": self.effective_permission.name,
|
||||
}
|
||||
|
|
@ -0,0 +1,418 @@
|
|||
"""DAG (Directed Acyclic Graph) and TaskNode models"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class TaskNodeStatus(str, Enum):
|
||||
"""Task node status enumeration"""
|
||||
PENDING = "pending" # Not yet started
|
||||
READY = "ready" # Dependencies satisfied, can start
|
||||
RUNNING = "running" # Currently executing
|
||||
COMPLETED = "completed" # Successfully completed
|
||||
FAILED = "failed" # Execution failed
|
||||
CANCELLED = "cancelled" # Cancelled by user
|
||||
BLOCKED = "blocked" # Blocked by failed dependency
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""Task execution result"""
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
output_data: Optional[Dict[str, Any]] = None # Structured output for supervisor
|
||||
execution_time: float = 0.0 # seconds
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"success": self.success,
|
||||
"data": self.data,
|
||||
"error": self.error,
|
||||
"output_data": self.output_data,
|
||||
"execution_time": self.execution_time,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def ok(cls, data: Any = None, output_data: Optional[Dict[str, Any]] = None,
|
||||
execution_time: float = 0.0) -> "TaskResult":
|
||||
"""Create success result"""
|
||||
return cls(success=True, data=data, output_data=output_data, execution_time=execution_time)
|
||||
|
||||
@classmethod
|
||||
def fail(cls, error: str, data: Any = None) -> "TaskResult":
|
||||
"""Create failure result"""
|
||||
return cls(success=False, error=error, data=data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskNode:
|
||||
"""
|
||||
Task node in the DAG
|
||||
|
||||
Represents a single executable task within the agent workflow.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
# Task definition
|
||||
task_type: str = "generic" # e.g., "code", "shell", "file", "llm"
|
||||
task_data: Dict[str, Any] = field(default_factory=dict) # Task-specific parameters
|
||||
|
||||
# Dependencies (node IDs that must complete before this node)
|
||||
dependencies: List[str] = field(default_factory=list)
|
||||
|
||||
# Status tracking
|
||||
status: TaskNodeStatus = TaskNodeStatus.PENDING
|
||||
result: Optional[TaskResult] = None
|
||||
|
||||
# Progress tracking
|
||||
progress: float = 0.0 # 0.0 - 1.0
|
||||
progress_message: str = ""
|
||||
|
||||
# Execution info
|
||||
assigned_agent_id: Optional[str] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
# Output data (key-value pairs for dependent tasks)
|
||||
output_data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialization"""
|
||||
if not self.id:
|
||||
self.id = str(uuid.uuid4())[:8]
|
||||
|
||||
@property
|
||||
def is_root(self) -> bool:
|
||||
"""Check if this is a root node (no dependencies)"""
|
||||
return len(self.dependencies) == 0
|
||||
|
||||
@property
|
||||
def is_leaf(self) -> bool:
|
||||
"""Check if this is a leaf node (no dependents)"""
|
||||
return False # Will be set by DAG
|
||||
|
||||
@property
|
||||
def execution_time(self) -> float:
|
||||
"""Calculate execution time in seconds"""
|
||||
if self.started_at and self.completed_at:
|
||||
return (self.completed_at - self.started_at).total_seconds()
|
||||
return 0.0
|
||||
|
||||
def mark_ready(self) -> None:
|
||||
"""Mark node as ready to execute"""
|
||||
self.status = TaskNodeStatus.READY
|
||||
|
||||
def mark_running(self, agent_id: str) -> None:
|
||||
"""Mark node as running"""
|
||||
self.status = TaskNodeStatus.RUNNING
|
||||
self.assigned_agent_id = agent_id
|
||||
self.started_at = datetime.utcnow()
|
||||
|
||||
def mark_completed(self, result: TaskResult) -> None:
|
||||
"""Mark node as completed"""
|
||||
self.status = TaskNodeStatus.COMPLETED
|
||||
self.result = result
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.progress = 1.0
|
||||
if result.output_data:
|
||||
self.output_data = result.output_data
|
||||
|
||||
def mark_failed(self, error: str) -> None:
|
||||
"""Mark node as failed"""
|
||||
self.status = TaskNodeStatus.FAILED
|
||||
self.result = TaskResult.fail(error)
|
||||
self.completed_at = datetime.utcnow()
|
||||
|
||||
def mark_cancelled(self) -> None:
|
||||
"""Mark node as cancelled"""
|
||||
self.status = TaskNodeStatus.CANCELLED
|
||||
self.completed_at = datetime.utcnow()
|
||||
|
||||
def update_progress(self, progress: float, message: str = "") -> None:
|
||||
"""Update execution progress"""
|
||||
self.progress = max(0.0, min(1.0, progress))
|
||||
if message:
|
||||
self.progress_message = message
|
||||
|
||||
def can_execute(self, completed_nodes: Set[str]) -> bool:
|
||||
"""
|
||||
Check if this node can execute (all dependencies completed)
|
||||
|
||||
Args:
|
||||
completed_nodes: Set of completed node IDs
|
||||
"""
|
||||
if self.status != TaskNodeStatus.READY:
|
||||
return False
|
||||
return all(dep_id in completed_nodes for dep_id in self.dependencies)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"task_type": self.task_type,
|
||||
"task_data": self.task_data,
|
||||
"dependencies": self.dependencies,
|
||||
"status": self.status.value,
|
||||
"progress": self.progress,
|
||||
"progress_message": self.progress_message,
|
||||
"assigned_agent_id": self.assigned_agent_id,
|
||||
"result": self.result.to_dict() if self.result else None,
|
||||
"output_data": self.output_data,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DAG:
|
||||
"""
|
||||
Directed Acyclic Graph for task scheduling
|
||||
|
||||
Manages the task workflow with dependency tracking and parallel execution support.
|
||||
"""
|
||||
id: str
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
|
||||
# Nodes and edges
|
||||
nodes: Dict[str, TaskNode] = field(default_factory=dict)
|
||||
edges: List[tuple] = field(default_factory=list) # (from_id, to_id)
|
||||
|
||||
# Metadata
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
# Root and leaf tracking
|
||||
_root_nodes: Set[str] = field(default_factory=set)
|
||||
_leaf_nodes: Set[str] = field(default_factory=set)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialization"""
|
||||
if not self.id:
|
||||
self.id = str(uuid.uuid4())[:8]
|
||||
|
||||
def add_node(self, node: TaskNode) -> None:
|
||||
"""
|
||||
Add a task node to the DAG
|
||||
|
||||
Args:
|
||||
node: TaskNode to add
|
||||
"""
|
||||
self.nodes[node.id] = node
|
||||
self._update_root_leaf_cache()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def add_edge(self, from_id: str, to_id: str) -> None:
|
||||
"""
|
||||
Add an edge (dependency) between nodes
|
||||
|
||||
Args:
|
||||
from_id: Source node ID (must complete first)
|
||||
to_id: Target node ID (depends on source)
|
||||
"""
|
||||
if from_id not in self.nodes:
|
||||
raise ValueError(f"Source node '{from_id}' not found")
|
||||
if to_id not in self.nodes:
|
||||
raise ValueError(f"Target node '{to_id}' not found")
|
||||
|
||||
# Add dependency to target node
|
||||
if from_id not in self.nodes[to_id].dependencies:
|
||||
self.nodes[to_id].dependencies.append(from_id)
|
||||
|
||||
self.edges.append((from_id, to_id))
|
||||
self._update_root_leaf_cache()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def _update_root_leaf_cache(self) -> None:
|
||||
"""Update cached root and leaf node sets"""
|
||||
self._root_nodes = set()
|
||||
self._leaf_nodes = set()
|
||||
|
||||
for node_id in self.nodes:
|
||||
if len(self.nodes[node_id].dependencies) == 0:
|
||||
self._root_nodes.add(node_id)
|
||||
|
||||
# Check if node is a leaf (no one depends on it)
|
||||
is_leaf = True
|
||||
for other in self.nodes.values():
|
||||
if node_id in other.dependencies:
|
||||
is_leaf = False
|
||||
break
|
||||
if is_leaf:
|
||||
self._leaf_nodes.add(node_id)
|
||||
|
||||
@property
|
||||
def root_nodes(self) -> List[TaskNode]:
|
||||
"""Get all root nodes (nodes with no dependencies)"""
|
||||
return [self.nodes[nid] for nid in self._root_nodes if nid in self.nodes]
|
||||
|
||||
@property
|
||||
def leaf_nodes(self) -> List[TaskNode]:
|
||||
"""Get all leaf nodes (nodes no one depends on)"""
|
||||
return [self.nodes[nid] for nid in self._leaf_nodes if nid in self.nodes]
|
||||
|
||||
def get_ready_nodes(self, completed_nodes: Set[str]) -> List[TaskNode]:
|
||||
"""
|
||||
Get all nodes that are ready to execute
|
||||
|
||||
A node is ready if:
|
||||
1. All its dependencies are in completed_nodes
|
||||
2. It is not already running or completed
|
||||
|
||||
Args:
|
||||
completed_nodes: Set of completed node IDs
|
||||
"""
|
||||
ready = []
|
||||
for node in self.nodes.values():
|
||||
if node.status == TaskNodeStatus.READY and node.can_execute(completed_nodes):
|
||||
ready.append(node)
|
||||
return ready
|
||||
|
||||
def get_blocked_nodes(self, failed_node_id: str) -> List[TaskNode]:
|
||||
"""
|
||||
Get all nodes blocked by a failed node
|
||||
|
||||
Args:
|
||||
failed_node_id: ID of the failed node
|
||||
"""
|
||||
blocked = []
|
||||
for node in self.nodes.values():
|
||||
if failed_node_id in node.dependencies:
|
||||
if node.status not in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED):
|
||||
blocked.append(node)
|
||||
return blocked
|
||||
|
||||
def mark_node_completed(self, node_id: str, result: TaskResult) -> None:
|
||||
"""Mark a node as completed with result"""
|
||||
if node_id in self.nodes:
|
||||
self.nodes[node_id].mark_completed(result)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def mark_node_failed(self, node_id: str, error: str) -> None:
|
||||
"""Mark a node as failed"""
|
||||
if node_id in self.nodes:
|
||||
self.nodes[node_id].mark_failed(error)
|
||||
self._propagate_failure(node_id)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def _propagate_failure(self, failed_node_id: str) -> None:
|
||||
"""Propagate failure to dependent nodes"""
|
||||
for node in self.nodes.values():
|
||||
if failed_node_id in node.dependencies:
|
||||
if node.status == TaskNodeStatus.READY:
|
||||
node.status = TaskNodeStatus.BLOCKED
|
||||
|
||||
@property
|
||||
def completed_count(self) -> int:
|
||||
"""Count of completed nodes"""
|
||||
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.COMPLETED)
|
||||
|
||||
@property
|
||||
def failed_count(self) -> int:
|
||||
"""Count of failed nodes"""
|
||||
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.FAILED)
|
||||
|
||||
@property
|
||||
def total_count(self) -> int:
|
||||
"""Total number of nodes"""
|
||||
return len(self.nodes)
|
||||
|
||||
@property
|
||||
def progress(self) -> float:
|
||||
"""Overall DAG progress (0.0 - 1.0)"""
|
||||
if not self.nodes:
|
||||
return 0.0
|
||||
return sum(n.progress for n in self.nodes.values()) / len(self.nodes)
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if DAG execution is complete (all nodes done)"""
|
||||
return all(
|
||||
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED)
|
||||
for n in self.nodes.values()
|
||||
)
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Check if DAG completed successfully"""
|
||||
return self.is_complete and self.failed_count == 0
|
||||
|
||||
def get_execution_order(self) -> List[List[TaskNode]]:
|
||||
"""
|
||||
Get nodes grouped by execution level (parallel-friendly order)
|
||||
|
||||
Returns:
|
||||
List of node groups, where each group can be executed in parallel
|
||||
"""
|
||||
levels: List[List[TaskNode]] = []
|
||||
completed: Set[str] = set()
|
||||
|
||||
while len(completed) < len(self.nodes):
|
||||
# Find nodes whose dependencies are all completed
|
||||
current_level = []
|
||||
for node in self.nodes.values():
|
||||
if node.id not in completed and node.can_execute(completed):
|
||||
current_level.append(node)
|
||||
|
||||
if not current_level:
|
||||
break # Circular dependency or all blocked
|
||||
|
||||
levels.append(current_level)
|
||||
completed.update(n.id for n in current_level)
|
||||
|
||||
return levels
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"nodes": [n.to_dict() for n in self.nodes.values()],
|
||||
"edges": [{"from": e[0], "to": e[1]} for e in self.edges],
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"progress": self.progress,
|
||||
"completed_count": self.completed_count,
|
||||
"total_count": self.total_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DAG":
|
||||
"""
|
||||
Create DAG from dictionary
|
||||
|
||||
Args:
|
||||
data: Dictionary with DAG data
|
||||
"""
|
||||
dag = cls(
|
||||
id=data.get("id", str(uuid.uuid4())[:8]),
|
||||
name=data.get("name", ""),
|
||||
description=data.get("description", ""),
|
||||
)
|
||||
|
||||
# Recreate nodes
|
||||
for node_data in data.get("nodes", []):
|
||||
node = TaskNode(
|
||||
id=node_data["id"],
|
||||
name=node_data["name"],
|
||||
description=node_data.get("description", ""),
|
||||
task_type=node_data.get("task_type", "generic"),
|
||||
task_data=node_data.get("task_data", {}),
|
||||
dependencies=node_data.get("dependencies", []),
|
||||
)
|
||||
dag.add_node(node)
|
||||
|
||||
# Recreate edges
|
||||
for edge in data.get("edges", []):
|
||||
dag.add_edge(edge["from"], edge["to"])
|
||||
|
||||
return dag
|
||||
|
|
@ -0,0 +1,360 @@
|
|||
"""DAG Scheduler - orchestrates parallel task execution"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
|
||||
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
||||
from luxx.agents.supervisor import SupervisorAgent
|
||||
from luxx.agents.worker import WorkerAgent
|
||||
from luxx.tools.executor import ToolExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DAGScheduler:
|
||||
"""
|
||||
DAG Scheduler
|
||||
|
||||
Orchestrates parallel execution of tasks based on DAG structure.
|
||||
Features:
|
||||
- Parallel execution with max_workers limit
|
||||
- Dependency-aware scheduling
|
||||
- Real-time progress tracking
|
||||
- Cancellation support
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dag: DAG,
|
||||
supervisor: SupervisorAgent,
|
||||
worker_factory: Callable[[], WorkerAgent],
|
||||
max_workers: int = 3,
|
||||
tool_executor: ToolExecutor = None
|
||||
):
|
||||
"""
|
||||
Initialize DAG Scheduler
|
||||
|
||||
Args:
|
||||
dag: DAG to execute
|
||||
supervisor: Supervisor agent instance
|
||||
worker_factory: Factory function to create worker agents
|
||||
max_workers: Maximum parallel workers
|
||||
tool_executor: Tool executor instance
|
||||
"""
|
||||
self.dag = dag
|
||||
self.supervisor = supervisor
|
||||
self.worker_factory = worker_factory
|
||||
self.max_workers = max_workers
|
||||
self.tool_executor = tool_executor or ToolExecutor()
|
||||
|
||||
# Execution state
|
||||
self._running_nodes: Set[str] = set()
|
||||
self._completed_nodes: Set[str] = set()
|
||||
self._node_results: Dict[str, TaskResult] = {}
|
||||
self._parent_outputs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Control flags
|
||||
self._cancelled = False
|
||||
self._cancel_lock = threading.Lock()
|
||||
|
||||
# Progress callback
|
||||
self._progress_callback: Optional[Callable] = None
|
||||
|
||||
def set_progress_callback(self, callback: Optional[Callable]) -> None:
|
||||
"""Set progress callback for real-time updates"""
|
||||
self._progress_callback = callback
|
||||
|
||||
def _emit_progress(self, node_id: str, progress: float, message: str = "") -> None:
|
||||
"""Emit progress update"""
|
||||
if self._progress_callback:
|
||||
self._progress_callback(node_id, progress, message)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: Dict[str, Any],
|
||||
task: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the DAG
|
||||
|
||||
Args:
|
||||
context: Execution context (workspace, user info, etc.)
|
||||
task: Original task description
|
||||
|
||||
Returns:
|
||||
Execution results
|
||||
"""
|
||||
self._cancelled = False
|
||||
self._running_nodes.clear()
|
||||
self._completed_nodes.clear()
|
||||
self._node_results.clear()
|
||||
self._parent_outputs.clear()
|
||||
|
||||
# Emit DAG start
|
||||
self._emit_progress("dag", 0.0, "Starting DAG execution")
|
||||
|
||||
# Initialize nodes - mark root nodes as ready
|
||||
for node in self.dag.root_nodes:
|
||||
node.mark_ready()
|
||||
|
||||
# Main execution loop
|
||||
while not self._is_execution_complete():
|
||||
if self._is_cancelled():
|
||||
await self._cancel_running_nodes()
|
||||
return self._build_cancelled_result()
|
||||
|
||||
# Get ready nodes that can be executed
|
||||
ready_nodes = self._get_ready_nodes()
|
||||
|
||||
if not ready_nodes and not self._running_nodes:
|
||||
# No ready nodes and nothing running - deadlock or all blocked
|
||||
logger.warning("No ready nodes and no running nodes - possible deadlock")
|
||||
break
|
||||
|
||||
# Launch ready nodes up to max_workers
|
||||
nodes_to_launch = [
|
||||
n for n in ready_nodes
|
||||
if len(self._running_nodes) < self.max_workers
|
||||
and n.id not in self._running_nodes
|
||||
]
|
||||
|
||||
for node in nodes_to_launch:
|
||||
asyncio.create_task(self._execute_node(node, context))
|
||||
|
||||
# Small delay to prevent busy waiting
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Emit DAG completion
|
||||
success = self.dag.is_success
|
||||
self._emit_progress("dag", 1.0, "DAG execution complete" if success else "DAG execution failed")
|
||||
|
||||
return self._build_result()
|
||||
|
||||
def _is_execution_complete(self) -> bool:
|
||||
"""Check if execution is complete"""
|
||||
return all(
|
||||
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED, TaskNodeStatus.BLOCKED)
|
||||
for n in self.dag.nodes.values()
|
||||
)
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if execution was cancelled"""
|
||||
with self._cancel_lock:
|
||||
return self._cancelled
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel execution"""
|
||||
with self._cancel_lock:
|
||||
self._cancelled = True
|
||||
self._emit_progress("dag", 0.0, "Execution cancelled")
|
||||
|
||||
def _get_ready_nodes(self) -> List[TaskNode]:
|
||||
"""Get nodes that are ready to execute"""
|
||||
return [
|
||||
n for n in self.dag.nodes.values()
|
||||
if n.status == TaskNodeStatus.READY
|
||||
and n.id not in self._running_nodes
|
||||
and n.can_execute(self._completed_nodes)
|
||||
]
|
||||
|
||||
async def _execute_node(self, node: TaskNode, context: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Execute a single node
|
||||
|
||||
Args:
|
||||
node: Node to execute
|
||||
context: Execution context
|
||||
"""
|
||||
if self._is_cancelled():
|
||||
node.mark_cancelled()
|
||||
return
|
||||
|
||||
# Mark node as running
|
||||
self._running_nodes.add(node.id)
|
||||
node.mark_running(self.supervisor.agent.id)
|
||||
|
||||
self._emit_progress(node.id, 0.0, f"Starting: {node.name}")
|
||||
|
||||
# Collect parent outputs for this node
|
||||
parent_outputs = {}
|
||||
for dep_id in node.dependencies:
|
||||
if dep_id in self._node_results:
|
||||
parent_outputs[dep_id] = self._node_results[dep_id].output_data or {}
|
||||
|
||||
# Create worker for this task
|
||||
worker = self.worker_factory()
|
||||
|
||||
# Define progress callback for this node
|
||||
def node_progress(progress: float, message: str = ""):
|
||||
node.update_progress(progress, message)
|
||||
self._emit_progress(node.id, progress, message)
|
||||
|
||||
try:
|
||||
# Execute task
|
||||
result = await worker.execute_task(
|
||||
node,
|
||||
context,
|
||||
parent_outputs=parent_outputs,
|
||||
progress_callback=node_progress
|
||||
)
|
||||
|
||||
# Store result
|
||||
self._node_results[node.id] = result
|
||||
|
||||
if result.success:
|
||||
node.mark_completed(result)
|
||||
self._completed_nodes.add(node.id)
|
||||
self._emit_progress(node.id, 1.0, f"Completed: {node.name}")
|
||||
|
||||
# Check if any blocked nodes can now run
|
||||
self._unblock_nodes()
|
||||
else:
|
||||
node.mark_failed(result.error or "Unknown error")
|
||||
self._emit_progress(node.id, 0.0, f"Failed: {node.name} - {result.error}")
|
||||
|
||||
# Block dependent nodes
|
||||
self._block_dependent_nodes(node.id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Node {node.id} execution error: {e}")
|
||||
node.mark_failed(str(e))
|
||||
self._node_results[node.id] = TaskResult.fail(error=str(e))
|
||||
self._block_dependent_nodes(node.id)
|
||||
|
||||
finally:
|
||||
self._running_nodes.discard(node.id)
|
||||
|
||||
def _unblock_nodes(self) -> None:
|
||||
"""Unblock nodes whose dependencies are now satisfied"""
|
||||
for node in self.dag.nodes.values():
|
||||
if node.status == TaskNodeStatus.BLOCKED:
|
||||
if node.can_execute(self._completed_nodes):
|
||||
node.mark_ready()
|
||||
self._emit_progress(node.id, 0.0, f"Unblocked: {node.name}")
|
||||
|
||||
def _block_dependent_nodes(self, failed_node_id: str) -> None:
|
||||
"""Block nodes that depend on a failed node"""
|
||||
blocked = self.dag.get_blocked_nodes(failed_node_id)
|
||||
for node in blocked:
|
||||
node.status = TaskNodeStatus.BLOCKED
|
||||
self._emit_progress(node.id, 0.0, f"Blocked due to: {failed_node_id}")
|
||||
|
||||
async def _cancel_running_nodes(self) -> None:
|
||||
"""Cancel all running nodes"""
|
||||
for node_id in self._running_nodes:
|
||||
if node_id in self.dag.nodes:
|
||||
self.dag.nodes[node_id].mark_cancelled()
|
||||
self._running_nodes.clear()
|
||||
|
||||
def _build_result(self) -> Dict[str, Any]:
|
||||
"""Build final result"""
|
||||
return {
|
||||
"success": self.dag.is_success,
|
||||
"dag_id": self.dag.id,
|
||||
"total_tasks": self.dag.total_count,
|
||||
"completed_tasks": self.dag.completed_count,
|
||||
"failed_tasks": self.dag.failed_count,
|
||||
"progress": self.dag.progress,
|
||||
"results": {
|
||||
node_id: result.to_dict()
|
||||
for node_id, result in self._node_results.items()
|
||||
}
|
||||
}
|
||||
|
||||
def _build_cancelled_result(self) -> Dict[str, Any]:
|
||||
"""Build cancelled result"""
|
||||
return {
|
||||
"success": False,
|
||||
"cancelled": True,
|
||||
"dag_id": self.dag.id,
|
||||
"total_tasks": self.dag.total_count,
|
||||
"completed_tasks": self.dag.completed_count,
|
||||
"failed_tasks": self.dag.failed_count,
|
||||
"progress": self.dag.progress,
|
||||
"results": {
|
||||
node_id: result.to_dict()
|
||||
for node_id, result in self._node_results.items()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class SchedulerPool:
|
||||
"""
|
||||
Pool of DAG schedulers for managing multiple concurrent DAG executions
|
||||
"""
|
||||
|
||||
def __init__(self, max_concurrent: int = 10):
|
||||
"""
|
||||
Initialize scheduler pool
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum concurrent DAG executions
|
||||
"""
|
||||
self.max_concurrent = max_concurrent
|
||||
self._schedulers: Dict[str, DAGScheduler] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
task_id: str,
|
||||
dag: DAG,
|
||||
supervisor: SupervisorAgent,
|
||||
worker_factory: Callable,
|
||||
max_workers: int = 3
|
||||
) -> DAGScheduler:
|
||||
"""
|
||||
Create and register a new scheduler
|
||||
|
||||
Args:
|
||||
task_id: Unique task identifier
|
||||
dag: DAG to execute
|
||||
supervisor: Supervisor agent
|
||||
worker_factory: Worker factory function
|
||||
max_workers: Max parallel workers
|
||||
|
||||
Returns:
|
||||
Created scheduler
|
||||
"""
|
||||
with self._lock:
|
||||
if len(self._schedulers) >= self.max_concurrent:
|
||||
raise RuntimeError("Maximum concurrent schedulers reached")
|
||||
|
||||
scheduler = DAGScheduler(
|
||||
dag=dag,
|
||||
supervisor=supervisor,
|
||||
worker_factory=worker_factory,
|
||||
max_workers=max_workers
|
||||
)
|
||||
|
||||
self._schedulers[task_id] = scheduler
|
||||
return scheduler
|
||||
|
||||
def get(self, task_id: str) -> Optional[DAGScheduler]:
|
||||
"""Get scheduler by task ID"""
|
||||
with self._lock:
|
||||
return self._schedulers.get(task_id)
|
||||
|
||||
def remove(self, task_id: str) -> bool:
|
||||
"""Remove scheduler"""
|
||||
with self._lock:
|
||||
if task_id in self._schedulers:
|
||||
del self._schedulers[task_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def cancel(self, task_id: str) -> bool:
|
||||
"""Cancel a scheduler"""
|
||||
scheduler = self.get(task_id)
|
||||
if scheduler:
|
||||
scheduler.cancel()
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
"""Number of active schedulers"""
|
||||
with self._lock:
|
||||
return len(self._schedulers)
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
"""Agent Registry - manages agent lifecycle and access"""
|
||||
from typing import Dict, List, Optional, Set
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""
|
||||
Agent Registry (Singleton)
|
||||
|
||||
Thread-safe registry for managing all agents in the system.
|
||||
Provides agent creation, retrieval, and lifecycle management.
|
||||
"""
|
||||
_instance: Optional["AgentRegistry"] = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self._agents: Dict[str, Agent] = {}
|
||||
self._user_agents: Dict[int, Set[str]] = {} # user_id -> set of agent_ids
|
||||
self._conversation_agents: Dict[str, Set[str]] = {} # conversation_id -> set of agent_ids
|
||||
self._registry_lock = threading.Lock()
|
||||
self._initialized = True
|
||||
|
||||
def _generate_agent_id(self) -> str:
|
||||
"""Generate unique agent ID"""
|
||||
return f"agent_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
config: AgentConfig,
|
||||
user_id: Optional[int] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
) -> Agent:
|
||||
"""
|
||||
Create and register a new agent
|
||||
|
||||
Args:
|
||||
config: Agent configuration
|
||||
user_id: Associated user ID
|
||||
conversation_id: Associated conversation ID
|
||||
workspace: Agent's workspace path
|
||||
|
||||
Returns:
|
||||
Created Agent instance
|
||||
"""
|
||||
with self._registry_lock:
|
||||
agent_id = self._generate_agent_id()
|
||||
agent = Agent(
|
||||
id=agent_id,
|
||||
config=config,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
workspace=workspace,
|
||||
)
|
||||
|
||||
self._agents[agent_id] = agent
|
||||
|
||||
# Track by user
|
||||
if user_id:
|
||||
if user_id not in self._user_agents:
|
||||
self._user_agents[user_id] = set()
|
||||
self._user_agents[user_id].add(agent_id)
|
||||
|
||||
# Track by conversation
|
||||
if conversation_id:
|
||||
if conversation_id not in self._conversation_agents:
|
||||
self._conversation_agents[conversation_id] = set()
|
||||
self._conversation_agents[conversation_id].add(agent_id)
|
||||
|
||||
return agent
|
||||
|
||||
def get(self, agent_id: str) -> Optional[Agent]:
|
||||
"""
|
||||
Get agent by ID
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Agent if found, None otherwise
|
||||
"""
|
||||
with self._registry_lock:
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def list_user_agents(self, user_id: int) -> List[Agent]:
|
||||
"""
|
||||
List all agents for a user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of user's agents
|
||||
"""
|
||||
with self._registry_lock:
|
||||
agent_ids = self._user_agents.get(user_id, set())
|
||||
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
|
||||
|
||||
def list_conversation_agents(self, conversation_id: str) -> List[Agent]:
|
||||
"""
|
||||
List all agents in a conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
List of conversation's agents
|
||||
"""
|
||||
with self._registry_lock:
|
||||
agent_ids = self._conversation_agents.get(conversation_id, set())
|
||||
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
|
||||
|
||||
def list_by_type(self, agent_type: AgentType) -> List[Agent]:
|
||||
"""
|
||||
List all agents of a specific type
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent to filter
|
||||
|
||||
Returns:
|
||||
List of agents of the specified type
|
||||
"""
|
||||
with self._registry_lock:
|
||||
return [
|
||||
a for a in self._agents.values()
|
||||
if a.config.agent_type == agent_type
|
||||
]
|
||||
|
||||
def list_by_status(self, status: AgentStatus) -> List[Agent]:
|
||||
"""
|
||||
List all agents with a specific status
|
||||
|
||||
Args:
|
||||
status: Status to filter
|
||||
|
||||
Returns:
|
||||
List of agents with the specified status
|
||||
"""
|
||||
with self._registry_lock:
|
||||
return [a for a in self._agents.values() if a.status == status]
|
||||
|
||||
def update_status(self, agent_id: str, status: AgentStatus) -> bool:
|
||||
"""
|
||||
Update agent status
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
status: New status
|
||||
|
||||
Returns:
|
||||
True if updated, False if agent not found
|
||||
"""
|
||||
with self._registry_lock:
|
||||
agent = self._agents.get(agent_id)
|
||||
if agent:
|
||||
agent.status = status
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove(self, agent_id: str) -> bool:
|
||||
"""
|
||||
Remove agent from registry
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
with self._registry_lock:
|
||||
agent = self._agents.pop(agent_id, None)
|
||||
if agent is None:
|
||||
return False
|
||||
|
||||
# Remove from user tracking
|
||||
if agent.user_id and agent.user_id in self._user_agents:
|
||||
self._user_agents[agent.user_id].discard(agent_id)
|
||||
|
||||
# Remove from conversation tracking
|
||||
if agent.conversation_id and agent.conversation_id in self._conversation_agents:
|
||||
self._conversation_agents[agent.conversation_id].discard(agent_id)
|
||||
|
||||
return True
|
||||
|
||||
def remove_user_agents(self, user_id: int) -> int:
|
||||
"""
|
||||
Remove all agents for a user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of agents removed
|
||||
"""
|
||||
with self._registry_lock:
|
||||
agent_ids = self._user_agents.pop(user_id, set())
|
||||
count = 0
|
||||
for agent_id in agent_ids:
|
||||
if agent_id in self._agents:
|
||||
del self._agents[agent_id]
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def remove_conversation_agents(self, conversation_id: str) -> int:
|
||||
"""
|
||||
Remove all agents in a conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
Number of agents removed
|
||||
"""
|
||||
with self._registry_lock:
|
||||
agent_ids = self._conversation_agents.pop(conversation_id, set())
|
||||
count = 0
|
||||
for agent_id in agent_ids:
|
||||
if agent_id in self._agents:
|
||||
del self._agents[agent_id]
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all agents from registry"""
|
||||
with self._registry_lock:
|
||||
self._agents.clear()
|
||||
self._user_agents.clear()
|
||||
self._conversation_agents.clear()
|
||||
|
||||
@property
|
||||
def agent_count(self) -> int:
|
||||
"""Total number of agents"""
|
||||
with self._registry_lock:
|
||||
return len(self._agents)
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get registry statistics"""
|
||||
with self._registry_lock:
|
||||
status_counts = {}
|
||||
type_counts = {}
|
||||
for agent in self._agents.values():
|
||||
status_counts[agent.status.value] = status_counts.get(agent.status.value, 0) + 1
|
||||
type_counts[agent.config.agent_type.value] = type_counts.get(agent.config.agent_type.value, 0) + 1
|
||||
|
||||
return {
|
||||
"total": len(self._agents),
|
||||
"by_status": status_counts,
|
||||
"by_type": type_counts,
|
||||
}
|
||||
|
||||
|
||||
# Global registry instance
|
||||
registry = AgentRegistry()
|
||||
|
|
@ -0,0 +1,345 @@
|
|||
"""Supervisor Agent - task decomposition and result integration"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
|
||||
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
||||
from luxx.services.llm_client import llm_client
|
||||
from luxx.tools.core import registry as tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SupervisorAgent:
|
||||
"""
|
||||
Supervisor Agent
|
||||
|
||||
Responsible for:
|
||||
- Task decomposition using LLM
|
||||
- Generating DAG (task graph)
|
||||
- Result integration from workers
|
||||
"""
|
||||
|
||||
# System prompt for task decomposition
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a Supervisor Agent that decomposes complex tasks into executable subtasks.
|
||||
|
||||
Your responsibilities:
|
||||
1. Analyze the user's task and break it down into smaller, manageable subtasks
|
||||
2. Create a DAG (Directed Acyclic Graph) where nodes are subtasks and edges represent dependencies
|
||||
3. Each subtask should be specific and actionable
|
||||
4. Consider parallel execution opportunities - tasks without dependencies can run concurrently
|
||||
5. Store key results from subtasks for final integration
|
||||
|
||||
Output format for task decomposition:
|
||||
{
|
||||
"task_name": "Overall task name",
|
||||
"task_description": "Description of what needs to be accomplished",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "task_001",
|
||||
"name": "Task name",
|
||||
"description": "What this task does",
|
||||
"task_type": "code|shell|file|llm|generic",
|
||||
"task_data": {...}, # Task-specific parameters
|
||||
"dependencies": [] # IDs of tasks that must complete first
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Guidelines:
|
||||
- Keep tasks focused and atomic
|
||||
- Use meaningful task IDs (e.g., task_001, task_002)
|
||||
- Mark parallelizable tasks with no dependencies
|
||||
- Maximum 10 subtasks for a single decomposition
|
||||
- Include only the output_data that matters for dependent tasks or final result
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: Agent,
|
||||
llm_client=None,
|
||||
max_subtasks: int = 10
|
||||
):
|
||||
"""
|
||||
Initialize Supervisor Agent
|
||||
|
||||
Args:
|
||||
agent: Agent instance (should be SUPERVISOR type)
|
||||
llm_client: LLM client instance
|
||||
max_subtasks: Maximum number of subtasks to generate
|
||||
"""
|
||||
self.agent = agent
|
||||
self.llm_client = llm_client or llm_client
|
||||
self.max_subtasks = max_subtasks
|
||||
|
||||
# Ensure agent has supervisor system prompt
|
||||
if not self.agent.config.system_prompt:
|
||||
self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
async def decompose_task(
|
||||
self,
|
||||
task: str,
|
||||
context: Dict[str, Any],
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> DAG:
|
||||
"""
|
||||
Decompose a task into subtasks using LLM
|
||||
|
||||
Args:
|
||||
task: User's task description
|
||||
context: Execution context (workspace, user info, etc.)
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
DAG representing the task decomposition
|
||||
"""
|
||||
self.agent.status = AgentStatus.PLANNING
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.1, "Analyzing task...")
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.agent.get_context()
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"Decompose this task into subtasks:\n{task}"
|
||||
})
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.2, "Calling LLM for task decomposition...")
|
||||
|
||||
# Call LLM
|
||||
try:
|
||||
response = await self.llm_client.sync_call(
|
||||
model=self.agent.config.model,
|
||||
messages=messages,
|
||||
temperature=self.agent.config.temperature,
|
||||
max_tokens=self.agent.config.max_tokens
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.5, "Processing decomposition...")
|
||||
|
||||
# Parse LLM response to extract DAG
|
||||
dag = self._parse_dag_from_response(response.content, task)
|
||||
|
||||
# Add assistant response to context
|
||||
self.agent.add_message("assistant", response.content)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.9, "Task decomposition complete")
|
||||
|
||||
self.agent.status = AgentStatus.IDLE
|
||||
return dag
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task decomposition failed: {e}")
|
||||
self.agent.status = AgentStatus.FAILED
|
||||
raise
|
||||
|
||||
def _parse_dag_from_response(self, content: str, original_task: str) -> DAG:
|
||||
"""
|
||||
Parse LLM response to extract DAG structure
|
||||
|
||||
Args:
|
||||
content: LLM response content
|
||||
original_task: Original task description
|
||||
|
||||
Returns:
|
||||
DAG instance
|
||||
"""
|
||||
# Try to extract JSON from response
|
||||
dag_data = self._extract_json(content)
|
||||
|
||||
if not dag_data:
|
||||
# Fallback: create a simple single-node DAG
|
||||
logger.warning("Could not parse DAG from LLM response, creating simple DAG")
|
||||
dag = DAG(
|
||||
id=f"dag_{self.agent.id}",
|
||||
name=original_task[:50],
|
||||
description=original_task
|
||||
)
|
||||
node = TaskNode(
|
||||
id="task_001",
|
||||
name="Execute Task",
|
||||
description=original_task,
|
||||
task_type="llm",
|
||||
task_data={"prompt": original_task}
|
||||
)
|
||||
dag.add_node(node)
|
||||
return dag
|
||||
|
||||
# Build DAG from parsed data
|
||||
dag = DAG(
|
||||
id=f"dag_{self.agent.id}",
|
||||
name=dag_data.get("task_name", original_task[:50]),
|
||||
description=dag_data.get("task_description", original_task)
|
||||
)
|
||||
|
||||
# Add nodes
|
||||
for node_data in dag_data.get("nodes", []):
|
||||
node = TaskNode(
|
||||
id=node_data["id"],
|
||||
name=node_data["name"],
|
||||
description=node_data.get("description", ""),
|
||||
task_type=node_data.get("task_type", "generic"),
|
||||
task_data=node_data.get("task_data", {})
|
||||
)
|
||||
dag.add_node(node)
|
||||
|
||||
# Add edges based on dependencies
|
||||
for node_data in dag_data.get("nodes", []):
|
||||
node_id = node_data["id"]
|
||||
for dep_id in node_data.get("dependencies", []):
|
||||
if dep_id in dag.nodes:
|
||||
dag.add_edge(dep_id, node_id)
|
||||
|
||||
return dag
|
||||
|
||||
def _extract_json(self, content: str) -> Optional[Dict]:
|
||||
"""
|
||||
Extract JSON from LLM response
|
||||
|
||||
Args:
|
||||
content: Raw LLM response
|
||||
|
||||
Returns:
|
||||
Parsed JSON dict or None
|
||||
"""
|
||||
# Try to find JSON in markdown code blocks
|
||||
import re
|
||||
|
||||
# Look for ```json ... ``` blocks
|
||||
json_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Look for raw JSON object
|
||||
json_match = re.search(r"\{.*\}", content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def integrate_results(self, dag: DAG) -> Dict[str, Any]:
|
||||
"""
|
||||
Integrate results from all completed tasks
|
||||
|
||||
Args:
|
||||
dag: DAG with completed tasks
|
||||
|
||||
Returns:
|
||||
Integrated result
|
||||
"""
|
||||
self.agent.status = AgentStatus.EXECUTING
|
||||
|
||||
# Collect all output data from completed nodes
|
||||
results = {}
|
||||
for node in dag.nodes.values():
|
||||
if node.status == TaskNodeStatus.COMPLETED and node.output_data:
|
||||
results[node.id] = node.output_data
|
||||
|
||||
# Store aggregated results
|
||||
self.agent.accumulated_result = results
|
||||
self.agent.status = AgentStatus.COMPLETED
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"dag_id": dag.id,
|
||||
"total_tasks": dag.total_count,
|
||||
"completed_tasks": dag.completed_count,
|
||||
"failed_tasks": dag.failed_count,
|
||||
"results": results
|
||||
}
|
||||
|
||||
async def review_and_refine(
|
||||
self,
|
||||
dag: DAG,
|
||||
task: str,
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> Optional[DAG]:
|
||||
"""
|
||||
Review DAG execution and refine if needed
|
||||
|
||||
Args:
|
||||
dag: Current DAG state
|
||||
task: Original task
|
||||
progress_callback: Progress callback
|
||||
|
||||
Returns:
|
||||
Refined DAG or None if no refinement needed
|
||||
"""
|
||||
if dag.is_success:
|
||||
return None # No refinement needed
|
||||
|
||||
# Check if there are failed tasks
|
||||
failed_nodes = [n for n in dag.nodes.values() if n.status == TaskNodeStatus.FAILED]
|
||||
|
||||
if not failed_nodes:
|
||||
return None
|
||||
|
||||
# Build context for refinement
|
||||
context = {
|
||||
"original_task": task,
|
||||
"failed_tasks": [
|
||||
{
|
||||
"id": n.id,
|
||||
"name": n.name,
|
||||
"error": n.result.error if n.result else "Unknown error"
|
||||
}
|
||||
for n in failed_nodes
|
||||
],
|
||||
"completed_tasks": [
|
||||
{
|
||||
"id": n.id,
|
||||
"name": n.name,
|
||||
"output": n.output_data
|
||||
}
|
||||
for n in dag.nodes.values() if n.status == TaskNodeStatus.COMPLETED
|
||||
]
|
||||
}
|
||||
|
||||
messages = self.agent.get_context()
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"""Review the task execution and suggest refinements:
|
||||
|
||||
Task: {task}
|
||||
|
||||
Failed tasks: {json.dumps(context['failed_tasks'], indent=2)}
|
||||
|
||||
Completed tasks: {json.dumps(context['completed_tasks'], indent=2)}
|
||||
|
||||
If a task failed, you can:
|
||||
1. Break it into smaller tasks
|
||||
2. Change the approach
|
||||
3. Skip it if not critical
|
||||
|
||||
Provide a refined subtask plan if needed, or indicate if the overall task should fail."""
|
||||
})
|
||||
|
||||
try:
|
||||
response = await self.llm_client.sync_call(
|
||||
model=self.agent.config.model,
|
||||
messages=messages,
|
||||
temperature=self.agent.config.temperature
|
||||
)
|
||||
|
||||
# Check if refinement was suggested
|
||||
refined_dag = self._parse_dag_from_response(response.content, task)
|
||||
|
||||
# Only return if we got a valid refinement
|
||||
if refined_dag and refined_dag.nodes:
|
||||
return refined_dag
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DAG refinement failed: {e}")
|
||||
|
||||
return None
|
||||
|
|
@ -0,0 +1,401 @@
|
|||
"""Worker Agent - executes specific tasks"""
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
|
||||
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||
from luxx.agents.dag import TaskNode, TaskNodeStatus, TaskResult
|
||||
from luxx.services.llm_client import llm_client
|
||||
from luxx.tools.core import registry as tool_registry, ToolContext, CommandPermission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerAgent:
|
||||
"""
|
||||
Worker Agent
|
||||
|
||||
Responsible for executing specific tasks using:
|
||||
- LLM calls for reasoning tasks
|
||||
- Tool execution for actionable tasks
|
||||
|
||||
Follows sliding window context management.
|
||||
"""
|
||||
|
||||
# System prompt for worker tasks
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a Worker Agent that executes specific tasks efficiently.
|
||||
|
||||
Your responsibilities:
|
||||
1. Execute tasks assigned to you by the Supervisor
|
||||
2. Use appropriate tools when needed
|
||||
3. Report results clearly with structured output_data for dependent tasks
|
||||
4. Be concise and focused on the task at hand
|
||||
|
||||
Output format:
|
||||
- Provide clear, structured results
|
||||
- Include output_data for any data that dependent tasks might need
|
||||
- If a tool fails, explain the error clearly
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: Agent,
|
||||
llm_client=None,
|
||||
tool_executor=None
|
||||
):
|
||||
"""
|
||||
Initialize Worker Agent
|
||||
|
||||
Args:
|
||||
agent: Agent instance (should be WORKER type)
|
||||
llm_client: LLM client instance
|
||||
tool_executor: Tool executor instance
|
||||
"""
|
||||
self.agent = agent
|
||||
self.llm_client = llm_client or llm_client
|
||||
self.tool_executor = tool_executor
|
||||
|
||||
# Ensure agent has worker system prompt
|
||||
if not self.agent.config.system_prompt:
|
||||
self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
async def execute_task(
|
||||
self,
|
||||
task_node: TaskNode,
|
||||
context: Dict[str, Any],
|
||||
parent_outputs: Dict[str, Dict[str, Any]] = None,
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> TaskResult:
|
||||
"""
|
||||
Execute a task node
|
||||
|
||||
Args:
|
||||
task_node: Task node to execute
|
||||
context: Execution context (workspace, user info, etc.)
|
||||
parent_outputs: Output data from parent tasks (dependency results)
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
TaskResult with execution outcome
|
||||
"""
|
||||
self.agent.status = AgentStatus.EXECUTING
|
||||
self.agent.current_task_id = task_node.id
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.0, f"Starting task: {task_node.name}")
|
||||
|
||||
try:
|
||||
# Merge parent outputs into context
|
||||
execution_context = self._prepare_context(context, parent_outputs)
|
||||
|
||||
# Execute based on task type
|
||||
if task_node.task_type == "llm":
|
||||
result = await self._execute_llm_task(task_node, execution_context, progress_callback)
|
||||
elif task_node.task_type == "code":
|
||||
result = await self._execute_code_task(task_node, execution_context, progress_callback)
|
||||
elif task_node.task_type == "shell":
|
||||
result = await self._execute_shell_task(task_node, execution_context, progress_callback)
|
||||
elif task_node.task_type == "file":
|
||||
result = await self._execute_file_task(task_node, execution_context, progress_callback)
|
||||
else:
|
||||
result = await self._execute_generic_task(task_node, execution_context, progress_callback)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
result.execution_time = execution_time
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(1.0, f"Task complete: {task_node.name}")
|
||||
|
||||
self.agent.status = AgentStatus.IDLE
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task execution failed: {e}")
|
||||
execution_time = time.time() - start_time
|
||||
self.agent.status = AgentStatus.FAILED
|
||||
return TaskResult.fail(error=str(e))
|
||||
|
||||
def _prepare_context(
|
||||
self,
|
||||
context: Dict[str, Any],
|
||||
parent_outputs: Dict[str, Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare execution context by merging parent outputs
|
||||
|
||||
Args:
|
||||
context: Base context
|
||||
parent_outputs: Output from parent tasks
|
||||
|
||||
Returns:
|
||||
Merged context
|
||||
"""
|
||||
execution_context = context.copy()
|
||||
|
||||
if parent_outputs:
|
||||
# Merge parent outputs into context
|
||||
merged = {}
|
||||
for parent_id, outputs in parent_outputs.items():
|
||||
merged.update(outputs)
|
||||
execution_context["parent_outputs"] = parent_outputs
|
||||
execution_context["merged_data"] = merged
|
||||
|
||||
# Add user permission level
|
||||
if "user_permission_level" not in execution_context:
|
||||
execution_context["user_permission_level"] = self.agent.effective_permission.value
|
||||
|
||||
return execution_context
|
||||
|
||||
async def _execute_llm_task(
|
||||
self,
|
||||
task_node: TaskNode,
|
||||
context: Dict[str, Any],
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> TaskResult:
|
||||
"""Execute LLM reasoning task"""
|
||||
task_data = task_node.task_data
|
||||
|
||||
# Build prompt
|
||||
prompt = task_data.get("prompt", task_node.description)
|
||||
system_prompt = task_data.get("system", self.agent.config.system_prompt)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add parent data if available
|
||||
if "merged_data" in context:
|
||||
merged = context["merged_data"]
|
||||
context_info = "\n".join([f"{k}: {v}" for k, v in merged.items()])
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": f"Context from dependent tasks:\n{context_info}"
|
||||
})
|
||||
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.3, "Calling LLM...")
|
||||
|
||||
try:
|
||||
response = await self.llm_client.sync_call(
|
||||
model=self.agent.config.model,
|
||||
messages=messages,
|
||||
temperature=self.agent.config.temperature,
|
||||
max_tokens=self.agent.config.max_tokens
|
||||
)
|
||||
|
||||
return TaskResult.ok(
|
||||
data=response.content,
|
||||
output_data=task_node.task_data.get("output_template", {}).copy()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return TaskResult.fail(error=str(e))
|
||||
|
||||
async def _execute_code_task(
|
||||
self,
|
||||
task_node: TaskNode,
|
||||
context: Dict[str, Any],
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> TaskResult:
|
||||
"""Execute code generation/writing task"""
|
||||
task_data = task_node.task_data
|
||||
|
||||
# Build prompt for code generation
|
||||
prompt = task_data.get("prompt", task_node.description)
|
||||
language = task_data.get("language", "python")
|
||||
requirements = task_data.get("requirements", "")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": f"You are a {language} programmer. Write clean, efficient code."},
|
||||
{"role": "user", "content": f"Task: {prompt}\n\nRequirements: {requirements}"}
|
||||
]
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.3, "Generating code...")
|
||||
|
||||
try:
|
||||
response = await self.llm_client.sync_call(
|
||||
model=self.agent.config.model,
|
||||
messages=messages,
|
||||
temperature=0.2, # Lower temp for code
|
||||
max_tokens=4096
|
||||
)
|
||||
|
||||
return TaskResult.ok(
|
||||
data=response.content,
|
||||
output_data={
|
||||
"code": response.content,
|
||||
"language": language
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return TaskResult.fail(error=str(e))
|
||||
|
||||
async def _execute_shell_task(
|
||||
self,
|
||||
task_node: TaskNode,
|
||||
context: Dict[str, Any],
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> TaskResult:
|
||||
"""Execute shell command task"""
|
||||
task_data = task_node.task_data
|
||||
|
||||
command = task_data.get("command")
|
||||
if not command:
|
||||
return TaskResult.fail(error="No command specified")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.3, f"Executing: {command[:50]}...")
|
||||
|
||||
# Build tool context
|
||||
tool_ctx = ToolContext(
|
||||
workspace=context.get("workspace"),
|
||||
user_id=context.get("user_id"),
|
||||
username=context.get("username"),
|
||||
extra={"user_permission_level": context.get("user_permission_level", 1)}
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute shell command via tool
|
||||
result = tool_registry.execute(
|
||||
"shell_exec",
|
||||
{"command": command},
|
||||
context=tool_ctx
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
return TaskResult.ok(
|
||||
data=result.get("data", {}).get("output", ""),
|
||||
output_data={"output": result.get("data", {}).get("output", "")}
|
||||
)
|
||||
else:
|
||||
return TaskResult.fail(error=result.get("error", "Shell execution failed"))
|
||||
|
||||
except Exception as e:
|
||||
return TaskResult.fail(error=str(e))
|
||||
|
||||
async def _execute_file_task(
|
||||
self,
|
||||
task_node: TaskNode,
|
||||
context: Dict[str, Any],
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> TaskResult:
|
||||
"""Execute file operation task"""
|
||||
task_data = task_node.task_data
|
||||
|
||||
operation = task_data.get("operation")
|
||||
file_path = task_data.get("path")
|
||||
content = task_data.get("content", "")
|
||||
|
||||
if not operation or not file_path:
|
||||
return TaskResult.fail(error="Missing operation or path")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.3, f"File operation: {operation} {file_path}")
|
||||
|
||||
tool_ctx = ToolContext(
|
||||
workspace=context.get("workspace"),
|
||||
user_id=context.get("user_id"),
|
||||
username=context.get("username"),
|
||||
extra={"user_permission_level": context.get("user_permission_level", 1)}
|
||||
)
|
||||
|
||||
try:
|
||||
tool_name = f"file_{operation}"
|
||||
result = tool_registry.execute(
|
||||
tool_name,
|
||||
{"path": file_path, "content": content},
|
||||
context=tool_ctx
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
return TaskResult.ok(
|
||||
data=result.get("data"),
|
||||
output_data={"path": file_path, "operation": operation}
|
||||
)
|
||||
else:
|
||||
return TaskResult.fail(error=result.get("error", "File operation failed"))
|
||||
|
||||
except Exception as e:
|
||||
return TaskResult.fail(error=str(e))
|
||||
|
||||
async def _execute_generic_task(
|
||||
self,
|
||||
task_node: TaskNode,
|
||||
context: Dict[str, Any],
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> TaskResult:
|
||||
"""Execute generic task using LLM with tools"""
|
||||
task_data = task_node.task_data
|
||||
|
||||
# Build prompt
|
||||
prompt = task_data.get("prompt", task_node.description)
|
||||
tools = task_data.get("tools", [])
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.agent.config.system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
# Get tool definitions if specified
|
||||
tool_defs = None
|
||||
if tools:
|
||||
tool_defs = [tool_registry.get(t).to_openai_format() for t in tools if tool_registry.get(t)]
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.2, "Processing task...")
|
||||
|
||||
max_iterations = 5
|
||||
iteration = 0
|
||||
|
||||
while iteration < max_iterations:
|
||||
try:
|
||||
response = await self.llm_client.sync_call(
|
||||
model=self.agent.config.model,
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
temperature=self.agent.config.temperature,
|
||||
max_tokens=self.agent.config.max_tokens
|
||||
)
|
||||
|
||||
# Add assistant response
|
||||
messages.append({"role": "assistant", "content": response.content})
|
||||
|
||||
# Check for tool calls
|
||||
if response.tool_calls:
|
||||
if progress_callback:
|
||||
progress_callback(0.5, f"Executing {len(response.tool_calls)} tools...")
|
||||
|
||||
# Execute tools
|
||||
tool_results = self.tool_executor.process_tool_calls(
|
||||
response.tool_calls,
|
||||
context
|
||||
)
|
||||
|
||||
# Add tool results
|
||||
for tr in tool_results:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr["tool_call_id"],
|
||||
"content": tr["content"]
|
||||
})
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.8, "Tools executed")
|
||||
else:
|
||||
# No tool calls, task complete
|
||||
return TaskResult.ok(
|
||||
data=response.content,
|
||||
output_data=task_data.get("output_template", {})
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return TaskResult.fail(error=str(e))
|
||||
|
||||
iteration += 1
|
||||
|
||||
return TaskResult.fail(error="Max iterations exceeded")
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from luxx.routes import auth, conversations, messages, tools, providers
|
||||
from luxx.routes.agents_ws import router as agents_ws_router
|
||||
|
||||
|
||||
api_router = APIRouter()
|
||||
|
|
@ -12,3 +13,4 @@ api_router.include_router(conversations.router)
|
|||
api_router.include_router(messages.router)
|
||||
api_router.include_router(tools.router)
|
||||
api_router.include_router(providers.router)
|
||||
api_router.include_router(agents_ws_router)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,385 @@
|
|||
"""WebSocket routes for agent real-time communication"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, Optional, Set
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
||||
import uuid
|
||||
|
||||
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
||||
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
||||
from luxx.agents.supervisor import SupervisorAgent
|
||||
from luxx.agents.worker import WorkerAgent
|
||||
from luxx.agents.dag_scheduler import DAGScheduler, SchedulerPool
|
||||
from luxx.agents.registry import AgentRegistry
|
||||
from luxx.services.llm_client import llm_client
|
||||
from luxx.tools.executor import ToolExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""
|
||||
WebSocket Connection Manager
|
||||
|
||||
Manages WebSocket connections for real-time agent progress updates.
|
||||
Features:
|
||||
- Connection tracking by task_id
|
||||
- Heartbeat mechanism
|
||||
- Progress broadcasting
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# task_id -> set of websocket connections
|
||||
self._connections: Dict[str, Set[WebSocket]] = {}
|
||||
# websocket -> task_id mapping
|
||||
self._ws_to_task: Dict[WebSocket, str] = {}
|
||||
# heartbeat tasks
|
||||
self._heartbeat_tasks: Dict[WebSocket, asyncio.Task] = {}
|
||||
# lock for thread safety
|
||||
self._lock = threading.Lock()
|
||||
# heartbeat interval in seconds
|
||||
self._heartbeat_interval = 30
|
||||
|
||||
def connect(self, websocket: WebSocket, task_id: str) -> None:
|
||||
"""
|
||||
Accept and register a WebSocket connection
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
task_id: Task ID to subscribe to
|
||||
"""
|
||||
websocket.accept()
|
||||
|
||||
with self._lock:
|
||||
if task_id not in self._connections:
|
||||
self._connections[task_id] = set()
|
||||
self._connections[task_id].add(websocket)
|
||||
self._ws_to_task[websocket] = task_id
|
||||
|
||||
logger.info(f"WebSocket connected for task {task_id}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket) -> None:
|
||||
"""
|
||||
Unregister a WebSocket connection
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
with self._lock:
|
||||
task_id = self._ws_to_task.pop(websocket, None)
|
||||
if task_id and task_id in self._connections:
|
||||
self._connections[task_id].discard(websocket)
|
||||
if not self._connections[task_id]:
|
||||
del self._connections[task_id]
|
||||
|
||||
# Cancel heartbeat task
|
||||
if websocket in self._heartbeat_tasks:
|
||||
self._heartbeat_tasks[websocket].cancel()
|
||||
del self._heartbeat_tasks[websocket]
|
||||
|
||||
if task_id:
|
||||
logger.info(f"WebSocket disconnected for task {task_id}")
|
||||
|
||||
async def send_to_task(self, task_id: str, message: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Send message to all connections subscribed to a task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
message: Message to send
|
||||
"""
|
||||
with self._lock:
|
||||
connections = self._connections.get(task_id, set()).copy()
|
||||
|
||||
if not connections:
|
||||
return
|
||||
|
||||
dead_connections = set()
|
||||
|
||||
for websocket in connections:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send to websocket: {e}")
|
||||
dead_connections.add(websocket)
|
||||
|
||||
# Clean up dead connections
|
||||
for ws in dead_connections:
|
||||
self.disconnect(ws)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Broadcast message to all connections
|
||||
|
||||
Args:
|
||||
message: Message to broadcast
|
||||
"""
|
||||
with self._lock:
|
||||
all_connections = list(self._ws_to_task.keys())
|
||||
|
||||
for websocket in all_connections:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to broadcast: {e}")
|
||||
|
||||
async def send_personal(self, websocket: WebSocket, message: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Send message to a specific connection
|
||||
|
||||
Args:
|
||||
websocket: Target WebSocket
|
||||
message: Message to send
|
||||
"""
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send personal message: {e}")
|
||||
|
||||
def start_heartbeat(self, websocket: WebSocket) -> None:
|
||||
"""
|
||||
Start heartbeat for a connection
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
"""
|
||||
async def heartbeat_loop():
|
||||
while True:
|
||||
await asyncio.sleep(self._heartbeat_interval)
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"interval": self._heartbeat_interval
|
||||
})
|
||||
except Exception:
|
||||
break
|
||||
|
||||
task = asyncio.create_task(heartbeat_loop())
|
||||
with self._lock:
|
||||
self._heartbeat_tasks[websocket] = task
|
||||
|
||||
@property
|
||||
def connection_count(self) -> int:
|
||||
"""Total number of connections"""
|
||||
with self._lock:
|
||||
return len(self._ws_to_task)
|
||||
|
||||
def get_task_connections(self, task_id: str) -> int:
|
||||
"""Get number of connections for a task"""
|
||||
with self._lock:
|
||||
return len(self._connections.get(task_id, set()))
|
||||
|
||||
|
||||
# Global connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
# Global scheduler pool
|
||||
scheduler_pool = SchedulerPool(max_concurrent=10)
|
||||
|
||||
# Global tool executor
|
||||
tool_executor = ToolExecutor()
|
||||
|
||||
|
||||
class ProgressEmitter:
|
||||
"""
|
||||
Progress emitter that sends updates via WebSocket
|
||||
|
||||
Wraps DAGScheduler progress callback to emit WebSocket messages.
|
||||
"""
|
||||
|
||||
def __init__(self, task_id: str, connection_manager: ConnectionManager):
|
||||
self.task_id = task_id
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
def __call__(self, node_id: str, progress: float, message: str = "") -> None:
|
||||
"""Progress callback"""
|
||||
if node_id == "dag":
|
||||
# DAG-level progress
|
||||
asyncio.create_task(self.connection_manager.send_to_task(
|
||||
self.task_id,
|
||||
{
|
||||
"type": "dag_progress",
|
||||
"data": {
|
||||
"progress": progress,
|
||||
"message": message
|
||||
}
|
||||
}
|
||||
))
|
||||
else:
|
||||
# Node-level progress
|
||||
asyncio.create_task(self.connection_manager.send_to_task(
|
||||
self.task_id,
|
||||
{
|
||||
"type": "node_progress",
|
||||
"data": {
|
||||
"node_id": node_id,
|
||||
"progress": progress,
|
||||
"message": message
|
||||
}
|
||||
}
|
||||
))
|
||||
|
||||
|
||||
@router.websocket("/ws/dag/{task_id}")
|
||||
async def dag_websocket(websocket: WebSocket, task_id: str):
|
||||
"""
|
||||
WebSocket endpoint for DAG progress updates
|
||||
|
||||
Protocol:
|
||||
- Client sends: subscribe, get_status, ping, cancel_task
|
||||
- Server sends: subscribed, heartbeat, dag_start, node_start, node_progress,
|
||||
node_complete, node_error, dag_complete, pong
|
||||
"""
|
||||
# Accept connection
|
||||
connection_manager.connect(websocket, task_id)
|
||||
|
||||
# Send subscribed confirmation
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "subscribed",
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
# Start heartbeat
|
||||
connection_manager.start_heartbeat(websocket)
|
||||
|
||||
# Get scheduler if exists
|
||||
scheduler = scheduler_pool.get(task_id)
|
||||
if scheduler:
|
||||
# Send current DAG state
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "dag_status",
|
||||
"data": scheduler.dag.to_dict()
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Receive message
|
||||
data = await websocket.receive_text()
|
||||
|
||||
try:
|
||||
message = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": "Invalid JSON"
|
||||
})
|
||||
continue
|
||||
|
||||
msg_type = message.get("type")
|
||||
|
||||
if msg_type == "subscribe":
|
||||
# Already subscribed on connect
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "subscribed",
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
elif msg_type == "get_status":
|
||||
scheduler = scheduler_pool.get(task_id)
|
||||
if scheduler:
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "dag_status",
|
||||
"data": scheduler.dag.to_dict()
|
||||
})
|
||||
else:
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "dag_status",
|
||||
"data": None
|
||||
})
|
||||
|
||||
elif msg_type == "ping":
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "pong"
|
||||
})
|
||||
|
||||
elif msg_type == "cancel_task":
|
||||
if scheduler_pool.cancel(task_id):
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "task_cancelled",
|
||||
"task_id": task_id
|
||||
})
|
||||
else:
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": "Task not found or already completed"
|
||||
})
|
||||
|
||||
else:
|
||||
await connection_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": f"Unknown message type: {msg_type}"
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
connection_manager.disconnect(websocket)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
connection_manager.disconnect(websocket)
|
||||
|
||||
|
||||
# Helper functions to emit progress from scheduler
|
||||
|
||||
def create_progress_emitter(task_id: str) -> ProgressEmitter:
|
||||
"""Create a progress emitter for a task"""
|
||||
return ProgressEmitter(task_id, connection_manager)
|
||||
|
||||
|
||||
async def emit_dag_start(task_id: str, dag: DAG) -> None:
|
||||
"""Emit DAG start event"""
|
||||
await connection_manager.send_to_task(task_id, {
|
||||
"type": "dag_start",
|
||||
"data": {
|
||||
"graph": dag.to_dict()
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
async def emit_node_start(task_id: str, node: TaskNode) -> None:
|
||||
"""Emit node start event"""
|
||||
await connection_manager.send_to_task(task_id, {
|
||||
"type": "node_start",
|
||||
"data": {
|
||||
"node_id": node.id,
|
||||
"name": node.name,
|
||||
"status": node.status.value
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
async def emit_node_complete(task_id: str, node: TaskNode) -> None:
|
||||
"""Emit node complete event"""
|
||||
await connection_manager.send_to_task(task_id, {
|
||||
"type": "node_complete",
|
||||
"data": {
|
||||
"node_id": node.id,
|
||||
"result": node.result.to_dict() if node.result else None,
|
||||
"output_data": node.output_data
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
async def emit_node_error(task_id: str, node: TaskNode, error: str) -> None:
|
||||
"""Emit node error event"""
|
||||
await connection_manager.send_to_task(task_id, {
|
||||
"type": "node_error",
|
||||
"data": {
|
||||
"node_id": node.id,
|
||||
"error": error
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
async def emit_dag_complete(task_id: str, success: bool, results: Dict) -> None:
|
||||
"""Emit DAG complete event"""
|
||||
await connection_manager.send_to_task(task_id, {
|
||||
"type": "dag_complete",
|
||||
"data": {
|
||||
"success": success,
|
||||
"results": results
|
||||
}
|
||||
})
|
||||
Loading…
Reference in New Issue