Files
SharedClassManager/backend/middleware/auth_middleware.py

161 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ===========================================
# 班级操行分管理系统 - 后端服务
#
# 开发者: Canglan
# 联系方式: admin@sea-studio.top
# 版权归属: Sea Network Technology Studio
# 许可证: MIT License
#
# 版权所有 © Sea Network Technology Studio
# ===========================================
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from fastapi.responses import JSONResponse
from typing import Optional, Dict, Any
import re
from config import settings
from utils.jwt_handler import jwt_handler
from utils.redis_client import RedisClient
from utils.response import unauthorized_response
from utils.logger import get_logger
logger = get_logger(__name__)
# 不需要认证的路由
PUBLIC_PATHS = [
r'^/$',
r'^/health$',
r'^/api/auth/login$',
r'^/api/auth/logout$',
]
def is_public_path(path: str) -> bool:
"""检查是否为公开路径"""
for pattern in PUBLIC_PATHS:
if re.match(pattern, path):
return True
# 动态匹配调试入口路径
if settings.DEBUG_PATH and path == settings.DEBUG_PATH:
return True
return False
return False
class AuthMiddleware(BaseHTTPMiddleware):
"""JWT认证中间件"""
async def dispatch(self, request: Request, call_next):
path = request.url.path
# OPTIONS 预检请求跳过认证
if request.method == "OPTIONS":
logger.debug(f"[Auth] OPTIONS {path} - 跳过认证")
return await call_next(request)
# 公开路径跳过认证
if is_public_path(path):
logger.debug(f"[Auth] {request.method} {path} - 公开路径,跳过认证")
return await call_next(request)
logger.info(f"[Auth] {request.method} {path} - 开始认证")
try:
# 获取Authorization头
auth_header = request.headers.get("Authorization")
if not auth_header:
logger.warning(f"[Auth] {path} - 缺少Authorization header")
return self._cors_response(request, 401, "缺少认证令牌")
# 解析Bearer Token
try:
scheme, token = auth_header.split()
if scheme.lower() != "bearer":
logger.warning(f"[Auth] {path} - Authorization header格式错误")
return self._cors_response(request, 401, "认证格式错误")
except ValueError:
logger.warning(f"[Auth] {path} - Authorization header格式错误")
return self._cors_response(request, 401, "认证格式错误")
# 验证Token
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"[Auth] {path} - JWT验证失败")
return self._cors_response(request, 401, "令牌无效或已过期")
# 验证Redis中的Token
user_id = payload.get("user_id")
stored_token = await RedisClient.get_user_token(user_id)
if not stored_token or stored_token != token:
logger.warning(f"[Auth] {path} - Redis Token不匹配, user_id={user_id}, stored={'' if stored_token else ''}")
return self._cors_response(request, 401, "令牌已失效,请重新登录")
# 将用户信息存储到request.state
request.state.user_id = payload.get("user_id")
request.state.username = payload.get("username")
request.state.user_type = payload.get("user_type")
request.state.student_id = payload.get("student_id")
request.state.role = payload.get("role")
# 刷新Token过期时间
await RedisClient.expire(f"user_token:{user_id}", settings.JWT_EXPIRE_MINUTES * 60)
logger.debug(f"[Auth] {path} - 认证成功, user_id={user_id}, username={payload.get('username')}")
except Exception as e:
logger.error(f"认证中间件异常: {e}", exc_info=True)
return self._cors_response(request, 401, "认证服务异常,请稍后重试")
try:
response = await call_next(request)
# 为所有响应确保CORS头存在防止路由层异常导致CORS头丢失
origin = request.headers.get("origin", "")
allowed_origins = settings.CORS_ORIGINS or []
if origin in allowed_origins and not response.headers.get("access-control-allow-origin"):
response.headers["access-control-allow-origin"] = origin
response.headers["access-control-allow-credentials"] = "true"
return response
except Exception as e:
logger.error(f"[Auth] call_next异常: {e}", exc_info=True)
return self._cors_response(request, 500, "服务器内部错误")
def _cors_response(self, request: Request, status_code: int, message: str) -> JSONResponse:
"""创建带CORS头的响应"""
origin = request.headers.get("origin", "")
allowed_origins = settings.CORS_ORIGINS or []
headers = {}
if origin in allowed_origins:
headers["Access-Control-Allow-Origin"] = origin
headers["Access-Control-Allow-Credentials"] = "true"
return JSONResponse(
status_code=status_code,
content={
"success": False,
"code": status_code,
"message": message,
"data": None
},
headers=headers
)
async def get_current_user(request: Request) -> Dict[str, Any]:
"""获取当前登录用户信息"""
return {
"user_id": request.state.user_id,
"username": request.state.username,
"user_type": request.state.user_type,
"student_id": request.state.student_id,
"role": request.state.role
}
async def get_current_user_id(request: Request) -> int:
"""获取当前用户ID"""
return request.state.user_id