"""LLM Provider routes""" from typing import Optional from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from luxx.database import get_db, SessionLocal from luxx.models import User, LLMProvider from luxx.routes.auth import get_current_user from luxx.utils.helpers import success_response import httpx import asyncio router = APIRouter(prefix="/providers", tags=["LLM Providers"]) class ProviderCreate(BaseModel): name: str provider_type: str = "openai" base_url: str api_key: str default_model: str = "gpt-4" is_default: bool = False class ProviderUpdate(BaseModel): name: Optional[str] = None provider_type: Optional[str] = None base_url: Optional[str] = None api_key: Optional[str] = None default_model: Optional[str] = None is_default: Optional[bool] = None enabled: Optional[bool] = None @router.get("/", response_model=dict) def list_providers( current_user: User = Depends(get_current_user) ): """Get user's LLM providers""" db = SessionLocal() try: providers = db.query(LLMProvider).filter( LLMProvider.user_id == current_user.id ).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all() return success_response(data={ "providers": [p.to_dict() for p in providers], "total": len(providers) }) finally: db.close() @router.post("/", response_model=dict) def create_provider( provider: ProviderCreate, current_user: User = Depends(get_current_user) ): """Create a new LLM provider""" db = SessionLocal() try: # If this is set as default, unset other defaults if provider.is_default: db.query(LLMProvider).filter( LLMProvider.user_id == current_user.id ).update({"is_default": False}) db_provider = LLMProvider( user_id=current_user.id, name=provider.name, provider_type=provider.provider_type, base_url=provider.base_url, api_key=provider.api_key, default_model=provider.default_model, is_default=provider.is_default ) db.add(db_provider) db.commit() db.refresh(db_provider) return success_response(data=db_provider.to_dict(include_key=True)) except Exception as e: db.rollback() raise HTTPException(status_code=400, detail=str(e)) finally: db.close() @router.get("/{provider_id}", response_model=dict) def get_provider( provider_id: int, current_user: User = Depends(get_current_user) ): """Get provider details""" db = SessionLocal() try: provider = db.query(LLMProvider).filter( LLMProvider.id == provider_id, LLMProvider.user_id == current_user.id ).first() if not provider: raise HTTPException(status_code=404, detail="Provider not found") return success_response(data=provider.to_dict(include_key=True)) finally: db.close() @router.put("/{provider_id}", response_model=dict) def update_provider( provider_id: int, update: ProviderUpdate, current_user: User = Depends(get_current_user) ): """Update provider""" db = SessionLocal() try: provider = db.query(LLMProvider).filter( LLMProvider.id == provider_id, LLMProvider.user_id == current_user.id ).first() if not provider: raise HTTPException(status_code=404, detail="Provider not found") # If setting as default, unset others if update.is_default: db.query(LLMProvider).filter( LLMProvider.user_id == current_user.id, LLMProvider.id != provider_id ).update({"is_default": False}) # Update fields update_data = update.dict(exclude_unset=True) # Keep existing API key if the new one is empty if update_data.get('api_key') == '': update_data.pop('api_key') for key, value in update_data.items(): setattr(provider, key, value) db.commit() db.refresh(provider) return success_response(data=provider.to_dict(include_key=True)) except HTTPException: raise except Exception as e: db.rollback() raise HTTPException(status_code=400, detail=str(e)) finally: db.close() @router.delete("/{provider_id}", response_model=dict) def delete_provider( provider_id: int, current_user: User = Depends(get_current_user) ): """Delete provider""" db = SessionLocal() try: provider = db.query(LLMProvider).filter( LLMProvider.id == provider_id, LLMProvider.user_id == current_user.id ).first() if not provider: raise HTTPException(status_code=404, detail="Provider not found") db.delete(provider) db.commit() return success_response(message="Provider deleted") finally: db.close() @router.post("/{provider_id}/test", response_model=dict) def test_provider( provider_id: int, current_user: User = Depends(get_current_user) ): """Test provider connection""" db = SessionLocal() try: provider = db.query(LLMProvider).filter( LLMProvider.id == provider_id, LLMProvider.user_id == current_user.id ).first() if not provider: raise HTTPException(status_code=404, detail="Provider not found") # Test the connection try: async def test(): async with httpx.AsyncClient(timeout=10.0) as client: response = await client.post( f"{provider.base_url}/chat/completions", headers={ "Authorization": f"Bearer {provider.api_key}", "Content-Type": "application/json" }, json={ "model": provider.default_model, "messages": [{"role": "user", "content": "Hi"}], "max_tokens": 10 } ) return { "status_code": response.status_code, "success": response.status_code == 200, "response_body": response.text[:500] if response.text else None } result = asyncio.run(test()) if result["success"]: return success_response(data={ "success": True, "message": "连接成功", "status_code": result["status_code"] }) else: return success_response(data={ "success": False, "message": f"HTTP {result['status_code']}", "status_code": result["status_code"], "response_body": result["response_body"] }) except httpx.HTTPStatusError as e: return success_response(data={ "success": False, "message": f"HTTP {e.response.status_code}: {e.response.text[:200] if e.response.text else 'Unknown error'}", "status_code": e.response.status_code, "response_body": e.response.text[:500] if e.response.text else None }) except Exception as e: return success_response(data={ "success": False, "message": f"连接失败: {str(e)}", "error_type": type(e).__name__ }) finally: db.close()