Luxx/luxx/agents/worker.py

402 lines
14 KiB
Python

"""Worker Agent - executes specific tasks"""
import json
import logging
import time
from typing import Any, Dict, List, Optional, Callable
from luxx.agents.core import Agent, AgentConfig, AgentType, AgentStatus
from luxx.agents.dag import TaskNode, TaskNodeStatus, TaskResult
from luxx.services.llm_client import llm_client
from luxx.tools.core import registry as tool_registry, ToolContext, CommandPermission
logger = logging.getLogger(__name__)
class WorkerAgent:
"""
Worker Agent
Responsible for executing specific tasks using:
- LLM calls for reasoning tasks
- Tool execution for actionable tasks
Follows sliding window context management.
"""
# System prompt for worker tasks
DEFAULT_SYSTEM_PROMPT = """You are a Worker Agent that executes specific tasks efficiently.
Your responsibilities:
1. Execute tasks assigned to you by the Supervisor
2. Use appropriate tools when needed
3. Report results clearly with structured output_data for dependent tasks
4. Be concise and focused on the task at hand
Output format:
- Provide clear, structured results
- Include output_data for any data that dependent tasks might need
- If a tool fails, explain the error clearly
"""
def __init__(
self,
agent: Agent,
llm_client=None,
tool_executor=None
):
"""
Initialize Worker Agent
Args:
agent: Agent instance (should be WORKER type)
llm_client: LLM client instance
tool_executor: Tool executor instance
"""
self.agent = agent
self.llm_client = llm_client or llm_client
self.tool_executor = tool_executor
# Ensure agent has worker system prompt
if not self.agent.config.system_prompt:
self.agent.config.system_prompt = self.DEFAULT_SYSTEM_PROMPT
async def execute_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
parent_outputs: Dict[str, Dict[str, Any]] = None,
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""
Execute a task node
Args:
task_node: Task node to execute
context: Execution context (workspace, user info, etc.)
parent_outputs: Output data from parent tasks (dependency results)
progress_callback: Optional callback for progress updates
Returns:
TaskResult with execution outcome
"""
self.agent.status = AgentStatus.EXECUTING
self.agent.current_task_id = task_node.id
start_time = time.time()
if progress_callback:
progress_callback(0.0, f"Starting task: {task_node.name}")
try:
# Merge parent outputs into context
execution_context = self._prepare_context(context, parent_outputs)
# Execute based on task type
if task_node.task_type == "llm":
result = await self._execute_llm_task(task_node, execution_context, progress_callback)
elif task_node.task_type == "code":
result = await self._execute_code_task(task_node, execution_context, progress_callback)
elif task_node.task_type == "shell":
result = await self._execute_shell_task(task_node, execution_context, progress_callback)
elif task_node.task_type == "file":
result = await self._execute_file_task(task_node, execution_context, progress_callback)
else:
result = await self._execute_generic_task(task_node, execution_context, progress_callback)
execution_time = time.time() - start_time
result.execution_time = execution_time
if progress_callback:
progress_callback(1.0, f"Task complete: {task_node.name}")
self.agent.status = AgentStatus.IDLE
return result
except Exception as e:
logger.error(f"Task execution failed: {e}")
execution_time = time.time() - start_time
self.agent.status = AgentStatus.FAILED
return TaskResult.fail(error=str(e))
def _prepare_context(
self,
context: Dict[str, Any],
parent_outputs: Dict[str, Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Prepare execution context by merging parent outputs
Args:
context: Base context
parent_outputs: Output from parent tasks
Returns:
Merged context
"""
execution_context = context.copy()
if parent_outputs:
# Merge parent outputs into context
merged = {}
for parent_id, outputs in parent_outputs.items():
merged.update(outputs)
execution_context["parent_outputs"] = parent_outputs
execution_context["merged_data"] = merged
# Add user permission level
if "user_permission_level" not in execution_context:
execution_context["user_permission_level"] = self.agent.effective_permission.value
return execution_context
async def _execute_llm_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute LLM reasoning task"""
task_data = task_node.task_data
# Build prompt
prompt = task_data.get("prompt", task_node.description)
system_prompt = task_data.get("system", self.agent.config.system_prompt)
messages = [{"role": "system", "content": system_prompt}]
# Add parent data if available
if "merged_data" in context:
merged = context["merged_data"]
context_info = "\n".join([f"{k}: {v}" for k, v in merged.items()])
messages.append({
"role": "system",
"content": f"Context from dependent tasks:\n{context_info}"
})
messages.append({"role": "user", "content": prompt})
if progress_callback:
progress_callback(0.3, "Calling LLM...")
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
temperature=self.agent.config.temperature,
max_tokens=self.agent.config.max_tokens
)
return TaskResult.ok(
data=response.content,
output_data=task_node.task_data.get("output_template", {}).copy()
)
except Exception as e:
return TaskResult.fail(error=str(e))
async def _execute_code_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute code generation/writing task"""
task_data = task_node.task_data
# Build prompt for code generation
prompt = task_data.get("prompt", task_node.description)
language = task_data.get("language", "python")
requirements = task_data.get("requirements", "")
messages = [
{"role": "system", "content": f"You are a {language} programmer. Write clean, efficient code."},
{"role": "user", "content": f"Task: {prompt}\n\nRequirements: {requirements}"}
]
if progress_callback:
progress_callback(0.3, "Generating code...")
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
temperature=0.2, # Lower temp for code
max_tokens=4096
)
return TaskResult.ok(
data=response.content,
output_data={
"code": response.content,
"language": language
}
)
except Exception as e:
return TaskResult.fail(error=str(e))
async def _execute_shell_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute shell command task"""
task_data = task_node.task_data
command = task_data.get("command")
if not command:
return TaskResult.fail(error="No command specified")
if progress_callback:
progress_callback(0.3, f"Executing: {command[:50]}...")
# Build tool context
tool_ctx = ToolContext(
workspace=context.get("workspace"),
user_id=context.get("user_id"),
username=context.get("username"),
extra={"user_permission_level": context.get("user_permission_level", 1)}
)
try:
# Execute shell command via tool
result = tool_registry.execute(
"shell_exec",
{"command": command},
context=tool_ctx
)
if result.get("success"):
return TaskResult.ok(
data=result.get("data", {}).get("output", ""),
output_data={"output": result.get("data", {}).get("output", "")}
)
else:
return TaskResult.fail(error=result.get("error", "Shell execution failed"))
except Exception as e:
return TaskResult.fail(error=str(e))
async def _execute_file_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute file operation task"""
task_data = task_node.task_data
operation = task_data.get("operation")
file_path = task_data.get("path")
content = task_data.get("content", "")
if not operation or not file_path:
return TaskResult.fail(error="Missing operation or path")
if progress_callback:
progress_callback(0.3, f"File operation: {operation} {file_path}")
tool_ctx = ToolContext(
workspace=context.get("workspace"),
user_id=context.get("user_id"),
username=context.get("username"),
extra={"user_permission_level": context.get("user_permission_level", 1)}
)
try:
tool_name = f"file_{operation}"
result = tool_registry.execute(
tool_name,
{"path": file_path, "content": content},
context=tool_ctx
)
if result.get("success"):
return TaskResult.ok(
data=result.get("data"),
output_data={"path": file_path, "operation": operation}
)
else:
return TaskResult.fail(error=result.get("error", "File operation failed"))
except Exception as e:
return TaskResult.fail(error=str(e))
async def _execute_generic_task(
self,
task_node: TaskNode,
context: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> TaskResult:
"""Execute generic task using LLM with tools"""
task_data = task_node.task_data
# Build prompt
prompt = task_data.get("prompt", task_node.description)
tools = task_data.get("tools", [])
messages = [
{"role": "system", "content": self.agent.config.system_prompt},
{"role": "user", "content": prompt}
]
# Get tool definitions if specified
tool_defs = None
if tools:
tool_defs = [tool_registry.get(t).to_openai_format() for t in tools if tool_registry.get(t)]
if progress_callback:
progress_callback(0.2, "Processing task...")
max_iterations = 5
iteration = 0
while iteration < max_iterations:
try:
response = await self.llm_client.sync_call(
model=self.agent.config.model,
messages=messages,
tools=tool_defs,
temperature=self.agent.config.temperature,
max_tokens=self.agent.config.max_tokens
)
# Add assistant response
messages.append({"role": "assistant", "content": response.content})
# Check for tool calls
if response.tool_calls:
if progress_callback:
progress_callback(0.5, f"Executing {len(response.tool_calls)} tools...")
# Execute tools
tool_results = self.tool_executor.process_tool_calls(
response.tool_calls,
context
)
# Add tool results
for tr in tool_results:
messages.append({
"role": "tool",
"tool_call_id": tr["tool_call_id"],
"content": tr["content"]
})
if progress_callback:
progress_callback(0.8, "Tools executed")
else:
# No tool calls, task complete
return TaskResult.ok(
data=response.content,
output_data=task_data.get("output_template", {})
)
except Exception as e:
return TaskResult.fail(error=str(e))
iteration += 1
return TaskResult.fail(error="Max iterations exceeded")