115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
# ===========================================
|
|
# 班级操行分管理系统 - 后端服务
|
|
#
|
|
# 开发者: Canglan
|
|
# 联系方式: admin@sea-studio.top
|
|
# 版权归属: Sea Network Technology Studio
|
|
# 许可证: MIT License
|
|
#
|
|
# 版权所有 © Sea Network Technology Studio
|
|
# ===========================================
|
|
|
|
from fastapi import Request, HTTPException
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from typing import Optional, Dict, Any, Tuple
|
|
import re
|
|
|
|
from utils.jwt_handler import jwt_handler
|
|
from utils.redis_client import RedisClient
|
|
from utils.response import unauthorized_response
|
|
from utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# 不需要认证的路由
|
|
PUBLIC_PATHS = [
|
|
r'^/$',
|
|
r'^/health$',
|
|
r'^/api/auth/login$',
|
|
r'^/api/auth/logout$',
|
|
r'^/debug/.*$', # 调试入口
|
|
]
|
|
|
|
# 不需要Token验证但需要记录访问的路由
|
|
OPEN_PATHS = [
|
|
r'^/api/auth/change-password$',
|
|
]
|
|
|
|
|
|
def is_public_path(path: str) -> bool:
|
|
"""检查是否为公开路径"""
|
|
for pattern in PUBLIC_PATHS:
|
|
if re.match(pattern, path):
|
|
return True
|
|
return False
|
|
|
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
|
"""JWT认证中间件"""
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# OPTIONS 预检请求跳过认证
|
|
if request.method == "OPTIONS":
|
|
return await call_next(request)
|
|
|
|
path = request.url.path
|
|
|
|
# 公开路径跳过认证
|
|
if is_public_path(path):
|
|
return await call_next(request)
|
|
|
|
# 获取Authorization头
|
|
auth_header = request.headers.get("Authorization")
|
|
|
|
if not auth_header:
|
|
return unauthorized_response("缺少认证令牌")
|
|
|
|
# 解析Bearer Token
|
|
try:
|
|
scheme, token = auth_header.split()
|
|
if scheme.lower() != "bearer":
|
|
return unauthorized_response("认证格式错误")
|
|
except ValueError:
|
|
return unauthorized_response("认证格式错误")
|
|
|
|
# 验证Token
|
|
payload = jwt_handler.verify_token(token)
|
|
if not payload:
|
|
return unauthorized_response("令牌无效或已过期")
|
|
|
|
# 验证Redis中的Token
|
|
user_id = payload.get("user_id")
|
|
stored_token = await RedisClient.get_user_token(user_id)
|
|
|
|
if not stored_token or stored_token != token:
|
|
return unauthorized_response("令牌已失效,请重新登录")
|
|
|
|
# 将用户信息存储到request.state
|
|
request.state.user_id = payload.get("user_id")
|
|
request.state.username = payload.get("username")
|
|
request.state.user_type = payload.get("user_type")
|
|
request.state.student_id = payload.get("student_id")
|
|
request.state.role = payload.get("role")
|
|
|
|
# 刷新Token过期时间
|
|
from config import settings
|
|
await RedisClient.expire(f"user_token:{user_id}", settings.JWT_EXPIRE_MINUTES * 60)
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
async def get_current_user(request: Request) -> Dict[str, Any]:
|
|
"""获取当前登录用户信息"""
|
|
return {
|
|
"user_id": request.state.user_id,
|
|
"username": request.state.username,
|
|
"user_type": request.state.user_type,
|
|
"student_id": request.state.student_id,
|
|
"role": request.state.role
|
|
}
|
|
|
|
|
|
async def get_current_user_id(request: Request) -> int:
|
|
"""获取当前用户ID"""
|
|
return request.state.user_id |