# =========================================== # 班级操行分管理系统 - 后端服务 # # 开发者: Canglan # 联系方式: admin@sea-studio.top # 版权归属: Sea Network Technology Studio # 许可证: MIT License # # 版权所有 © Sea Network Technology Studio # =========================================== from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response from fastapi.responses import JSONResponse from typing import Optional, Dict, Any import re from config import settings from utils.jwt_handler import jwt_handler from utils.redis_client import RedisClient from utils.response import unauthorized_response from utils.logger import get_logger logger = get_logger(__name__) # 不需要认证的路由 PUBLIC_PATHS = [ r'^/$', r'^/health$', r'^/api/auth/login$', r'^/api/auth/logout$', r'^/debug/.*$', # 调试入口 ] def is_public_path(path: str) -> bool: """检查是否为公开路径""" for pattern in PUBLIC_PATHS: if re.match(pattern, path): return True return False class AuthMiddleware(BaseHTTPMiddleware): """JWT认证中间件""" async def dispatch(self, request: Request, call_next): path = request.url.path # OPTIONS 预检请求跳过认证 if request.method == "OPTIONS": logger.debug(f"[Auth] OPTIONS {path} - 跳过认证") return await call_next(request) # 公开路径跳过认证 if is_public_path(path): logger.debug(f"[Auth] {request.method} {path} - 公开路径,跳过认证") return await call_next(request) logger.info(f"[Auth] {request.method} {path} - 开始认证") try: # 获取Authorization头 auth_header = request.headers.get("Authorization") if not auth_header: logger.warning(f"[Auth] {path} - 缺少Authorization header") return self._cors_response(request, 401, "缺少认证令牌") # 解析Bearer Token try: scheme, token = auth_header.split() if scheme.lower() != "bearer": logger.warning(f"[Auth] {path} - Authorization header格式错误") return self._cors_response(request, 401, "认证格式错误") except ValueError: logger.warning(f"[Auth] {path} - Authorization header格式错误") return self._cors_response(request, 401, "认证格式错误") # 验证Token payload = jwt_handler.verify_token(token) if not payload: logger.warning(f"[Auth] {path} - JWT验证失败") return self._cors_response(request, 401, "令牌无效或已过期") # 验证Redis中的Token user_id = payload.get("user_id") stored_token = await RedisClient.get_user_token(user_id) if not stored_token or stored_token != token: logger.warning(f"[Auth] {path} - Redis Token不匹配, user_id={user_id}, stored={'有' if stored_token else '无'}") return self._cors_response(request, 401, "令牌已失效,请重新登录") # 将用户信息存储到request.state request.state.user_id = payload.get("user_id") request.state.username = payload.get("username") request.state.user_type = payload.get("user_type") request.state.student_id = payload.get("student_id") request.state.role = payload.get("role") # 刷新Token过期时间 await RedisClient.expire(f"user_token:{user_id}", settings.JWT_EXPIRE_MINUTES * 60) logger.debug(f"[Auth] {path} - 认证成功, user_id={user_id}, username={payload.get('username')}") except Exception as e: logger.error(f"认证中间件异常: {e}", exc_info=True) return self._cors_response(request, 401, "认证服务异常,请稍后重试") try: response = await call_next(request) # 为所有响应确保CORS头存在(防止路由层异常导致CORS头丢失) origin = request.headers.get("origin", "") allowed_origins = settings.CORS_ORIGINS or [] if origin in allowed_origins and not response.headers.get("access-control-allow-origin"): response.headers["access-control-allow-origin"] = origin response.headers["access-control-allow-credentials"] = "true" return response except Exception as e: logger.error(f"[Auth] call_next异常: {e}", exc_info=True) return self._cors_response(request, 500, "服务器内部错误") def _cors_response(self, request: Request, status_code: int, message: str) -> JSONResponse: """创建带CORS头的响应""" origin = request.headers.get("origin", "") allowed_origins = settings.CORS_ORIGINS or [] headers = {} if origin in allowed_origins: headers["Access-Control-Allow-Origin"] = origin headers["Access-Control-Allow-Credentials"] = "true" return JSONResponse( status_code=status_code, content={ "success": False, "code": status_code, "message": message, "data": None }, headers=headers ) async def get_current_user(request: Request) -> Dict[str, Any]: """获取当前登录用户信息""" return { "user_id": request.state.user_id, "username": request.state.username, "user_type": request.state.user_type, "student_id": request.state.student_id, "role": request.state.role } async def get_current_user_id(request: Request) -> int: """获取当前用户ID""" return request.state.user_id