feat: 增加agent 的后端

This commit is contained in:
ViperEkura 2026-04-15 21:36:50 +08:00
parent 22a4b8a4bb
commit 8089d94e78
9 changed files with 2354 additions and 0 deletions

16
luxx/agents/__init__.py Normal file
View File

@ -0,0 +1,16 @@
"""Multi-Agent system module"""
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
from luxx.agents.registry import AgentRegistry
__all__ = [
"Agent",
"AgentConfig",
"AgentType",
"AgentStatus",
"DAG",
"TaskNode",
"TaskNodeStatus",
"TaskResult",
"AgentRegistry",
]

161
luxx/agents/core.py Normal file
View File

@ -0,0 +1,161 @@
"""Agent core models"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from enum import Enum
from datetime import datetime
from luxx.tools.core import CommandPermission
class AgentType(str, Enum):
"""Agent type enumeration"""
SUPERVISOR = "supervisor"
WORKER = "worker"
class AgentStatus(str, Enum):
"""Agent status enumeration"""
IDLE = "idle"
PLANNING = "planning"
EXECUTING = "executing"
WAITING = "waiting"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class AgentConfig:
"""Agent configuration"""
name: str
agent_type: AgentType
description: str = ""
max_permission: CommandPermission = CommandPermission.EXECUTE
max_turns: int = 10 # Context window: sliding window size
model: str = "deepseek-chat"
temperature: float = 0.7
max_tokens: int = 4096
system_prompt: str = ""
tools: List[str] = field(default_factory=list) # Tool names available to this agent
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class Agent:
"""
Agent entity
Represents an AI agent with its configuration, state, and context.
"""
id: str
config: AgentConfig
status: AgentStatus = AgentStatus.IDLE
user_id: Optional[int] = None
conversation_id: Optional[str] = None
workspace: Optional[str] = None
# Runtime state
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
# Context management
context_window: List[Dict[str, Any]] = field(default_factory=list)
accumulated_result: Dict[str, Any] = field(default_factory=dict)
# Progress tracking
current_task_id: Optional[str] = None
progress: float = 0.0 # 0.0 - 1.0
# Permission (effective permission = min(user_permission, agent.max_permission))
effective_permission: CommandPermission = field(default_factory=lambda: CommandPermission.EXECUTE)
def __post_init__(self):
"""Post-initialization processing"""
if self.config.system_prompt and not self.context_window:
# Initialize with system prompt
self.context_window = [
{"role": "system", "content": self.config.system_prompt}
]
def add_message(self, role: str, content: str) -> None:
"""
Add message to context window with sliding window management
Args:
role: Message role (user/assistant/system)
content: Message content
"""
self.context_window.append({"role": role, "content": content})
self._trim_context()
self.updated_at = datetime.utcnow()
def _trim_context(self) -> None:
"""
Trim context window using sliding window strategy
Keeps system prompt and the most recent N turns (user+assistant pairs)
"""
max_items = 1 + (self.config.max_turns * 2) # system + (max_turns * 2)
# Always keep system prompt at index 0
if len(self.context_window) > max_items and len(self.context_window) > 1:
# Keep system prompt and the most recent messages
system_prompt = self.context_window[0]
remaining = self.context_window[1:]
trimmed = remaining[-(max_items - 1):]
self.context_window = [system_prompt] + trimmed
def get_context(self) -> List[Dict[str, Any]]:
"""Get current context window"""
return self.context_window.copy()
def set_user_permission(self, user_permission: CommandPermission) -> None:
"""
Set effective permission based on user and agent limits
Effective permission = min(user_permission, agent.max_permission)
Args:
user_permission: User's permission level
"""
self.effective_permission = min(user_permission, self.config.max_permission)
def store_result(self, key: str, value: Any) -> None:
"""
Store result for supervisor's result-based context management
Args:
key: Result key
value: Result value
"""
self.accumulated_result[key] = value
self.updated_at = datetime.utcnow()
def get_result(self, key: str) -> Optional[Any]:
"""Get stored result by key"""
return self.accumulated_result.get(key)
def clear_context(self) -> None:
"""Clear context but keep system prompt"""
if self.context_window and self.context_window[0]["role"] == "system":
system_prompt = self.context_window[0]
self.context_window = [system_prompt]
else:
self.context_window = []
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization"""
return {
"id": self.id,
"name": self.config.name,
"type": self.config.agent_type.value,
"status": self.status.value,
"user_id": self.user_id,
"conversation_id": self.conversation_id,
"workspace": self.workspace,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"current_task_id": self.current_task_id,
"progress": self.progress,
"effective_permission": self.effective_permission.name,
}

418
luxx/agents/dag.py Normal file
View File

