Luxx/luxx/agents/supervisor.py

346 lines
11 KiB
Python

"""Supervisor Agent - task decomposition and result integration"""
import json
import logging
from typing import Any, Dict, List, Optional, Callable
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
from luxx.services.llm_client import llm_client
from luxx.tools.core import registry as tool_registry
logger = logging.getLogger(__name__)
class SupervisorAgent:
"""
Supervisor Agent
Responsible for:
- Task decomposition using LLM
- Generating DAG (task graph)
- Result integration from workers
"""
# System prompt for task decomposition
DEFAULT_SYSTEM_PROMPT = """You are a Supervisor Agent that decomposes complex tasks into executable subtasks.
Your responsibilities:
1. Analyze the user's task and break it down into smaller, manageable subtasks
2. Create a DAG (Directed Acyclic Graph) where nodes are subtasks and edges represent dependencies
3. Each subtask should be specific and actionable
4. Consider parallel execution opportunities - tasks without dependencies can run concurrently
5. Store key results from subtasks for final integration
Output format for task decomposition:
{
"task_name": "Overall task name",
"task_description": "Description of what needs to be accomplished",
"nodes": [
{
"id": "task_001",
"name": "Task name",
"description": "What this task does",
"task_type": "code|shell|file|llm|generic",
"task_data": {...}, # Task-specific parameters
"dependencies": [] # IDs of tasks that must complete first
}
]
}
Guidelines:
- Keep tasks focused and atomic
- Use meaningful task IDs (e.g., task_001, task_002)
- Mark parallelizable tasks with no dependencies
- Maximum 10 subtasks for a single decomposition
- Include only the output_data that matters for dependent tasks or final result
"""
def __init__(
self,
agent: Agent,
llm_client=None,
max_subtasks: int = 10
):
"""
Initialize Supervisor Agent
Args:
agent: Agent instance (should be SUPERVISOR type)
llm_client: LLM client instance
max_subtasks: Maximum number of subtasks to generate
"""
self.agent = agent
self.llm_client = llm_client or llm_client
self.max_subtasks = max_subtasks
# Ensure agent has supervisor system prompt
if not self.agent.config.system_prompt:
self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT
async def decompose_task(
self,
task: str,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> DAG:
"""
Decompose a task into subtasks using LLM
Args:
task: User's task description
context: Execution context (workspace, user info, etc.)
progress_callback: Optional callback for progress updates
Returns:
DAG representing the task decomposition
"""
self.agent.status = AgentStatus.PLANNING
if progress_callback:
progress_callback(0.1, "Analyzing task...")
# Build messages for LLM
messages = self.agent.get_context()
messages.append({
"role": "user",
"content": f"Decompose this task into subtasks:\n{task}"
})
if progress_callback:
progress_callback(0.2, "Calling LLM for task decomposition...")
# Call LLM
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
temperature=self.agent.config.temperature,
max_tokens=self.agent.config.max_tokens
)
if progress_callback:
progress_callback(0.5, "Processing decomposition...")
# Parse LLM response to extract DAG
dag = self._parse_dag_from_response(response.content, task)
# Add assistant response to context
self.agent.add_message("assistant", response.content)
if progress_callback:
progress_callback(0.9, "Task decomposition complete")
self.agent.status = AgentStatus.IDLE
return dag
except Exception as e:
logger.error(f"Task decomposition failed: {e}")
self.agent.status = AgentStatus.FAILED
raise
def _parse_dag_from_response(self, content: str, original_task: str) -> DAG:
"""
Parse LLM response to extract DAG structure
Args:
content: LLM response content
original_task: Original task description
Returns:
DAG instance
"""
# Try to extract JSON from response
dag_data = self._extract_json(content)
if not dag_data:
# Fallback: create a simple single-node DAG
logger.warning("Could not parse DAG from LLM response, creating simple DAG")
dag = DAG(
id=f"dag_{self.agent.id}",
name=original_task[:50],
description=original_task
)
node = TaskNode(
id="task_001",
name="Execute Task",
description=original_task,
task_type="llm",
task_data={"prompt": original_task}
)
dag.add_node(node)
return dag
# Build DAG from parsed data
dag = DAG(
id=f"dag_{self.agent.id}",
name=dag_data.get("task_name", original_task[:50]),
description=dag_data.get("task_description", original_task)
)
# Add nodes
for node_data in dag_data.get("nodes", []):
node = TaskNode(
id=node_data["id"],
name=node_data["name"],
description=node_data.get("description", ""),
task_type=node_data.get("task_type", "generic"),
task_data=node_data.get("task_data", {})
)
dag.add_node(node)
# Add edges based on dependencies
for node_data in dag_data.get("nodes", []):
node_id = node_data["id"]
for dep_id in node_data.get("dependencies", []):
if dep_id in dag.nodes:
dag.add_edge(dep_id, node_id)
return dag
def _extract_json(self, content: str) -> Optional[Dict]:
"""
Extract JSON from LLM response
Args:
content: Raw LLM response
Returns:
Parsed JSON dict or None
"""
# Try to find JSON in markdown code blocks
import re
# Look for ```json ... ``` blocks
json_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(1))
except json.JSONDecodeError:
pass
# Look for raw JSON object
json_match = re.search(r"\{.*\}", content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(0))
except json.JSONDecodeError:
pass
return None
def integrate_results(self, dag: DAG) -> Dict[str, Any]:
"""
Integrate results from all completed tasks
Args:
dag: DAG with completed tasks
Returns:
Integrated result
"""
self.agent.status = AgentStatus.EXECUTING
# Collect all output data from completed nodes
results = {}
for node in dag.nodes.values():
if node.status == TaskNodeStatus.COMPLETED and node.output_data:
results[node.id] = node.output_data
# Store aggregated results
self.agent.accumulated_result = results
self.agent.status = AgentStatus.COMPLETED
return {
"success": True,
"dag_id": dag.id,
"total_tasks": dag.total_count,
"completed_tasks": dag.completed_count,
"failed_tasks": dag.failed_count,
"results": results
}
async def review_and_refine(
self,
dag: DAG,
task: str,
progress_callback: Optional[Callable] = None
) -> Optional[DAG]:
"""
Review DAG execution and refine if needed
Args:
dag: Current DAG state
task: Original task
progress_callback: Progress callback
Returns:
Refined DAG or None if no refinement needed
"""
if dag.is_success:
return None # No refinement needed
# Check if there are failed tasks
failed_nodes = [n for n in dag.nodes.values() if n.status == TaskNodeStatus.FAILED]
if not failed_nodes:
return None
# Build context for refinement
context = {
"original_task": task,
"failed_tasks": [
{
"id": n.id,
"name": n.name,
"error": n.result.error if n.result else "Unknown error"
}
for n in failed_nodes
],
"completed_tasks": [
{
"id": n.id,
"name": n.name,
"output": n.output_data
}
for n in dag.nodes.values() if n.status == TaskNodeStatus.COMPLETED
]
}
messages = self.agent.get_context()
messages.append({
"role": "user",
"content": f"""Review the task execution and suggest refinements:
Task: {task}
Failed tasks: {json.dumps(context['failed_tasks'], indent=2)}
Completed tasks: {json.dumps(context['completed_tasks'], indent=2)}
If a task failed, you can:
1. Break it into smaller tasks
2. Change the approach
3. Skip it if not critical
Provide a refined subtask plan if needed, or indicate if the overall task should fail."""
})
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
temperature=self.agent.config.temperature
)
# Check if refinement was suggested
refined_dag = self._parse_dag_from_response(response.content, task)
# Only return if we got a valid refinement
if refined_dag and refined_dag.nodes:
return refined_dag
except Exception as e:
logger.error(f"DAG refinement failed: {e}")
return None