"""认证路由""" from datetime import timedelta from typing import Optional from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel, EmailStr from sqlalchemy.orm import Session from alcor.database import get_db from alcor.models import User from alcor.utils.helpers import ( hash_password, verify_password, create_access_token, success_response, error_response ) router = APIRouter(prefix="/auth", tags=["认证"]) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") class UserRegister(BaseModel): """用户注册模型""" username: str email: Optional[EmailStr] = None password: str class UserLogin(BaseModel): """用户登录模型""" username: str password: str class UserResponse(BaseModel): """用户响应模型""" id: int username: str email: Optional[str] = None role: str is_active: bool class TokenResponse(BaseModel): """令牌响应模型""" access_token: str token_type: str = "bearer" def get_current_user( token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: """获取当前用户""" from alcor.utils.helpers import decode_access_token payload = decode_access_token(token) if payload is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", headers={"WWW-Authenticate": "Bearer"}, ) user_id = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证" ) user = db.query(User).filter(User.id == user_id).first() if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在" ) return user @router.post("/register", response_model=dict) def register(user_data: UserRegister, db: Session = Depends(get_db)): """用户注册""" # 检查用户名是否存在 existing_user = db.query(User).filter(User.username == user_data.username).first() if existing_user: return error_response("用户名已存在", 400) # 检查邮箱是否存在 if user_data.email: existing_email = db.query(User).filter(User.email == user_data.email).first() if existing_email: return error_response("邮箱已被注册", 400) # 创建用户 password_hash = hash_password(user_data.password) user = User( username=user_data.username, email=user_data.email, password_hash=password_hash, role="user" ) db.add(user) db.commit() db.refresh(user) return success_response( data={"id": user.id, "username": user.username}, message="注册成功" ) @router.post("/login", response_model=dict) def login(user_data: UserLogin, db: Session = Depends(get_db)): """用户登录""" user = db.query(User).filter(User.username == user_data.username).first() if not user or not verify_password(user_data.password, user.password_hash or ""): return error_response("用户名或密码错误", 401) if not user.is_active: return error_response("用户已被禁用", 403) # 创建访问令牌 access_token = create_access_token( data={"sub": user.id, "username": user.username}, expires_delta=timedelta(days=7) ) return success_response( data={ "access_token": access_token, "token_type": "bearer", "user": user.to_dict() }, message="登录成功" ) @router.post("/logout") def logout(current_user: User = Depends(get_current_user)): """用户登出(前端清除令牌即可)""" return success_response(message="登出成功") @router.get("/me", response_model=dict) def get_me(current_user: User = Depends(get_current_user)): """获取当前用户信息""" return success_response(data=current_user.to_dict())