From 0f88844e10d3d4fdd0b00466e8e7e501244582b7 Mon Sep 17 00:00:00 2001 From: canglan Date: Tue, 14 Apr 2026 13:42:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E6=BB=9Abug=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/middleware/auth_middleware.py | 136 +++++++++++--------------- 1 file changed, 59 insertions(+), 77 deletions(-) diff --git a/backend/middleware/auth_middleware.py b/backend/middleware/auth_middleware.py index 47e4333..487641d 100644 --- a/backend/middleware/auth_middleware.py +++ b/backend/middleware/auth_middleware.py @@ -9,13 +9,16 @@ # 版权所有 © Sea Network Technology Studio # =========================================== -from fastapi import Request -from typing import Dict, Any +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response +from fastapi.responses import JSONResponse +from typing import Optional, Dict, Any import re -import json from utils.jwt_handler import jwt_handler from utils.redis_client import RedisClient +from utils.response import unauthorized_response from utils.logger import get_logger logger = get_logger(__name__) @@ -43,72 +46,53 @@ def is_public_path(path: str) -> bool: return False -class AuthMiddleware: - """JWT认证中间件(纯ASGI实现,兼容CORS)""" +class AuthMiddleware(BaseHTTPMiddleware): + """JWT认证中间件""" - def __init__(self, app): - self.app = app - - async def __call__(self, scope, receive, send): - if scope["type"] not in ("http", "websocket"): - await self.app(scope, receive, send) - return - - # 从 scope 中获取请求信息 - path = scope.get("path", "") - method = scope.get("method", "") - + async def dispatch(self, request: Request, call_next): # OPTIONS 预检请求跳过认证 - if method == "OPTIONS": - await self.app(scope, receive, send) - return + if request.method == "OPTIONS": + return await call_next(request) + + path = request.url.path # 公开路径跳过认证 if is_public_path(path): - await self.app(scope, receive, send) - return + return await call_next(request) - # 从 headers 中获取 Authorization - headers = dict(scope.get("headers", [])) - auth_header = headers.get(b"authorization", b"").decode("utf-8") if b"authorization" in headers else None - - if not auth_header: - await self._send_unauthorized(send, "缺少认证令牌") - return - - # 解析Bearer Token - try: - scheme, token = auth_header.split() - if scheme.lower() != "bearer": - await self._send_unauthorized(send, "认证格式错误") - return - except ValueError: - await self._send_unauthorized(send, "认证格式错误") - return - - # 验证Token try: + # 获取Authorization头 + auth_header = request.headers.get("Authorization") + + if not auth_header: + return self._cors_response(request, 401, "缺少认证令牌") + + # 解析Bearer Token + try: + scheme, token = auth_header.split() + if scheme.lower() != "bearer": + return self._cors_response(request, 401, "认证格式错误") + except ValueError: + return self._cors_response(request, 401, "认证格式错误") + + # 验证Token payload = jwt_handler.verify_token(token) if not payload: - await self._send_unauthorized(send, "令牌无效或已过期") - return + return self._cors_response(request, 401, "令牌无效或已过期") # 验证Redis中的Token user_id = payload.get("user_id") stored_token = await RedisClient.get_user_token(user_id) if not stored_token or stored_token != token: - await self._send_unauthorized(send, "令牌已失效,请重新登录") - return + return self._cors_response(request, 401, "令牌已失效,请重新登录") - # 将用户信息存储到scope的state中(与request.state兼容) - if "state" not in scope: - scope["state"] = {} - scope["state"]["user_id"] = payload.get("user_id") - scope["state"]["username"] = payload.get("username") - scope["state"]["user_type"] = payload.get("user_type") - scope["state"]["student_id"] = payload.get("student_id") - scope["state"]["role"] = payload.get("role") + # 将用户信息存储到request.state + request.state.user_id = payload.get("user_id") + request.state.username = payload.get("username") + request.state.user_type = payload.get("user_type") + request.state.student_id = payload.get("student_id") + request.state.role = payload.get("role") # 刷新Token过期时间 from config import settings @@ -116,32 +100,30 @@ class AuthMiddleware: except Exception as e: logger.error(f"认证中间件异常: {e}", exc_info=True) - await self._send_unauthorized(send, "认证服务异常,请稍后重试") - return + return self._cors_response(request, 401, "认证服务异常,请稍后重试") - await self.app(scope, receive, send) + return await call_next(request) - async def _send_unauthorized(self, send, message: str): - """发送401未授权响应""" - body = json.dumps({ - "success": False, - "code": 401, - "message": message, - "data": None - }).encode("utf-8") + def _cors_response(self, request: Request, status_code: int, message: str) -> JSONResponse: + """创建带CORS头的响应""" + origin = request.headers.get("origin", "") + allowed_origins = ["https://class.sea-studio.top", "https://classbackendapi.sea-studio.top"] - await send({ - "type": "http.response.start", - "status": 401, - "headers": [ - [b"content-type", b"application/json"], - [b"content-length", str(len(body)).encode()], - ], - }) - await send({ - "type": "http.response.body", - "body": body, - }) + headers = {} + if origin in allowed_origins: + headers["Access-Control-Allow-Origin"] = origin + headers["Access-Control-Allow-Credentials"] = "true" + + return JSONResponse( + status_code=status_code, + content={ + "success": False, + "code": status_code, + "message": message, + "data": None + }, + headers=headers + ) async def get_current_user(request: Request) -> Dict[str, Any]: @@ -157,4 +139,4 @@ async def get_current_user(request: Request) -> Dict[str, Any]: async def get_current_user_id(request: Request) -> int: """获取当前用户ID""" - return request.state.user_id \ No newline at end of file + return request.state.user_id