296 lines
9.0 KiB
Python
296 lines
9.0 KiB
Python
"""Agent routes - REST API for agent task management"""
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional
|
|
from fastapi import APIRouter, Depends, WebSocket
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from luxx.database import get_db
|
|
from luxx.models import User
|
|
from luxx.routes.auth import get_current_user
|
|
from luxx.utils.helpers import success_response, error_response, generate_id
|
|
from luxx.agents.core import AgentConfig, AgentType, AgentStatus
|
|
from luxx.agents.registry import AgentRegistry
|
|
from luxx.agents.supervisor import SupervisorAgent
|
|
from luxx.agents.worker import WorkerAgent
|
|
from luxx.agents.dag_scheduler import SchedulerPool, DAGScheduler
|
|
from luxx.services.llm_client import LLMClient
|
|
from luxx.tools.executor import ToolExecutor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/agents", tags=["Agents"])
|
|
|
|
|
|
# ============ Request/Response Models ============
|
|
|
|
class AgentCreateRequest(BaseModel):
|
|
"""Create agent task request"""
|
|
conversation_id: str
|
|
task: str
|
|
options: Optional[dict] = None
|
|
|
|
|
|
class AgentTaskResponse(BaseModel):
|
|
"""Agent task response"""
|
|
task_id: str
|
|
status: str
|
|
conversation_id: str
|
|
|
|
|
|
# ============ 全局实例 ============
|
|
|
|
# Scheduler pool for managing concurrent DAG executions
|
|
scheduler_pool = SchedulerPool(max_concurrent=10)
|
|
|
|
# Tool executor
|
|
tool_executor = ToolExecutor()
|
|
|
|
|
|
# ============ Routes ============
|
|
|
|
@router.post("/request", response_model=dict)
|
|
async def create_agent_task(
|
|
data: AgentCreateRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""
|
|
Create a new agent task
|
|
|
|
This will:
|
|
1. Create a Supervisor agent
|
|
2. Decompose the task into a DAG
|
|
3. Start DAG execution
|
|
4. Return task_id for WebSocket subscription
|
|
"""
|
|
from luxx.config import config
|
|
from luxx.models import Conversation, LLMProvider
|
|
|
|
# Get conversation to find provider
|
|
conversation = db.query(Conversation).filter(
|
|
Conversation.id == data.conversation_id,
|
|
Conversation.user_id == current_user.id
|
|
).first()
|
|
|
|
if not conversation:
|
|
return error_response("Conversation not found", 404)
|
|
|
|
# Determine LLM client configuration
|
|
llm_api_key = None
|
|
llm_api_url = None
|
|
llm_model = None
|
|
|
|
logger.info(f"Conversation provider_id: {conversation.provider_id}")
|
|
|
|
if conversation.provider_id:
|
|
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
|
|
logger.info(f"Provider found: {provider}")
|
|
if provider:
|
|
llm_api_key = provider.api_key
|
|
llm_api_url = provider.base_url
|
|
llm_model = provider.default_model
|
|
logger.info(f"Provider config - api_key: {'set' if llm_api_key else 'None'}, url: {llm_api_url}, model: {llm_model}")
|
|
|
|
# Fallback to config if no provider
|
|
if not llm_api_key:
|
|
llm_api_key = config.llm_api_key
|
|
llm_api_url = config.llm_api_url
|
|
llm_model = "deepseek-chat"
|
|
|
|
# Check if LLM API key is configured
|
|
if not llm_api_key:
|
|
return error_response(
|
|
"LLM API key not configured. Please set up a provider in settings or set DEEPSEEK_API_KEY environment variable.",
|
|
500
|
|
)
|
|
|
|
task_id = generate_id("task")
|
|
|
|
try:
|
|
# Create LLM client with proper configuration
|
|
llm_client = LLMClient(
|
|
api_key=llm_api_key,
|
|
api_url=llm_api_url,
|
|
model=llm_model
|
|
)
|
|
|
|
# Create supervisor agent
|
|
agent_registry = AgentRegistry()
|
|
|
|
supervisor_config = AgentConfig(
|
|
name=f"supervisor_{task_id}",
|
|
agent_type=AgentType.SUPERVISOR,
|
|
description=f"Supervisor for task {task_id}",
|
|
model=llm_model, # Use the model's default model
|
|
max_turns=10
|
|
)
|
|
|
|
supervisor_agent = agent_registry.create_agent(
|
|
config=supervisor_config,
|
|
user_id=current_user.id,
|
|
conversation_id=data.conversation_id
|
|
)
|
|
|
|
# Create supervisor instance
|
|
supervisor = SupervisorAgent(
|
|
agent=supervisor_agent,
|
|
llm_client=llm_client
|
|
)
|
|
|
|
# Decompose task into DAG
|
|
context = {
|
|
"user_id": current_user.id,
|
|
"username": current_user.username,
|
|
"conversation_id": data.conversation_id
|
|
}
|
|
|
|
dag = await supervisor.decompose_task(data.task, context)
|
|
|
|
# Create worker factory
|
|
def worker_factory():
|
|
# Create new LLM client for each worker with proper config
|
|
worker_llm_client = LLMClient(
|
|
api_key=llm_api_key,
|
|
api_url=llm_api_url,
|
|
model=llm_model
|
|
)
|
|
worker_config = AgentConfig(
|
|
name=f"worker_{task_id}",
|
|
agent_type=AgentType.WORKER,
|
|
description=f"Worker for task {task_id}",
|
|
model=llm_model,
|
|
max_turns=5
|
|
)
|
|
worker_agent = agent_registry.create_agent(
|
|
config=worker_config,
|
|
user_id=current_user.id,
|
|
conversation_id=data.conversation_id
|
|
)
|
|
return WorkerAgent(
|
|
agent=worker_agent,
|
|
llm_client=worker_llm_client,
|
|
tool_executor=tool_executor
|
|
)
|
|
|
|
# Create scheduler
|
|
scheduler = scheduler_pool.create_scheduler(
|
|
task_id=task_id,
|
|
dag=dag,
|
|
supervisor=supervisor,
|
|
worker_factory=worker_factory,
|
|
max_workers=3
|
|
)
|
|
|
|
# Start execution in background
|
|
asyncio.create_task(_execute_dag_background(
|
|
task_id=task_id,
|
|
scheduler=scheduler,
|
|
context=context,
|
|
task=data.task,
|
|
supervisor_agent=supervisor_agent
|
|
))
|
|
|
|
return success_response(data={
|
|
"task_id": task_id,
|
|
"status": "planning",
|
|
"conversation_id": data.conversation_id
|
|
}, message="Agent task created successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create agent task: {e}")
|
|
return error_response(f"Failed to create task: {str(e)}", 500)
|
|
|
|
|
|
async def _execute_dag_background(
|
|
task_id: str,
|
|
scheduler: DAGScheduler,
|
|
context: dict,
|
|
task: str,
|
|
supervisor_agent
|
|
):
|
|
"""Execute DAG in background and handle completion"""
|
|
try:
|
|
result = await scheduler.execute(context, task)
|
|
|
|
# Update supervisor status
|
|
supervisor_agent.status = AgentStatus.COMPLETED if result["success"] else AgentStatus.FAILED
|
|
|
|
# Emit completion event via WebSocket
|
|
from luxx.routes.agents_ws import emit_dag_complete, emit_node_complete
|
|
|
|
# Emit node complete events for each completed node
|
|
for node_id, node_result in result.get("results", {}).items():
|
|
if node_result.get("success"):
|
|
# Find the node in the DAG
|
|
node = scheduler.dag.nodes.get(node_id)
|
|
if node:
|
|
await emit_node_complete(task_id, node)
|
|
|
|
# Emit DAG complete event
|
|
await emit_dag_complete(task_id, result["success"], result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"DAG execution failed for task {task_id}: {e}")
|
|
|
|
|
|
@router.get("/task/{task_id}", response_model=dict)
|
|
async def get_task_status(
|
|
task_id: str,
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Get task status and DAG info"""
|
|
scheduler = scheduler_pool.get(task_id)
|
|
|
|
if not scheduler:
|
|
return error_response("Task not found", 404)
|
|
|
|
return success_response(data={
|
|
"task_id": task_id,
|
|
"status": "executing",
|
|
"dag": scheduler.dag.to_dict(),
|
|
"progress": scheduler.dag.progress
|
|
})
|
|
|
|
|
|
@router.post("/task/{task_id}/cancel", response_model=dict)
|
|
async def cancel_task(
|
|
task_id: str,
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Cancel a running task"""
|
|
if scheduler_pool.cancel(task_id):
|
|
return success_response(message="Task cancelled")
|
|
return error_response("Task not found or already completed", 404)
|
|
|
|
|
|
@router.delete("/task/{task_id}", response_model=dict)
|
|
async def delete_task(
|
|
task_id: str,
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Delete a task"""
|
|
if scheduler_pool.remove(task_id):
|
|
return success_response(message="Task deleted")
|
|
return error_response("Task not found", 404)
|
|
|
|
|
|
@router.get("/tasks", response_model=dict)
|
|
async def list_tasks(
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""List user's agent tasks"""
|
|
# Get active schedulers for this user
|
|
# Note: In a real implementation, you'd store task metadata in database
|
|
tasks = []
|
|
|
|
return success_response(data={
|
|
"items": tasks,
|
|
"total": len(tasks),
|
|
"page": page,
|
|
"page_size": page_size
|
|
})
|