Luxx/alcor/routes/auth.py

155 lines
4.1 KiB
Python

"""认证路由"""
from datetime import timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr
from sqlalchemy.orm import Session
from alcor.database import get_db
from alcor.models import User
from alcor.utils.helpers import (
hash_password,
verify_password,
create_access_token,
success_response,
error_response
)
router = APIRouter(prefix="/auth", tags=["认证"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
class UserRegister(BaseModel):
"""用户注册模型"""
username: str
email: Optional[EmailStr] = None
password: str
class UserLogin(BaseModel):
"""用户登录模型"""
username: str
password: str
class UserResponse(BaseModel):
"""用户响应模型"""
id: int
username: str
email: Optional[str] = None
role: str
is_active: bool
class TokenResponse(BaseModel):
"""令牌响应模型"""
access_token: str
token_type: str = "bearer"
def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
) -> User:
"""获取当前用户"""
from alcor.utils.helpers import decode_access_token
payload = decode_access_token(token)
if payload is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证"
)
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
return user
@router.post("/register", response_model=dict)
def register(user_data: UserRegister, db: Session = Depends(get_db)):
"""用户注册"""
# 检查用户名是否存在
existing_user = db.query(User).filter(User.username == user_data.username).first()
if existing_user:
return error_response("用户名已存在", 400)
# 检查邮箱是否存在
if user_data.email:
existing_email = db.query(User).filter(User.email == user_data.email).first()
if existing_email:
return error_response("邮箱已被注册", 400)
# 创建用户
password_hash = hash_password(user_data.password)
user = User(
username=user_data.username,
email=user_data.email,
password_hash=password_hash,
role="user"
)
db.add(user)
db.commit()
db.refresh(user)
return success_response(
data={"id": user.id, "username": user.username},
message="注册成功"
)
@router.post("/login", response_model=dict)
def login(user_data: UserLogin, db: Session = Depends(get_db)):
"""用户登录"""
user = db.query(User).filter(User.username == user_data.username).first()
if not user or not verify_password(user_data.password, user.password_hash or ""):
return error_response("用户名或密码错误", 401)
if not user.is_active:
return error_response("用户已被禁用", 403)
# 创建访问令牌
access_token = create_access_token(
data={"sub": user.id, "username": user.username},
expires_delta=timedelta(days=7)
)
return success_response(
data={
"access_token": access_token,
"token_type": "bearer",
"user": user.to_dict()
},
message="登录成功"
)
@router.post("/logout")
def logout(current_user: User = Depends(get_current_user)):
"""用户登出(前端清除令牌即可)"""
return success_response(message="登出成功")
@router.get("/me", response_model=dict)
def get_me(current_user: User = Depends(get_current_user)):
"""获取当前用户信息"""
return success_response(data=current_user.to_dict())