346 lines
11 KiB
Python
346 lines
11 KiB
Python
"""Supervisor Agent - task decomposition and result integration"""
|
|
import json
|
|
import logging
|
|
from typing import Any, Dict, List, Optional, Callable
|
|
|
|
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
|
|
from luxx.agents.dag import DAG, TaskNode, TaskNodeStatus, TaskResult
|
|
from luxx.services.llm_client import llm_client
|
|
from luxx.tools.core import registry as tool_registry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SupervisorAgent:
|
|
"""
|
|
Supervisor Agent
|
|
|
|
Responsible for:
|
|
- Task decomposition using LLM
|
|
- Generating DAG (task graph)
|
|
- Result integration from workers
|
|
"""
|
|
|
|
# System prompt for task decomposition
|
|
DEFAULT_SYSTEM_PROMPT = """You are a Supervisor Agent that decomposes complex tasks into executable subtasks.
|
|
|
|
Your responsibilities:
|
|
1. Analyze the user's task and break it down into smaller, manageable subtasks
|
|
2. Create a DAG (Directed Acyclic Graph) where nodes are subtasks and edges represent dependencies
|
|
3. Each subtask should be specific and actionable
|
|
4. Consider parallel execution opportunities - tasks without dependencies can run concurrently
|
|
5. Store key results from subtasks for final integration
|
|
|
|
Output format for task decomposition:
|
|
{
|
|
"task_name": "Overall task name",
|
|
"task_description": "Description of what needs to be accomplished",
|
|
"nodes": [
|
|
{
|
|
"id": "task_001",
|
|
"name": "Task name",
|
|
"description": "What this task does",
|
|
"task_type": "code|shell|file|llm|generic",
|
|
"task_data": {...}, # Task-specific parameters
|
|
"dependencies": [] # IDs of tasks that must complete first
|
|
}
|
|
]
|
|
}
|
|
|
|
Guidelines:
|
|
- Keep tasks focused and atomic
|
|
- Use meaningful task IDs (e.g., task_001, task_002)
|
|
- Mark parallelizable tasks with no dependencies
|
|
- Maximum 10 subtasks for a single decomposition
|
|
- Include only the output_data that matters for dependent tasks or final result
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
agent: Agent,
|
|
llm_client=None,
|
|
max_subtasks: int = 10
|
|
):
|
|
"""
|
|
Initialize Supervisor Agent
|
|
|
|
Args:
|
|
agent: Agent instance (should be SUPERVISOR type)
|
|
llm_client: LLM client instance
|
|
max_subtasks: Maximum number of subtasks to generate
|
|
"""
|
|
self.agent = agent
|
|
self.llm_client = llm_client or llm_client
|
|
self.max_subtasks = max_subtasks
|
|
|
|
# Ensure agent has supervisor system prompt
|
|
if not self.agent.config.system_prompt:
|
|
self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT
|
|
|
|
async def decompose_task(
|
|
self,
|
|
task: str,
|
|
context: Dict[str, Any],
|
|
progress_callback: Optional[Callable] = None
|
|
) -> DAG:
|
|
"""
|
|
Decompose a task into subtasks using LLM
|
|
|
|
Args:
|
|
task: User's task description
|
|
context: Execution context (workspace, user info, etc.)
|
|
progress_callback: Optional callback for progress updates
|
|
|
|
Returns:
|
|
DAG representing the task decomposition
|
|
"""
|
|
self.agent.status = AgentStatus.PLANNING
|
|
|
|
if progress_callback:
|
|
progress_callback(0.1, "Analyzing task...")
|
|
|
|
# Build messages for LLM
|
|
messages = self.agent.get_context()
|
|
messages.append({
|
|
"role": "user",
|
|
"content": f"Decompose this task into subtasks:\n{task}"
|
|
})
|
|
|
|
if progress_callback:
|
|
progress_callback(0.2, "Calling LLM for task decomposition...")
|
|
|
|
# Call LLM
|
|
try:
|
|
response = await self.llm_client.sync_call(
|
|
model=self.agent.config.model,
|
|
messages=messages,
|
|
temperature=self.agent.config.temperature,
|
|
max_tokens=self.agent.config.max_tokens
|
|
)
|
|
|
|
if progress_callback:
|
|
progress_callback(0.5, "Processing decomposition...")
|
|
|
|
# Parse LLM response to extract DAG
|
|
dag = self._parse_dag_from_response(response.content, task)
|
|
|
|
# Add assistant response to context
|
|
self.agent.add_message("assistant", response.content)
|
|
|
|
if progress_callback:
|
|
progress_callback(0.9, "Task decomposition complete")
|
|
|
|
self.agent.status = AgentStatus.IDLE
|
|
return dag
|
|
|
|
except Exception as e:
|
|
logger.error(f"Task decomposition failed: {e}")
|
|
self.agent.status = AgentStatus.FAILED
|
|
raise
|
|
|
|
def _parse_dag_from_response(self, content: str, original_task: str) -> DAG:
|
|
"""
|
|
Parse LLM response to extract DAG structure
|
|
|
|
Args:
|
|
content: LLM response content
|
|
original_task: Original task description
|
|
|
|
Returns:
|
|
DAG instance
|
|
"""
|
|
# Try to extract JSON from response
|
|
dag_data = self._extract_json(content)
|
|
|
|
if not dag_data:
|
|
# Fallback: create a simple single-node DAG
|
|
logger.warning("Could not parse DAG from LLM response, creating simple DAG")
|
|
dag = DAG(
|
|
id=f"dag_{self.agent.id}",
|
|
name=original_task[:50],
|
|
description=original_task
|
|
)
|
|
node = TaskNode(
|
|
id="task_001",
|
|
name="Execute Task",
|
|
description=original_task,
|
|
task_type="llm",
|
|
task_data={"prompt": original_task}
|
|
)
|
|
dag.add_node(node)
|
|
return dag
|
|
|
|
# Build DAG from parsed data
|
|
dag = DAG(
|
|
id=f"dag_{self.agent.id}",
|
|
name=dag_data.get("task_name", original_task[:50]),
|
|
description=dag_data.get("task_description", original_task)
|
|
)
|
|
|
|
# Add nodes
|
|
for node_data in dag_data.get("nodes", []):
|
|
node = TaskNode(
|
|
id=node_data["id"],
|
|
name=node_data["name"],
|
|
description=node_data.get("description", ""),
|
|
task_type=node_data.get("task_type", "generic"),
|
|
task_data=node_data.get("task_data", {})
|
|
)
|
|
dag.add_node(node)
|
|
|
|
# Add edges based on dependencies
|
|
for node_data in dag_data.get("nodes", []):
|
|
node_id = node_data["id"]
|
|
for dep_id in node_data.get("dependencies", []):
|
|
if dep_id in dag.nodes:
|
|
dag.add_edge(dep_id, node_id)
|
|
|
|
return dag
|
|
|
|
def _extract_json(self, content: str) -> Optional[Dict]:
|
|
"""
|
|
Extract JSON from LLM response
|
|
|
|
Args:
|
|
content: Raw LLM response
|
|
|
|
Returns:
|
|
Parsed JSON dict or None
|
|
"""
|
|
# Try to find JSON in markdown code blocks
|
|
import re
|
|
|
|
# Look for ```json ... ``` blocks
|
|
json_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, re.DOTALL)
|
|
if json_match:
|
|
try:
|
|
return json.loads(json_match.group(1))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Look for raw JSON object
|
|
json_match = re.search(r"\{.*\}", content, re.DOTALL)
|
|
if json_match:
|
|
try:
|
|
return json.loads(json_match.group(0))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return None
|
|
|
|
def integrate_results(self, dag: DAG) -> Dict[str, Any]:
|
|
"""
|
|
Integrate results from all completed tasks
|
|
|
|
Args:
|
|
dag: DAG with completed tasks
|
|
|
|
Returns:
|
|
Integrated result
|
|
"""
|
|
self.agent.status = AgentStatus.EXECUTING
|
|
|
|
# Collect all output data from completed nodes
|
|
results = {}
|
|
for node in dag.nodes.values():
|
|
if node.status == TaskNodeStatus.COMPLETED and node.output_data:
|
|
results[node.id] = node.output_data
|
|
|
|
# Store aggregated results
|
|
self.agent.accumulated_result = results
|
|
self.agent.status = AgentStatus.COMPLETED
|
|
|
|
return {
|
|
"success": True,
|
|
"dag_id": dag.id,
|
|
"total_tasks": dag.total_count,
|
|
"completed_tasks": dag.completed_count,
|
|
"failed_tasks": dag.failed_count,
|
|
"results": results
|
|
}
|
|
|
|
async def review_and_refine(
|
|
self,
|
|
dag: DAG,
|
|
task: str,
|
|
progress_callback: Optional[Callable] = None
|
|
) -> Optional[DAG]:
|
|
"""
|
|
Review DAG execution and refine if needed
|
|
|
|
Args:
|
|
dag: Current DAG state
|
|
task: Original task
|
|
progress_callback: Progress callback
|
|
|
|
Returns:
|
|
Refined DAG or None if no refinement needed
|
|
"""
|
|
if dag.is_success:
|
|
return None # No refinement needed
|
|
|
|
# Check if there are failed tasks
|
|
failed_nodes = [n for n in dag.nodes.values() if n.status == TaskNodeStatus.FAILED]
|
|
|
|
if not failed_nodes:
|
|
return None
|
|
|
|
# Build context for refinement
|
|
context = {
|
|
"original_task": task,
|
|
"failed_tasks": [
|
|
{
|
|
"id": n.id,
|
|
"name": n.name,
|
|
"error": n.result.error if n.result else "Unknown error"
|
|
}
|
|
for n in failed_nodes
|
|
],
|
|
"completed_tasks": [
|
|
{
|
|
"id": n.id,
|
|
"name": n.name,
|
|
"output": n.output_data
|
|
}
|
|
for n in dag.nodes.values() if n.status == TaskNodeStatus.COMPLETED
|
|
]
|
|
}
|
|
|
|
messages = self.agent.get_context()
|
|
messages.append({
|
|
"role": "user",
|
|
"content": f"""Review the task execution and suggest refinements:
|
|
|
|
Task: {task}
|
|
|
|
Failed tasks: {json.dumps(context['failed_tasks'], indent=2)}
|
|
|
|
Completed tasks: {json.dumps(context['completed_tasks'], indent=2)}
|
|
|
|
If a task failed, you can:
|
|
1. Break it into smaller tasks
|
|
2. Change the approach
|
|
3. Skip it if not critical
|
|
|
|
Provide a refined subtask plan if needed, or indicate if the overall task should fail."""
|
|
})
|
|
|
|
try:
|
|
response = await self.llm_client.sync_call(
|
|
model=self.agent.config.model,
|
|
messages=messages,
|
|
temperature=self.agent.config.temperature
|
|
)
|
|
|
|
# Check if refinement was suggested
|
|
refined_dag = self._parse_dag_from_response(response.content, task)
|
|
|
|
# Only return if we got a valid refinement
|
|
if refined_dag and refined_dag.nodes:
|
|
return refined_dag
|
|
|
|
except Exception as e:
|
|
logger.error(f"DAG refinement failed: {e}")
|
|
|
|
return None
|