Luxx/luxx/routes/providers.py

248 lines
7.8 KiB
Python

"""LLM Provider routes"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from luxx.database import get_db, SessionLocal
from luxx.models import User, LLMProvider
from luxx.routes.auth import get_current_user
from luxx.utils.helpers import success_response
import httpx
import asyncio
router = APIRouter(prefix="/providers", tags=["LLM Providers"])
class ProviderCreate(BaseModel):
name: str
provider_type: str = "openai"
base_url: str
api_key: str
default_model: str = "gpt-4"
is_default: bool = False
class ProviderUpdate(BaseModel):
name: Optional[str] = None
provider_type: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
default_model: Optional[str] = None
is_default: Optional[bool] = None
enabled: Optional[bool] = None
@router.get("/", response_model=dict)
def list_providers(
current_user: User = Depends(get_current_user)
):
"""Get user's LLM providers"""
db = SessionLocal()
try:
providers = db.query(LLMProvider).filter(
LLMProvider.user_id == current_user.id
).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all()
return success_response(data={
"providers": [p.to_dict() for p in providers],
"total": len(providers)
})
finally:
db.close()
@router.post("/", response_model=dict)
def create_provider(
provider: ProviderCreate,
current_user: User = Depends(get_current_user)
):
"""Create a new LLM provider"""
db = SessionLocal()
try:
# If this is set as default, unset other defaults
if provider.is_default:
db.query(LLMProvider).filter(
LLMProvider.user_id == current_user.id
).update({"is_default": False})
db_provider = LLMProvider(
user_id=current_user.id,
name=provider.name,
provider_type=provider.provider_type,
base_url=provider.base_url,
api_key=provider.api_key,
default_model=provider.default_model,
is_default=provider.is_default
)
db.add(db_provider)
db.commit()
db.refresh(db_provider)
return success_response(data=db_provider.to_dict(include_key=True))
except Exception as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e))
finally:
db.close()
@router.get("/{provider_id}", response_model=dict)
def get_provider(
provider_id: int,
current_user: User = Depends(get_current_user)
):
"""Get provider details"""
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
return success_response(data=provider.to_dict(include_key=True))
finally:
db.close()
@router.put("/{provider_id}", response_model=dict)
def update_provider(
provider_id: int,
update: ProviderUpdate,
current_user: User = Depends(get_current_user)
):
"""Update provider"""
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# If setting as default, unset others
if update.is_default:
db.query(LLMProvider).filter(
LLMProvider.user_id == current_user.id,
LLMProvider.id != provider_id
).update({"is_default": False})
# Update fields
update_data = update.dict(exclude_unset=True)
# Keep existing API key if the new one is empty
if update_data.get('api_key') == '':
update_data.pop('api_key')
for key, value in update_data.items():
setattr(provider, key, value)
db.commit()
db.refresh(provider)
return success_response(data=provider.to_dict(include_key=True))
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e))
finally:
db.close()
@router.delete("/{provider_id}", response_model=dict)
def delete_provider(
provider_id: int,
current_user: User = Depends(get_current_user)
):
"""Delete provider"""
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
db.delete(provider)
db.commit()
return success_response(message="Provider deleted")
finally:
db.close()
@router.post("/{provider_id}/test", response_model=dict)
def test_provider(
provider_id: int,
current_user: User = Depends(get_current_user)
):
"""Test provider connection"""
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# Test the connection
try:
async def test():
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(
f"{provider.base_url}/chat/completions",
headers={
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json"
},
json={
"model": provider.default_model,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 10
}
)
return {
"status_code": response.status_code,
"success": response.status_code == 200,
"response_body": response.text[:500] if response.text else None
}
result = asyncio.run(test())
if result["success"]:
return success_response(data={
"success": True,
"message": "连接成功",
"status_code": result["status_code"]
})
else:
return success_response(data={
"success": False,
"message": f"HTTP {result['status_code']}",
"status_code": result["status_code"],
"response_body": result["response_body"]
})
except httpx.HTTPStatusError as e:
return success_response(data={
"success": False,
"message": f"HTTP {e.response.status_code}: {e.response.text[:200] if e.response.text else 'Unknown error'}",
"status_code": e.response.status_code,
"response_body": e.response.text[:500] if e.response.text else None
})
except Exception as e:
return success_response(data={
"success": False,
"message": f"连接失败: {str(e)}",
"error_type": type(e).__name__
})
finally:
db.close()