@ -0,0 +1,418 @@
"""DAG (Directed Acyclic Graph) and TaskNode models"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
from enum import Enum
from datetime import datetime
import uuid
class TaskNodeStatus(str, Enum):
"""Task node status enumeration"""
PENDING = "pending" # Not yet started
READY = "ready" # Dependencies satisfied, can start
RUNNING = "running" # Currently executing
COMPLETED = "completed" # Successfully completed
FAILED = "failed" # Execution failed
CANCELLED = "cancelled" # Cancelled by user
BLOCKED = "blocked" # Blocked by failed dependency
@dataclass
class TaskResult:
"""Task execution result"""
success: bool
data: Any = None
error: Optional[str] = None
output_data: Optional[Dict[str, Any]] = None # Structured output for supervisor
execution_time: float = 0.0 # seconds
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"success": self.success,
"data": self.data,
"error": self.error,
"output_data": self.output_data,
"execution_time": self.execution_time,
}
@classmethod
def ok(cls, data: Any = None, output_data: Optional[Dict[str, Any]] = None,
execution_time: float = 0.0) -> "TaskResult":
"""Create success result"""
return cls(success=True, data=data, output_data=output_data, execution_time=execution_time)
@classmethod
def fail(cls, error: str, data: Any = None) -> "TaskResult":
"""Create failure result"""
return cls(success=False, error=error, data=data)
@dataclass
class TaskNode:
"""
Task node in the DAG
Represents a single executable task within the agent workflow.
"""
id: str
name: str
description: str = ""
# Task definition
task_type: str = "generic" # e.g., "code", "shell", "file", "llm"
task_data: Dict[str, Any] = field(default_factory=dict) # Task-specific parameters
# Dependencies (node IDs that must complete before this node)
dependencies: List[str] = field(default_factory=list)
# Status tracking
status: TaskNodeStatus = TaskNodeStatus.PENDING
result: Optional[TaskResult] = None
# Progress tracking
progress: float = 0.0 # 0.0 - 1.0
progress_message: str = ""
# Execution info
assigned_agent_id: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
# Output data (key-value pairs for dependent tasks)
output_data: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Post-initialization"""
if not self.id:
self.id = str(uuid.uuid4())[:8]
@property
def is_root(self) -> bool:
"""Check if this is a root node (no dependencies)"""
return len(self.dependencies) == 0
@property
def is_leaf(self) -> bool:
"""Check if this is a leaf node (no dependents)"""
return False # Will be set by DAG
@property
def execution_time(self) -> float:
"""Calculate execution time in seconds"""
if self.started_at and self.completed_at:
return (self.completed_at - self.started_at).total_seconds()
return 0.0
def mark_ready(self) -> None:
"""Mark node as ready to execute"""
self.status = TaskNodeStatus.READY
def mark_running(self, agent_id: str) -> None:
"""Mark node as running"""
self.status = TaskNodeStatus.RUNNING
self.assigned_agent_id = agent_id
self.started_at = datetime.utcnow()
def mark_completed(self, result: TaskResult) -> None:
"""Mark node as completed"""
self.status = TaskNodeStatus.COMPLETED
self.result = result
self.completed_at = datetime.utcnow()
self.progress = 1.0
if result.output_data:
self.output_data = result.output_data
def mark_failed(self, error: str) -> None:
"""Mark node as failed"""
self.status = TaskNodeStatus.FAILED
self.result = TaskResult.fail(error)
self.completed_at = datetime.utcnow()
def mark_cancelled(self) -> None:
"""Mark node as cancelled"""
self.status = TaskNodeStatus.CANCELLED
self.completed_at = datetime.utcnow()
def update_progress(self, progress: float, message: str = "") -> None:
"""Update execution progress"""
self.progress = max(0.0, min(1.0, progress))
if message:
self.progress_message = message
def can_execute(self, completed_nodes: Set[str]) -> bool:
"""
Check if this node can execute (all dependencies completed)
Args:
completed_nodes: Set of completed node IDs
"""
if self.status != TaskNodeStatus.READY:
return False
return all(dep_id in completed_nodes for dep_id in self.dependencies)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"task_type": self.task_type,
"task_data": self.task_data,
"dependencies": self.dependencies,
"status": self.status.value,
"progress": self.progress,
"progress_message": self.progress_message,
"assigned_agent_id": self.assigned_agent_id,
"result": self.result.to_dict() if self.result else None,
"output_data": self.output_data,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
@dataclass
class DAG:
"""
Directed Acyclic Graph for task scheduling
Manages the task workflow with dependency tracking and parallel execution support.
"""
id: str
name: str = ""
description: str = ""
# Nodes and edges
nodes: Dict[str, TaskNode] = field(default_factory=dict)
edges: List[tuple] = field(default_factory=list) # (from_id, to_id)
# Metadata
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
# Root and leaf tracking
_root_nodes: Set[str] = field(default_factory=set)
_leaf_nodes: Set[str] = field(default_factory=set)
def __post_init__(self):
"""Post-initialization"""
if not self.id:
self.id = str(uuid.uuid4())[:8]
def add_node(self, node: TaskNode) -> None:
"""
Add a task node to the DAG
Args:
node: TaskNode to add
"""
self.nodes[node.id] = node
self._update_root_leaf_cache()
self.updated_at = datetime.utcnow()
def add_edge(self, from_id: str, to_id: str) -> None:
"""
Add an edge (dependency) between nodes
Args:
from_id: Source node ID (must complete first)
to_id: Target node ID (depends on source)
"""
if from_id not in self.nodes:
raise ValueError(f"Source node '{from_id}' not found")
if to_id not in self.nodes:
raise ValueError(f"Target node '{to_id}' not found")
# Add dependency to target node
if from_id not in self.nodes[to_id].dependencies:
self.nodes[to_id].dependencies.append(from_id)
self.edges.append((from_id, to_id))
self._update_root_leaf_cache()
self.updated_at = datetime.utcnow()
def _update_root_leaf_cache(self) -> None:
"""Update cached root and leaf node sets"""
self._root_nodes = set()
self._leaf_nodes = set()
for node_id in self.nodes:
if len(self.nodes[node_id].dependencies) == 0:
self._root_nodes.add(node_id)
# Check if node is a leaf (no one depends on it)
is_leaf = True
for other in self.nodes.values():
if node_id in other.dependencies:
is_leaf = False
break
if is_leaf:
self._leaf_nodes.add(node_id)
@property
def root_nodes(self) -> List[TaskNode]:
"""Get all root nodes (nodes with no dependencies)"""
return [self.nodes[nid] for nid in self._root_nodes if nid in self.nodes]
@property
def leaf_nodes(self) -> List[TaskNode]:
"""Get all leaf nodes (nodes no one depends on)"""
return [self.nodes[nid] for nid in self._leaf_nodes if nid in self.nodes]
def get_ready_nodes(self, completed_nodes: Set[str]) -> List[TaskNode]:
"""
Get all nodes that are ready to execute
A node is ready if:
1. All its dependencies are in completed_nodes
2. It is not already running or completed
Args:
completed_nodes: Set of completed node IDs
"""
ready = []
for node in self.nodes.values():
if node.status == TaskNodeStatus.READY and node.can_execute(completed_nodes):
ready.append(node)
return ready
def get_blocked_nodes(self, failed_node_id: str) -> List[TaskNode]:
"""
Get all nodes blocked by a failed node
Args:
failed_node_id: ID of the failed node
"""
blocked = []
for node in self.nodes.values():
if failed_node_id in node.dependencies:
if node.status not in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED):
blocked.append(node)
return blocked
def mark_node_completed(self, node_id: str, result: TaskResult) -> None:
"""Mark a node as completed with result"""
if node_id in self.nodes:
self.nodes[node_id].mark_completed(result)
self.updated_at = datetime.utcnow()
def mark_node_failed(self, node_id: str, error: str) -> None:
"""Mark a node as failed"""
if node_id in self.nodes:
self.nodes[node_id].mark_failed(error)
self._propagate_failure(node_id)
self.updated_at = datetime.utcnow()
def _propagate_failure(self, failed_node_id: str) -> None:
"""Propagate failure to dependent nodes"""
for node in self.nodes.values():
if failed_node_id in node.dependencies:
if node.status == TaskNodeStatus.READY:
node.status = TaskNodeStatus.BLOCKED
@property
def completed_count(self) -> int:
"""Count of completed nodes"""
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.COMPLETED)
@property
def failed_count(self) -> int:
"""Count of failed nodes"""
return sum(1 for n in self.nodes.values() if n.status == TaskNodeStatus.FAILED)
@property
def total_count(self) -> int:
"""Total number of nodes"""
return len(self.nodes)
@property
def progress(self) -> float:
"""Overall DAG progress (0.0 - 1.0)"""
if not self.nodes:
return 0.0
return sum(n.progress for n in self.nodes.values()) / len(self.nodes)
@property
def is_complete(self) -> bool:
"""Check if DAG execution is complete (all nodes done)"""
return all(
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED)
for n in self.nodes.values()
)
@property
def is_success(self) -> bool:
"""Check if DAG completed successfully"""
return self.is_complete and self.failed_count == 0
def get_execution_order(self) -> List[List[TaskNode]]:
"""
Get nodes grouped by execution level (parallel-friendly order)
Returns:
List of node groups, where each group can be executed in parallel
"""
levels: List[List[TaskNode]] = []
completed: Set[str] = set()
while len(completed) < len(self.nodes):
# Find nodes whose dependencies are all completed
current_level = []
for node in self.nodes.values():
if node.id not in completed and node.can_execute(completed):
current_level.append(node)
if not current_level:
break # Circular dependency or all blocked
levels.append(current_level)
completed.update(n.id for n in current_level)
return levels
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization"""
return {
"id": self.id,
"name": self.name,
"description": self.description,
"nodes": [n.to_dict() for n in self.nodes.values()],
"edges": [{"from": e[0], "to": e[1]} for e in self.edges],
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"progress": self.progress,
"completed_count": self.completed_count,
"total_count": self.total_count,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DAG":
"""
Create DAG from dictionary
Args:
data: Dictionary with DAG data
"""
dag = cls(
id=data.get("id", str(uuid.uuid4())[:8]),
name=data.get("name", ""),
description=data.get("description", ""),
)
# Recreate nodes
for node_data in data.get("nodes", []):
node = TaskNode(
id=node_data["id"],
name=node_data["name"],
description=node_data.get("description", ""),
task_type=node_data.get("task_type", "generic"),
task_data=node_data.get("task_data", {}),
dependencies=node_data.get("dependencies", []),
)
dag.add_node(node)
# Recreate edges
for edge in data.get("edges", []):
dag.add_edge(edge["from"], edge["to"])
return dag

View File

@ -0,0 +1,360 @@
"""DAG Scheduler - orchestrates parallel task execution"""
import asyncio
import logging
from typing import Any, Callable, Dict, List, Optional, Set
from concurrent.futures import ThreadPoolExecutor
import threading
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
from luxx.agents.supervisor import SupervisorAgent
from luxx.agents.worker import WorkerAgent
from luxx.tools.executor import ToolExecutor
logger = logging.getLogger(__name__)
class DAGScheduler:
"""
DAG Scheduler
Orchestrates parallel execution of tasks based on DAG structure.
Features:
- Parallel execution with max_workers limit
- Dependency-aware scheduling
- Real-time progress tracking
- Cancellation support
"""
def __init__(
self,
dag: DAG,
supervisor: SupervisorAgent,
worker_factory: Callable[[], WorkerAgent],
max_workers: int = 3,
tool_executor: ToolExecutor = None
):
"""
Initialize DAG Scheduler
Args:
dag: DAG to execute
supervisor: Supervisor agent instance
worker_factory: Factory function to create worker agents
max_workers: Maximum parallel workers
tool_executor: Tool executor instance
"""
self.dag = dag
self.supervisor = supervisor
self.worker_factory = worker_factory
self.max_workers = max_workers
self.tool_executor = tool_executor or ToolExecutor()
# Execution state
self._running_nodes: Set[str] = set()
self._completed_nodes: Set[str] = set()
self._node_results: Dict[str, TaskResult] = {}
self._parent_outputs: Dict[str, Dict[str, Any]] = {}
# Control flags
self._cancelled = False
self._cancel_lock = threading.Lock()
# Progress callback
self._progress_callback: Optional[Callable] = None
def set_progress_callback(self, callback: Optional[Callable]) -> None:
"""Set progress callback for real-time updates"""
self._progress_callback = callback
def _emit_progress(self, node_id: str, progress: float, message: str = "") -> None:
"""Emit progress update"""
if self._progress_callback:
self._progress_callback(node_id, progress, message)
async def execute(
self,
context: Dict[str, Any],
task: str
) -> Dict[str, Any]:
"""
Execute the DAG
Args:
context: Execution context (workspace, user info, etc.)
task: Original task description
Returns:
Execution results
"""
self._cancelled = False
self._running_nodes.clear()
self._completed_nodes.clear()
self._node_results.clear()
self._parent_outputs.clear()
# Emit DAG start
self._emit_progress("dag", 0.0, "Starting DAG execution")
# Initialize nodes - mark root nodes as ready
for node in self.dag.root_nodes:
node.mark_ready()
# Main execution loop
while not self._is_execution_complete():
if self._is_cancelled():
await self._cancel_running_nodes()
return self._build_cancelled_result()
# Get ready nodes that can be executed
ready_nodes = self._get_ready_nodes()
if not ready_nodes and not self._running_nodes:
# No ready nodes and nothing running - deadlock or all blocked
logger.warning("No ready nodes and no running nodes - possible deadlock")
break
# Launch ready nodes up to max_workers
nodes_to_launch = [
n for n in ready_nodes
if len(self._running_nodes) < self.max_workers
and n.id not in self._running_nodes
]
for node in nodes_to_launch:
asyncio.create_task(self._execute_node(node, context))
# Small delay to prevent busy waiting
await asyncio.sleep(0.1)
# Emit DAG completion
success = self.dag.is_success
self._emit_progress("dag", 1.0, "DAG execution complete" if success else "DAG execution failed")
return self._build_result()
def _is_execution_complete(self) -> bool:
"""Check if execution is complete"""
return all(
n.status in (TaskNodeStatus.COMPLETED, TaskNodeStatus.FAILED, TaskNodeStatus.CANCELLED, TaskNodeStatus.BLOCKED)
for n in self.dag.nodes.values()
)
def _is_cancelled(self) -> bool:
"""Check if execution was cancelled"""
with self._cancel_lock:
return self._cancelled
def cancel(self) -> None:
"""Cancel execution"""
with self._cancel_lock:
self._cancelled = True
self._emit_progress("dag", 0.0, "Execution cancelled")
def _get_ready_nodes(self) -> List[TaskNode]:
"""Get nodes that are ready to execute"""
return [
n for n in self.dag.nodes.values()
if n.status == TaskNodeStatus.READY
and n.id not in self._running_nodes
and n.can_execute(self._completed_nodes)
]
async def _execute_node(self, node: TaskNode, context: Dict[str, Any]) -> None:
"""
Execute a single node
Args:
node: Node to execute
context: Execution context
"""
if self._is_cancelled():
node.mark_cancelled()
return
# Mark node as running
self._running_nodes.add(node.id)
node.mark_running(self.supervisor.agent.id)
self._emit_progress(node.id, 0.0, f"Starting: {node.name}")
# Collect parent outputs for this node
parent_outputs = {}
for dep_id in node.dependencies:
if dep_id in self._node_results:
parent_outputs[dep_id] = self._node_results[dep_id].output_data or {}
# Create worker for this task
worker = self.worker_factory()
# Define progress callback for this node
def node_progress(progress: float, message: str = ""):
node.update_progress(progress, message)
self._emit_progress(node.id, progress, message)
try:
# Execute task
result = await worker.execute_task(
node,
context,
parent_outputs=parent_outputs,
progress_callback=node_progress
)
# Store result
self._node_results[node.id] = result
if result.success:
node.mark_completed(result)
self._completed_nodes.add(node.id)
self._emit_progress(node.id, 1.0, f"Completed: {node.name}")
# Check if any blocked nodes can now run
self._unblock_nodes()
else:
node.mark_failed(result.error or "Unknown error")
self._emit_progress(node.id, 0.0, f"Failed: {node.name} - {result.error}")
# Block dependent nodes
self._block_dependent_nodes(node.id)
except Exception as e:
logger.error(f"Node {node.id} execution error: {e}")
node.mark_failed(str(e))
self._node_results[node.id] = TaskResult.fail(error=str(e))
self._block_dependent_nodes(node.id)
finally:
self._running_nodes.discard(node.id)
def _unblock_nodes(self) -> None:
"""Unblock nodes whose dependencies are now satisfied"""
for node in self.dag.nodes.values():
if node.status == TaskNodeStatus.BLOCKED:
if node.can_execute(self._completed_nodes):
node.mark_ready()
self._emit_progress(node.id, 0.0, f"Unblocked: {node.name}")
def _block_dependent_nodes(self, failed_node_id: str) -> None:
"""Block nodes that depend on a failed node"""
blocked = self.dag.get_blocked_nodes(failed_node_id)
for node in blocked:
node.status = TaskNodeStatus.BLOCKED
self._emit_progress(node.id, 0.0, f"Blocked due to: {failed_node_id}")
async def _cancel_running_nodes(self) -> None:
"""Cancel all running nodes"""
for node_id in self._running_nodes:
if node_id in self.dag.nodes:
self.dag.nodes[node_id].mark_cancelled()
self._running_nodes.clear()
def _build_result(self) -> Dict[str, Any]:
"""Build final result"""
return {
"success": self.dag.is_success,
"dag_id": self.dag.id,
"total_tasks": self.dag.total_count,
"completed_tasks": self.dag.completed_count,
"failed_tasks": self.dag.failed_count,
"progress": self.dag.progress,
"results": {
node_id: result.to_dict()
for node_id, result in self._node_results.items()
}
}
def _build_cancelled_result(self) -> Dict[str, Any]:
"""Build cancelled result"""
return {
"success": False,
"cancelled": True,
"dag_id": self.dag.id,
"total_tasks": self.dag.total_count,
"completed_tasks": self.dag.completed_count,
"failed_tasks": self.dag.failed_count,
"progress": self.dag.progress,
"results": {
node_id: result.to_dict()
for node_id, result in self._node_results.items()
}
}
class SchedulerPool:
"""
Pool of DAG schedulers for managing multiple concurrent DAG executions
"""
def __init__(self, max_concurrent: int = 10):
"""
Initialize scheduler pool
Args:
max_concurrent: Maximum concurrent DAG executions
"""
self.max_concurrent = max_concurrent
self._schedulers: Dict[str, DAGScheduler] = {}
self._lock = threading.Lock()
def create_scheduler(
self,
task_id: str,
dag: DAG,
supervisor: SupervisorAgent,
worker_factory: Callable,
max_workers: int = 3
) -> DAGScheduler:
"""
Create and register a new scheduler
Args:
task_id: Unique task identifier
dag: DAG to execute
supervisor: Supervisor agent
worker_factory: Worker factory function
max_workers: Max parallel workers
Returns:
Created scheduler
"""
with self._lock:
if len(self._schedulers) >= self.max_concurrent:
raise RuntimeError("Maximum concurrent schedulers reached")
scheduler = DAGScheduler(
dag=dag,
supervisor=supervisor,
worker_factory=worker_factory,
max_workers=max_workers
)
self._schedulers[task_id] = scheduler
return scheduler
def get(self, task_id: str) -> Optional[DAGScheduler]:
"""Get scheduler by task ID"""
with self._lock:
return self._schedulers.get(task_id)
def remove(self, task_id: str) -> bool:
"""Remove scheduler"""
with self._lock:
if task_id in self._schedulers:
del self._schedulers[task_id]
return True
return False
def cancel(self, task_id: str) -> bool:
"""Cancel a scheduler"""
scheduler = self.get(task_id)
if scheduler:
scheduler.cancel()
return True
return False
@property
def active_count(self) -> int:
"""Number of active schedulers"""
with self._lock:
return len(self._schedulers)

266
luxx/agents/registry.py Normal file
View File

@ -0,0 +1,266 @@
"""Agent Registry - manages agent lifecycle and access"""
from typing import Dict, List, Optional, Set
import threading
import uuid
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
class AgentRegistry:
"""
Agent Registry (Singleton)
Thread-safe registry for managing all agents in the system.
Provides agent creation, retrieval, and lifecycle management.
"""
_instance: Optional["AgentRegistry"] = None
_lock: threading.Lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._agents: Dict[str, Agent] = {}
self._user_agents: Dict[int, Set[str]] = {} # user_id -> set of agent_ids
self._conversation_agents: Dict[str, Set[str]] = {} # conversation_id -> set of agent_ids
self._registry_lock = threading.Lock()
self._initialized = True
def _generate_agent_id(self) -> str:
"""Generate unique agent ID"""
return f"agent_{uuid.uuid4().hex[:12]}"
def create_agent(
self,
config: AgentConfig,
user_id: Optional[int] = None,
conversation_id: Optional[str] = None,
workspace: Optional[str] = None,
) -> Agent:
"""
Create and register a new agent
Args:
config: Agent configuration
user_id: Associated user ID
conversation_id: Associated conversation ID
workspace: Agent's workspace path
Returns:
Created Agent instance
"""
with self._registry_lock:
agent_id = self._generate_agent_id()
agent = Agent(
id=agent_id,
config=config,
user_id=user_id,
conversation_id=conversation_id,
workspace=workspace,
)
self._agents[agent_id] = agent
# Track by user
if user_id:
if user_id not in self._user_agents:
self._user_agents[user_id] = set()
self._user_agents[user_id].add(agent_id)
# Track by conversation
if conversation_id:
if conversation_id not in self._conversation_agents:
self._conversation_agents[conversation_id] = set()
self._conversation_agents[conversation_id].add(agent_id)
return agent
def get(self, agent_id: str) -> Optional[Agent]:
"""
Get agent by ID
Args:
agent_id: Agent ID
Returns:
Agent if found, None otherwise
"""
with self._registry_lock:
return self._agents.get(agent_id)
def list_user_agents(self, user_id: int) -> List[Agent]:
"""
List all agents for a user
Args:
user_id: User ID
Returns:
List of user's agents
"""
with self._registry_lock:
agent_ids = self._user_agents.get(user_id, set())
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
def list_conversation_agents(self, conversation_id: str) -> List[Agent]:
"""
List all agents in a conversation
Args:
conversation_id: Conversation ID
Returns:
List of conversation's agents
"""
with self._registry_lock:
agent_ids = self._conversation_agents.get(conversation_id, set())
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
def list_by_type(self, agent_type: AgentType) -> List[Agent]:
"""
List all agents of a specific type
Args:
agent_type: Type of agent to filter
Returns:
List of agents of the specified type
"""
with self._registry_lock:
return [
a for a in self._agents.values()
if a.config.agent_type == agent_type
]
def list_by_status(self, status: AgentStatus) -> List[Agent]:
"""
List all agents with a specific status
Args:
status: Status to filter
Returns:
List of agents with the specified status
"""
with self._registry_lock:
return [a for a in self._agents.values() if a.status == status]
def update_status(self, agent_id: str, status: AgentStatus) -> bool:
"""
Update agent status
Args:
agent_id: Agent ID
status: New status
Returns:
True if updated, False if agent not found
"""
with self._registry_lock:
agent = self._agents.get(agent_id)
if agent:
agent.status = status
return True
return False
def remove(self, agent_id: str) -> bool:
"""
Remove agent from registry
Args:
agent_id: Agent ID
Returns:
True if removed, False if not found
"""
with self._registry_lock:
agent = self._agents.pop(agent_id, None)
if agent is None:
return False
# Remove from user tracking
if agent.user_id and agent.user_id in self._user_agents:
self._user_agents[agent.user_id].discard(agent_id)
# Remove from conversation tracking
if agent.conversation_id and agent.conversation_id in self._conversation_agents:
self._conversation_agents[agent.conversation_id].discard(agent_id)
return True
def remove_user_agents(self, user_id: int) -> int:
"""
Remove all agents for a user
Args:
user_id: User ID
Returns:
Number of agents removed
"""
with self._registry_lock:
agent_ids = self._user_agents.pop(user_id, set())
count = 0
for agent_id in agent_ids:
if agent_id in self._agents:
del self._agents[agent_id]
count += 1
return count
def remove_conversation_agents(self, conversation_id: str) -> int:
"""
Remove all agents in a conversation
Args:
conversation_id: Conversation ID
Returns:
Number of agents removed
"""
with self._registry_lock:
agent_ids = self._conversation_agents.pop(conversation_id, set())
count = 0
for agent_id in agent_ids:
if agent_id in self._agents:
del self._agents[agent_id]
count += 1
return count
def clear(self) -> None:
"""Clear all agents from registry"""
with self._registry_lock:
self._agents.clear()
self._user_agents.clear()
self._conversation_agents.clear()
@property
def agent_count(self) -> int:
"""Total number of agents"""
with self._registry_lock:
return len(self._agents)
def get_stats(self) -> Dict:
"""Get registry statistics"""
with self._registry_lock:
status_counts = {}
type_counts = {}
for agent in self._agents.values():
status_counts[agent.status.value] = status_counts.get(agent.status.value, 0) + 1
type_counts[agent.config.agent_type.value] = type_counts.get(agent.config.agent_type.value, 0) + 1
return {
"total": len(self._agents),
"by_status": status_counts,
"by_type": type_counts,
}
# Global registry instance
registry = AgentRegistry()

345
luxx/agents/supervisor.py Normal file
View File

@ -0,0 +1,345 @@
"""Supervisor Agent - task decomposition and result integration"""
import json
import logging
from typing import Any, Dict, List, Optional, Callable
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
from luxx.services.llm_client import llm_client
from luxx.tools.core import registry as tool_registry
logger = logging.getLogger(__name__)
class SupervisorAgent:
"""
Supervisor Agent
Responsible for:
- Task decomposition using LLM
- Generating DAG (task graph)
- Result integration from workers
"""
# System prompt for task decomposition
DEFAULT_SYSTEM_PROMPT = """You are a Supervisor Agent that decomposes complex tasks into executable subtasks.
Your responsibilities:
1. Analyze the user's task and break it down into smaller, manageable subtasks
2. Create a DAG (Directed Acyclic Graph) where nodes are subtasks and edges represent dependencies
3. Each subtask should be specific and actionable
4. Consider parallel execution opportunities - tasks without dependencies can run concurrently
5. Store key results from subtasks for final integration
Output format for task decomposition:
{
"task_name": "Overall task name",
"task_description": "Description of what needs to be accomplished",
"nodes": [
{
"id": "task_001",
"name": "Task name",
"description": "What this task does",
"task_type": "code|shell|file|llm|generic",
"task_data": {...}, # Task-specific parameters
"dependencies": [] # IDs of tasks that must complete first
}
]
}
Guidelines:
- Keep tasks focused and atomic
- Use meaningful task IDs (e.g., task_001, task_002)
- Mark parallelizable tasks with no dependencies
- Maximum 10 subtasks for a single decomposition
- Include only the output_data that matters for dependent tasks or final result
"""
def __init__(
self,
agent: Agent,
llm_client=None,
max_subtasks: int = 10
):
"""
Initialize Supervisor Agent
Args:
agent: Agent instance (should be SUPERVISOR type)
llm_client: LLM client instance
max_subtasks: Maximum number of subtasks to generate
"""
self.agent = agent
self.llm_client = llm_client or llm_client
self.max_subtasks = max_subtasks
# Ensure agent has supervisor system prompt
if not self.agent.config.system_prompt:
self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT
async def decompose_task(
self,
task: str,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> DAG:
"""
Decompose a task into subtasks using LLM
Args:
task: User's task description
context: Execution context (workspace, user info, etc.)
progress_callback: Optional callback for progress updates
Returns:
DAG representing the task decomposition
"""
self.agent.status = AgentStatus.PLANNING
if progress_callback:
progress_callback(0.1, "Analyzing task...")
# Build messages for LLM
messages = self.agent.get_context()
messages.append({
"role": "user",
"content": f"Decompose this task into subtasks:\n{task}"
})
if progress_callback:
progress_callback(0.2, "Calling LLM for task decomposition...")
# Call LLM
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
temperature=self.agent.config.temperature,
max_tokens=self.agent.config.max_tokens
)
if progress_callback:
progress_callback(0.5, "Processing decomposition...")
# Parse LLM response to extract DAG
dag = self._parse_dag_from_response(response.content, task)
# Add assistant response to context
self.agent.add_message("assistant", response.content)
if progress_callback:
progress_callback(0.9, "Task decomposition complete")
self.agent.status = AgentStatus.IDLE
return dag
except Exception as e:
logger.error(f"Task decomposition failed: {e}")
self.agent.status = AgentStatus.FAILED
raise
def _parse_dag_from_response(self, content: str, original_task: str) -> DAG:
"""
Parse LLM response to extract DAG structure
Args:
content: LLM response content
original_task: Original task description
Returns:
DAG instance
"""
# Try to extract JSON from response
dag_data = self._extract_json(content)
if not dag_data:
# Fallback: create a simple single-node DAG
logger.warning("Could not parse DAG from LLM response, creating simple DAG")
dag = DAG(
id=f"dag_{self.agent.id}",
name=original_task[:50],
description=original_task
)
node = TaskNode(
id="task_001",
name="Execute Task",
description=original_task,
task_type="llm",
task_data={"prompt": original_task}
)
dag.add_node(node)
return dag
# Build DAG from parsed data
dag = DAG(
id=f"dag_{self.agent.id}",
name=dag_data.get("task_name", original_task[:50]),
description=dag_data.get("task_description", original_task)
)
# Add nodes
for node_data in dag_data.get("nodes", []):
node = TaskNode(
id=node_data["id"],
name=node_data["name"],
description=node_data.get("description", ""),
task_type=node_data.get("task_type", "generic"),
task_data=node_data.get("task_data", {})
)
dag.add_node(node)
# Add edges based on dependencies
for node_data in dag_data.get("nodes", []):
node_id = node_data["id"]
for dep_id in node_data.get("dependencies", []):
if dep_id in dag.nodes:
dag.add_edge(dep_id, node_id)
return dag
def _extract_json(self, content: str) -> Optional[Dict]:
"""
Extract JSON from LLM response
Args:
content: Raw LLM response
Returns:
Parsed JSON dict or None
"""
# Try to find JSON in markdown code blocks
import re
# Look for ```json ... ``` blocks
json_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(1))
except json.JSONDecodeError:
pass
# Look for raw JSON object
json_match = re.search(r"\{.*\}", content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(0))
except json.JSONDecodeError:
pass
return None
def integrate_results(self, dag: DAG) -> Dict[str, Any]:
"""
Integrate results from all completed tasks
Args:
dag: DAG with completed tasks
Returns:
Integrated result
"""
self.agent.status = AgentStatus.EXECUTING
# Collect all output data from completed nodes
results = {}
for node in dag.nodes.values():
if node.status == TaskNodeStatus.COMPLETED and node.output_data:
results[node.id] = node.output_data
# Store aggregated results
self.agent.accumulated_result = results
self.agent.status = AgentStatus.COMPLETED
return {
"success": True,
"dag_id": dag.id,
"total_tasks": dag.total_count,
"completed_tasks": dag.completed_count,
"failed_tasks": dag.failed_count,
"results": results
}
async def review_and_refine(
self,
dag: DAG,
task: str,
progress_callback: Optional[Callable] = None
) -> Optional[DAG]:
"""
Review DAG execution and refine if needed
Args:
dag: Current DAG state
task: Original task
progress_callback: Progress callback
Returns:
Refined DAG or None if no refinement needed
"""
if dag.is_success:
return None # No refinement needed
# Check if there are failed tasks
failed_nodes = [n for n in dag.nodes.values() if n.status == TaskNodeStatus.FAILED]
if not failed_nodes:
return None
# Build context for refinement
context = {
"original_task": task,
"failed_tasks": [
{
"id": n.id,
"name": n.name,
"error": n.result.error if n.result else "Unknown error"
}
for n in failed_nodes
],
"completed_tasks": [
{
"id": n.id,
"name": n.name,
"output": n.output_data
}
for n in dag.nodes.values() if n.status == TaskNodeStatus.COMPLETED
]
}
messages = self.agent.get_context()
messages.append({
"role": "user",
"content": f"""Review the task execution and suggest refinements:
Task: {task}
Failed tasks: {json.dumps(context['failed_tasks'], indent=2)}
Completed tasks: {json.dumps(context['completed_tasks'], indent=2)}
If a task failed, you can:
1. Break it into smaller tasks
2. Change the approach
3. Skip it if not critical
Provide a refined subtask plan if needed, or indicate if the overall task should fail."""
})
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
temperature=self.agent.config.temperature
)
# Check if refinement was suggested
refined_dag = self._parse_dag_from_response(response.content, task)
# Only return if we got a valid refinement
if refined_dag and refined_dag.nodes:
return refined_dag
except Exception as e:
logger.error(f"DAG refinement failed: {e}")
return None

401
luxx/agents/worker.py Normal file
View File

@ -0,0 +1,401 @@
"""Worker Agent - executes specific tasks"""
import json
import logging
import time
from typing import Any, Dict, List, Optional, Callable
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
from luxx.agents.dag import TaskNode, TaskNodeStatus, TaskResult
from luxx.services.llm_client import llm_client
from luxx.tools.core import registry as tool_registry, ToolContext, CommandPermission
logger = logging.getLogger(__name__)
class WorkerAgent:
"""
Worker Agent
Responsible for executing specific tasks using:
- LLM calls for reasoning tasks
- Tool execution for actionable tasks
Follows sliding window context management.
"""
# System prompt for worker tasks
DEFAULT_SYSTEM_PROMPT = """You are a Worker Agent that executes specific tasks efficiently.
Your responsibilities:
1. Execute tasks assigned to you by the Supervisor
2. Use appropriate tools when needed
3. Report results clearly with structured output_data for dependent tasks
4. Be concise and focused on the task at hand
Output format:
- Provide clear, structured results
- Include output_data for any data that dependent tasks might need
- If a tool fails, explain the error clearly
"""
def __init__(
self,
agent: Agent,
llm_client=None,
tool_executor=None
):
"""
Initialize Worker Agent
Args:
agent: Agent instance (should be WORKER type)
llm_client: LLM client instance
tool_executor: Tool executor instance
"""
self.agent = agent
self.llm_client = llm_client or llm_client
self.tool_executor = tool_executor
# Ensure agent has worker system prompt
if not self.agent.config.system_prompt:
self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT
async def execute_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
parent_outputs: Dict[str, Dict[str, Any]] = None,
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""
Execute a task node
Args:
task_node: Task node to execute
context: Execution context (workspace, user info, etc.)
parent_outputs: Output data from parent tasks (dependency results)
progress_callback: Optional callback for progress updates
Returns:
TaskResult with execution outcome
"""
self.agent.status = AgentStatus.EXECUTING
self.agent.current_task_id = task_node.id
start_time = time.time()
if progress_callback:
progress_callback(0.0, f"Starting task: {task_node.name}")
try:
# Merge parent outputs into context
execution_context = self._prepare_context(context, parent_outputs)
# Execute based on task type
if task_node.task_type == "llm":
result = await self._execute_llm_task(task_node, execution_context, progress_callback)
elif task_node.task_type == "code":
result = await self._execute_code_task(task_node, execution_context, progress_callback)
elif task_node.task_type == "shell":
result = await self._execute_shell_task(task_node, execution_context, progress_callback)
elif task_node.task_type == "file":
result = await self._execute_file_task(task_node, execution_context, progress_callback)
else:
result = await self._execute_generic_task(task_node, execution_context, progress_callback)
execution_time = time.time() - start_time
result.execution_time = execution_time
if progress_callback:
progress_callback(1.0, f"Task complete: {task_node.name}")
self.agent.status = AgentStatus.IDLE
return result
except Exception as e:
logger.error(f"Task execution failed: {e}")
execution_time = time.time() - start_time
self.agent.status = AgentStatus.FAILED
return TaskResult.fail(error=str(e))
def _prepare_context(
self,
context: Dict[str, Any],
parent_outputs: Dict[str, Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Prepare execution context by merging parent outputs
Args:
context: Base context
parent_outputs: Output from parent tasks
Returns:
Merged context
"""
execution_context = context.copy()
if parent_outputs:
# Merge parent outputs into context
merged = {}
for parent_id, outputs in parent_outputs.items():
merged.update(outputs)
execution_context["parent_outputs"] = parent_outputs
execution_context["merged_data"] = merged
# Add user permission level
if "user_permission_level" not in execution_context:
execution_context["user_permission_level"] = self.agent.effective_permission.value
return execution_context
async def _execute_llm_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute LLM reasoning task"""
task_data = task_node.task_data
# Build prompt
prompt = task_data.get("prompt", task_node.description)
system_prompt = task_data.get("system", self.agent.config.system_prompt)
messages = [{"role": "system", "content": system_prompt}]
# Add parent data if available
if "merged_data" in context:
merged = context["merged_data"]
context_info = "\n".join([f"{k}: {v}" for k, v in merged.items()])
messages.append({
"role": "system",
"content": f"Context from dependent tasks:\n{context_info}"
})
messages.append({"role": "user", "content": prompt})
if progress_callback:
progress_callback(0.3, "Calling LLM...")
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
temperature=self.agent.config.temperature,
max_tokens=self.agent.config.max_tokens
)
return TaskResult.ok(
data=response.content,
output_data=task_node.task_data.get("output_template", {}).copy()
)
except Exception as e:
return TaskResult.fail(error=str(e))
async def _execute_code_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute code generation/writing task"""
task_data = task_node.task_data
# Build prompt for code generation
prompt = task_data.get("prompt", task_node.description)
language = task_data.get("language", "python")
requirements = task_data.get("requirements", "")
messages = [
{"role": "system", "content": f"You are a {language} programmer. Write clean, efficient code."},
{"role": "user", "content": f"Task: {prompt}\n\nRequirements: {requirements}"}
]
if progress_callback:
progress_callback(0.3, "Generating code...")
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
temperature=0.2, # Lower temp for code
max_tokens=4096
)
return TaskResult.ok(
data=response.content,
output_data={
"code": response.content,
"language": language
}
)
except Exception as e:
return TaskResult.fail(error=str(e))
async def _execute_shell_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute shell command task"""
task_data = task_node.task_data
command = task_data.get("command")
if not command:
return TaskResult.fail(error="No command specified")
if progress_callback:
progress_callback(0.3, f"Executing: {command[:50]}...")
# Build tool context
tool_ctx = ToolContext(
workspace=context.get("workspace"),
user_id=context.get("user_id"),
username=context.get("username"),
extra={"user_permission_level": context.get("user_permission_level", 1)}
)
try:
# Execute shell command via tool
result = tool_registry.execute(
"shell_exec",
{"command": command},
context=tool_ctx
)
if result.get("success"):
return TaskResult.ok(
data=result.get("data", {}).get("output", ""),
output_data={"output": result.get("data", {}).get("output", "")}
)
else:
return TaskResult.fail(error=result.get("error", "Shell execution failed"))
except Exception as e:
return TaskResult.fail(error=str(e))
async def _execute_file_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute file operation task"""
task_data = task_node.task_data
operation = task_data.get("operation")
file_path = task_data.get("path")
content = task_data.get("content", "")
if not operation or not file_path:
return TaskResult.fail(error="Missing operation or path")
if progress_callback:
progress_callback(0.3, f"File operation: {operation} {file_path}")
tool_ctx = ToolContext(
workspace=context.get("workspace"),
user_id=context.get("user_id"),
username=context.get("username"),
extra={"user_permission_level": context.get("user_permission_level", 1)}
)
try:
tool_name = f"file_{operation}"
result = tool_registry.execute(
tool_name,
{"path": file_path, "content": content},
context=tool_ctx
)
if result.get("success"):
return TaskResult.ok(
data=result.get("data"),
output_data={"path": file_path, "operation": operation}
)
else:
return TaskResult.fail(error=result.get("error", "File operation failed"))
except Exception as e:
return TaskResult.fail(error=str(e))
async def _execute_generic_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute generic task using LLM with tools"""
task_data = task_node.task_data
# Build prompt
prompt = task_data.get("prompt", task_node.description)
tools = task_data.get("tools", [])
messages = [
{"role": "system", "content": self.agent.config.system_prompt},
{"role": "user", "content": prompt}
]
# Get tool definitions if specified
tool_defs = None
if tools:
tool_defs = [tool_registry.get(t).to_openai_format() for t in tools if tool_registry.get(t)]
if progress_callback:
progress_callback(0.2, "Processing task...")
max_iterations = 5
iteration = 0
while iteration < max_iterations:
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
tools=tool_defs,
temperature=self.agent.config.temperature,
max_tokens=self.agent.config.max_tokens
)
# Add assistant response
messages.append({"role": "assistant", "content": response.content})
# Check for tool calls
if response.tool_calls:
if progress_callback:
progress_callback(0.5, f"Executing {len(response.tool_calls)} tools...")
# Execute tools
tool_results = self.tool_executor.process_tool_calls(
response.tool_calls,
context
)
# Add tool results
for tr in tool_results:
messages.append({
"role": "tool",
"tool_call_id": tr["tool_call_id"],
"content": tr["content"]
})
if progress_callback:
progress_callback(0.8, "Tools executed")
else:
# No tool calls, task complete
return TaskResult.ok(
data=response.content,
output_data=task_data.get("output_template", {})
)
except Exception as e:
return TaskResult.fail(error=str(e))
iteration += 1
return TaskResult.fail(error="Max iterations exceeded")

