218 lines
6.5 KiB
Python
218 lines
6.5 KiB
Python
"""LLM Provider routes"""
|
|
from typing import Optional
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
from luxx.database import get_db, SessionLocal
|
|
from luxx.models import User, LLMProvider
|
|
from luxx.routes.auth import get_current_user
|
|
from luxx.utils.helpers import success_response
|
|
import httpx
|
|
import asyncio
|
|
|
|
router = APIRouter(prefix="/providers", tags=["LLM Providers"])
|
|
|
|
|
|
class ProviderCreate(BaseModel):
|
|
name: str
|
|
provider_type: str = "openai"
|
|
base_url: str
|
|
api_key: str
|
|
default_model: str = "gpt-4"
|
|
is_default: bool = False
|
|
|
|
|
|
class ProviderUpdate(BaseModel):
|
|
name: Optional[str] = None
|
|
provider_type: Optional[str] = None
|
|
base_url: Optional[str] = None
|
|
api_key: Optional[str] = None
|
|
default_model: Optional[str] = None
|
|
is_default: Optional[bool] = None
|
|
enabled: Optional[bool] = None
|
|
|
|
|
|
@router.get("/", response_model=dict)
|
|
def list_providers(
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Get user's LLM providers"""
|
|
db = SessionLocal()
|
|
try:
|
|
providers = db.query(LLMProvider).filter(
|
|
LLMProvider.user_id == current_user.id
|
|
).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all()
|
|
|
|
return success_response(data={
|
|
"providers": [p.to_dict() for p in providers],
|
|
"total": len(providers)
|
|
})
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@router.post("/", response_model=dict)
|
|
def create_provider(
|
|
provider: ProviderCreate,
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Create a new LLM provider"""
|
|
db = SessionLocal()
|
|
try:
|
|
# If this is set as default, unset other defaults
|
|
if provider.is_default:
|
|
db.query(LLMProvider).filter(
|
|
LLMProvider.user_id == current_user.id
|
|
).update({"is_default": False})
|
|
|
|
db_provider = LLMProvider(
|
|
user_id=current_user.id,
|
|
name=provider.name,
|
|
provider_type=provider.provider_type,
|
|
base_url=provider.base_url,
|
|
api_key=provider.api_key,
|
|
default_model=provider.default_model,
|
|
is_default=provider.is_default
|
|
)
|
|
db.add(db_provider)
|
|
db.commit()
|
|
db.refresh(db_provider)
|
|
|
|
return success_response(data=db_provider.to_dict(include_key=True))
|
|
except Exception as e:
|
|
db.rollback()
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@router.get("/{provider_id}", response_model=dict)
|
|
def get_provider(
|
|
provider_id: int,
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Get provider details"""
|
|
db = SessionLocal()
|
|
try:
|
|
provider = db.query(LLMProvider).filter(
|
|
LLMProvider.id == provider_id,
|
|
LLMProvider.user_id == current_user.id
|
|
).first()
|
|
|
|
if not provider:
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|
|
|
return success_response(data=provider.to_dict(include_key=True))
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@router.put("/{provider_id}", response_model=dict)
|
|
def update_provider(
|
|
provider_id: int,
|
|
update: ProviderUpdate,
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Update provider"""
|
|
db = SessionLocal()
|
|
try:
|
|
provider = db.query(LLMProvider).filter(
|
|
LLMProvider.id == provider_id,
|
|
LLMProvider.user_id == current_user.id
|
|
).first()
|
|
|
|
if not provider:
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|
|
|
# If setting as default, unset others
|
|
if update.is_default:
|
|
db.query(LLMProvider).filter(
|
|
LLMProvider.user_id == current_user.id,
|
|
LLMProvider.id != provider_id
|
|
).update({"is_default": False})
|
|
|
|
# Update fields
|
|
update_data = update.dict(exclude_unset=True)
|
|
for key, value in update_data.items():
|
|
setattr(provider, key, value)
|
|
|
|
db.commit()
|
|
db.refresh(provider)
|
|
|
|
return success_response(data=provider.to_dict(include_key=True))
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
db.rollback()
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@router.delete("/{provider_id}", response_model=dict)
|
|
def delete_provider(
|
|
provider_id: int,
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Delete provider"""
|
|
db = SessionLocal()
|
|
try:
|
|
provider = db.query(LLMProvider).filter(
|
|
LLMProvider.id == provider_id,
|
|
LLMProvider.user_id == current_user.id
|
|
).first()
|
|
|
|
if not provider:
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|
|
|
db.delete(provider)
|
|
db.commit()
|
|
|
|
return success_response(message="Provider deleted")
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@router.post("/{provider_id}/test", response_model=dict)
|
|
def test_provider(
|
|
provider_id: int,
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""Test provider connection"""
|
|
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
provider = db.query(LLMProvider).filter(
|
|
LLMProvider.id == provider_id,
|
|
LLMProvider.user_id == current_user.id
|
|
).first()
|
|
|
|
if not provider:
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|
|
|
# Test the connection
|
|
try:
|
|
async def test():
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
response = await client.post(
|
|
f"{provider.base_url}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {provider.api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": provider.default_model,
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"max_tokens": 10
|
|
}
|
|
)
|
|
return response.status_code == 200
|
|
|
|
success = asyncio.run(test())
|
|
return success_response(data={"success": success, "message": "连接成功" if success else "连接失败"})
|
|
except Exception as e:
|
|
return success_response(data={"success": False, "message": f"连接失败: {str(e)}"})
|
|
finally:
|
|
db.close()
|