Luxx/luxx/routes/agents.py

296 lines
9.0 KiB
Python

"""Agent routes - REST API for agent task management"""
import asyncio
import logging
from typing import Optional
from fastapi import APIRouter, Depends, WebSocket
from pydantic import BaseModel
from sqlalchemy.orm import Session
from luxx.database import get_db
from luxx.models import User
from luxx.routes.auth import get_current_user
from luxx.utils.helpers import success_response, error_response, generate_id
from luxx.agents.core import AgentConfig, AgentType, AgentStatus
from luxx.agents.registry import AgentRegistry
from luxx.agents.supervisor import SupervisorAgent
from luxx.agents.worker import WorkerAgent
from luxx.agents.dag_scheduler import SchedulerPool, DAGScheduler
from luxx.services.llm_client import LLMClient
from luxx.tools.executor import ToolExecutor
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/agents", tags=["Agents"])
# ============ Request/Response Models ============
class AgentCreateRequest(BaseModel):
"""Create agent task request"""
conversation_id: str
task: str
options: Optional[dict] = None
class AgentTaskResponse(BaseModel):
"""Agent task response"""
task_id: str
status: str
conversation_id: str
# ============ 全局实例 ============
# Scheduler pool for managing concurrent DAG executions
scheduler_pool = SchedulerPool(max_concurrent=10)
# Tool executor
tool_executor = ToolExecutor()
# ============ Routes ============
@router.post("/request", response_model=dict)
async def create_agent_task(
data: AgentCreateRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Create a new agent task
This will:
1. Create a Supervisor agent
2. Decompose the task into a DAG
3. Start DAG execution
4. Return task_id for WebSocket subscription
"""
from luxx.config import config
from luxx.models import Conversation, LLMProvider
# Get conversation to find provider
conversation = db.query(Conversation).filter(
Conversation.id == data.conversation_id,
Conversation.user_id == current_user.id
).first()
if not conversation:
return error_response("Conversation not found", 404)
# Determine LLM client configuration
llm_api_key = None
llm_api_url = None
llm_model = None
logger.info(f"Conversation provider_id: {conversation.provider_id}")
if conversation.provider_id:
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
logger.info(f"Provider found: {provider}")
if provider:
llm_api_key = provider.api_key
llm_api_url = provider.base_url
llm_model = provider.default_model
logger.info(f"Provider config - api_key: {'set' if llm_api_key else 'None'}, url: {llm_api_url}, model: {llm_model}")
# Fallback to config if no provider
if not llm_api_key:
llm_api_key = config.llm_api_key
llm_api_url = config.llm_api_url
llm_model = "deepseek-chat"
# Check if LLM API key is configured
if not llm_api_key:
return error_response(
"LLM API key not configured. Please set up a provider in settings or set DEEPSEEK_API_KEY environment variable.",
500
)
task_id = generate_id("task")
try:
# Create LLM client with proper configuration
llm_client = LLMClient(
api_key=llm_api_key,
api_url=llm_api_url,
model=llm_model
)
# Create supervisor agent
agent_registry = AgentRegistry()
supervisor_config = AgentConfig(
name=f"supervisor_{task_id}",
agent_type=AgentType.SUPERVISOR,
description=f"Supervisor for task {task_id}",
model=llm_model, # Use the model's default model
max_turns=10
)
supervisor_agent = agent_registry.create_agent(
config=supervisor_config,
user_id=current_user.id,
conversation_id=data.conversation_id
)
# Create supervisor instance
supervisor = SupervisorAgent(
agent=supervisor_agent,
llm_client=llm_client
)
# Decompose task into DAG
context = {
"user_id": current_user.id,
"username": current_user.username,
"conversation_id": data.conversation_id
}
dag = await supervisor.decompose_task(data.task, context)
# Create worker factory
def worker_factory():
# Create new LLM client for each worker with proper config
worker_llm_client = LLMClient(
api_key=llm_api_key,
api_url=llm_api_url,
model=llm_model
)
worker_config = AgentConfig(
name=f"worker_{task_id}",
agent_type=AgentType.WORKER,
description=f"Worker for task {task_id}",
model=llm_model,
max_turns=5
)
worker_agent = agent_registry.create_agent(
config=worker_config,
user_id=current_user.id,
conversation_id=data.conversation_id
)
return WorkerAgent(
agent=worker_agent,
llm_client=worker_llm_client,
tool_executor=tool_executor
)
# Create scheduler
scheduler = scheduler_pool.create_scheduler(
task_id=task_id,
dag=dag,
supervisor=supervisor,
worker_factory=worker_factory,
max_workers=3
)
# Start execution in background
asyncio.create_task(_execute_dag_background(
task_id=task_id,
scheduler=scheduler,
context=context,
task=data.task,
supervisor_agent=supervisor_agent
))
return success_response(data={
"task_id": task_id,
"status": "planning",
"conversation_id": data.conversation_id
}, message="Agent task created successfully")
except Exception as e:
logger.error(f"Failed to create agent task: {e}")
return error_response(f"Failed to create task: {str(e)}", 500)
async def _execute_dag_background(
task_id: str,
scheduler: DAGScheduler,
context: dict,
task: str,
supervisor_agent
):
"""Execute DAG in background and handle completion"""
try:
result = await scheduler.execute(context, task)
# Update supervisor status
supervisor_agent.status = AgentStatus.COMPLETED if result["success"] else AgentStatus.FAILED
# Emit completion event via WebSocket
from luxx.routes.agents_ws import emit_dag_complete, emit_node_complete
# Emit node complete events for each completed node
for node_id, node_result in result.get("results", {}).items():
if node_result.get("success"):
# Find the node in the DAG
node = scheduler.dag.nodes.get(node_id)
if node:
await emit_node_complete(task_id, node)
# Emit DAG complete event
await emit_dag_complete(task_id, result["success"], result)
except Exception as e:
logger.error(f"DAG execution failed for task {task_id}: {e}")
@router.get("/task/{task_id}", response_model=dict)
async def get_task_status(
task_id: str,
current_user: User = Depends(get_current_user)
):
"""Get task status and DAG info"""
scheduler = scheduler_pool.get(task_id)
if not scheduler:
return error_response("Task not found", 404)
return success_response(data={
"task_id": task_id,
"status": "executing",
"dag": scheduler.dag.to_dict(),
"progress": scheduler.dag.progress
})
@router.post("/task/{task_id}/cancel", response_model=dict)
async def cancel_task(
task_id: str,
current_user: User = Depends(get_current_user)
):
"""Cancel a running task"""
if scheduler_pool.cancel(task_id):
return success_response(message="Task cancelled")
return error_response("Task not found or already completed", 404)
@router.delete("/task/{task_id}", response_model=dict)
async def delete_task(
task_id: str,
current_user: User = Depends(get_current_user)
):
"""Delete a task"""
if scheduler_pool.remove(task_id):
return success_response(message="Task deleted")
return error_response("Task not found", 404)
@router.get("/tasks", response_model=dict)
async def list_tasks(
page: int = 1,
page_size: int = 20,
current_user: User = Depends(get_current_user)
):
"""List user's agent tasks"""
# Get active schedulers for this user
# Note: In a real implementation, you'd store task metadata in database
tasks = []
return success_response(data={
"items": tasks,
"total": len(tasks),
"page": page,
"page_size": page_size
})