View File

@ -2,6 +2,7 @@
from fastapi import APIRouter from fastapi import APIRouter
from luxx.routes import auth, conversations, messages, tools, providers from luxx.routes import auth, conversations, messages, tools, providers
from luxx.routes.agents_ws import router as agents_ws_router
api_router = APIRouter() api_router = APIRouter()
@ -12,3 +13,4 @@ api_router.include_router(conversations.router)
api_router.include_router(messages.router) api_router.include_router(messages.router)
api_router.include_router(tools.router) api_router.include_router(tools.router)
api_router.include_router(providers.router) api_router.include_router(providers.router)
api_router.include_router(agents_ws_router)

385
luxx/routes/agents_ws.py Normal file
View File

@ -0,0 +1,385 @@
"""WebSocket routes for agent real-time communication"""
import asyncio
import json
import logging
import threading
from typing import Any, Dict, Optional, Set
from datetime import datetime
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
import uuid
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
from luxx.agents.supervisor import SupervisorAgent
from luxx.agents.worker import WorkerAgent
from luxx.agents.dag_scheduler import DAGScheduler, SchedulerPool
from luxx.agents.registry import AgentRegistry
from luxx.services.llm_client import llm_client
from luxx.tools.executor import ToolExecutor
logger = logging.getLogger(__name__)
router = APIRouter()
class ConnectionManager:
"""
WebSocket Connection Manager
Manages WebSocket connections for real-time agent progress updates.
Features:
- Connection tracking by task_id
- Heartbeat mechanism
- Progress broadcasting
"""
def __init__(self):
# task_id -> set of websocket connections
self._connections: Dict[str, Set[WebSocket]] = {}
# websocket -> task_id mapping
self._ws_to_task: Dict[WebSocket, str] = {}
# heartbeat tasks
self._heartbeat_tasks: Dict[WebSocket, asyncio.Task] = {}
# lock for thread safety
self._lock = threading.Lock()
# heartbeat interval in seconds
self._heartbeat_interval = 30
def connect(self, websocket: WebSocket, task_id: str) -> None:
"""
Accept and register a WebSocket connection
Args:
websocket: WebSocket connection
task_id: Task ID to subscribe to
"""
websocket.accept()
with self._lock:
if task_id not in self._connections:
self._connections[task_id] = set()
self._connections[task_id].add(websocket)
self._ws_to_task[websocket] = task_id
logger.info(f"WebSocket connected for task {task_id}")
def disconnect(self, websocket: WebSocket) -> None:
"""
Unregister a WebSocket connection
Args:
websocket: WebSocket connection
"""
with self._lock:
task_id = self._ws_to_task.pop(websocket, None)
if task_id and task_id in self._connections:
self._connections[task_id].discard(websocket)
if not self._connections[task_id]:
del self._connections[task_id]
# Cancel heartbeat task
if websocket in self._heartbeat_tasks:
self._heartbeat_tasks[websocket].cancel()
del self._heartbeat_tasks[websocket]
if task_id:
logger.info(f"WebSocket disconnected for task {task_id}")
async def send_to_task(self, task_id: str, message: Dict[str, Any]) -> None:
"""
Send message to all connections subscribed to a task
Args:
task_id: Task ID
message: Message to send
"""
with self._lock:
connections = self._connections.get(task_id, set()).copy()
if not connections:
return
dead_connections = set()
for websocket in connections:
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"Failed to send to websocket: {e}")
dead_connections.add(websocket)
# Clean up dead connections
for ws in dead_connections:
self.disconnect(ws)
async def broadcast(self, message: Dict[str, Any]) -> None:
"""
Broadcast message to all connections
Args:
message: Message to broadcast
"""
with self._lock:
all_connections = list(self._ws_to_task.keys())
for websocket in all_connections:
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"Failed to broadcast: {e}")
async def send_personal(self, websocket: WebSocket, message: Dict[str, Any]) -> None:
"""
Send message to a specific connection
Args:
websocket: Target WebSocket
message: Message to send
"""
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"Failed to send personal message: {e}")
def start_heartbeat(self, websocket: WebSocket) -> None:
"""
Start heartbeat for a connection
Args:
websocket: WebSocket connection
"""
async def heartbeat_loop():
while True:
await asyncio.sleep(self._heartbeat_interval)
try:
await websocket.send_json({
"type": "heartbeat",
"interval": self._heartbeat_interval
})
except Exception:
break
task = asyncio.create_task(heartbeat_loop())
with self._lock:
self._heartbeat_tasks[websocket] = task
@property
def connection_count(self) -> int:
"""Total number of connections"""
with self._lock:
return len(self._ws_to_task)
def get_task_connections(self, task_id: str) -> int:
"""Get number of connections for a task"""
with self._lock:
return len(self._connections.get(task_id, set()))
# Global connection manager
connection_manager = ConnectionManager()
# Global scheduler pool
scheduler_pool = SchedulerPool(max_concurrent=10)
# Global tool executor
tool_executor = ToolExecutor()
class ProgressEmitter:
"""
Progress emitter that sends updates via WebSocket
Wraps DAGScheduler progress callback to emit WebSocket messages.
"""
def __init__(self, task_id: str, connection_manager: ConnectionManager):
self.task_id = task_id
self.connection_manager = connection_manager
def __call__(self, node_id: str, progress: float, message: str = "") -> None:
"""Progress callback"""
if node_id == "dag":
# DAG-level progress
asyncio.create_task(self.connection_manager.send_to_task(
self.task_id,
{
"type": "dag_progress",
"data": {
"progress": progress,
"message": message
}
}
))
else:
# Node-level progress
asyncio.create_task(self.connection_manager.send_to_task(
self.task_id,
{
"type": "node_progress",
"data": {
"node_id": node_id,
"progress": progress,
"message": message
}
}
))
@router.websocket("/ws/dag/{task_id}")
async def dag_websocket(websocket: WebSocket, task_id: str):
"""
WebSocket endpoint for DAG progress updates
Protocol:
- Client sends: subscribe, get_status, ping, cancel_task
- Server sends: subscribed, heartbeat, dag_start, node_start, node_progress,
node_complete, node_error, dag_complete, pong
"""
# Accept connection
connection_manager.connect(websocket, task_id)
# Send subscribed confirmation
await connection_manager.send_personal(websocket, {
"type": "subscribed",
"task_id": task_id
})
# Start heartbeat
connection_manager.start_heartbeat(websocket)
# Get scheduler if exists
scheduler = scheduler_pool.get(task_id)
if scheduler:
# Send current DAG state
await connection_manager.send_personal(websocket, {
"type": "dag_status",
"data": scheduler.dag.to_dict()
})
try:
while True:
# Receive message
data = await websocket.receive_text()
try:
message = json.loads(data)
except json.JSONDecodeError:
await connection_manager.send_personal(websocket, {
"type": "error",
"message": "Invalid JSON"
})
continue
msg_type = message.get("type")
if msg_type == "subscribe":
# Already subscribed on connect
await connection_manager.send_personal(websocket, {
"type": "subscribed",
"task_id": task_id
})
elif msg_type == "get_status":
scheduler = scheduler_pool.get(task_id)
if scheduler:
await connection_manager.send_personal(websocket, {
"type": "dag_status",
"data": scheduler.dag.to_dict()
})
else:
await connection_manager.send_personal(websocket, {
"type": "dag_status",
"data": None
})
elif msg_type == "ping":
await connection_manager.send_personal(websocket, {
"type": "pong"
})
elif msg_type == "cancel_task":
if scheduler_pool.cancel(task_id):
await connection_manager.send_personal(websocket, {
"type": "task_cancelled",
"task_id": task_id
})
else:
await connection_manager.send_personal(websocket, {
"type": "error",
"message": "Task not found or already completed"
})
else:
await connection_manager.send_personal(websocket, {
"type": "error",
"message": f"Unknown message type: {msg_type}"
})
except WebSocketDisconnect:
connection_manager.disconnect(websocket)
except Exception as e:
logger.error(f"WebSocket error: {e}")
connection_manager.disconnect(websocket)
# Helper functions to emit progress from scheduler
def create_progress_emitter(task_id: str) -> ProgressEmitter:
"""Create a progress emitter for a task"""
return ProgressEmitter(task_id, connection_manager)
async def emit_dag_start(task_id: str, dag: DAG) -> None:
"""Emit DAG start event"""
await connection_manager.send_to_task(task_id, {
"type": "dag_start",
"data": {
"graph": dag.to_dict()
}
})
async def emit_node_start(task_id: str, node: TaskNode) -> None:
"""Emit node start event"""
await connection_manager.send_to_task(task_id, {
"type": "node_start",
"data": {
"node_id": node.id,
"name": node.name,
"status": node.status.value
}
})
async def emit_node_complete(task_id: str, node: TaskNode) -> None:
"""Emit node complete event"""
await connection_manager.send_to_task(task_id, {
"type": "node_complete",
"data": {
"node_id": node.id,
"result": node.result.to_dict() if node.result else None,
"output_data": node.output_data
}
})
async def emit_node_error(task_id: str, node: TaskNode, error: str) -> None:
"""Emit node error event"""
await connection_manager.send_to_task(task_id, {
"type": "node_error",
"data": {
"node_id": node.id,
"error": error
}
})
async def emit_dag_complete(task_id: str, success: bool, results: Dict) -> None:
"""Emit DAG complete event"""
await connection_manager.send_to_task(task_id, {
"type": "dag_complete",
"data": {
"success": success,
"results": results
}
})