回滚bug修复

This commit is contained in:
2026-04-14 13:42:57 +08:00
parent 13917b337e
commit 0f88844e10

View File

@@ -9,13 +9,16 @@
# 版权所有 © Sea Network Technology Studio # 版权所有 © Sea Network Technology Studio
# =========================================== # ===========================================
from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware
from typing import Dict, Any from starlette.requests import Request
from starlette.responses import Response
from fastapi.responses import JSONResponse
from typing import Optional, Dict, Any
import re import re
import json
from utils.jwt_handler import jwt_handler from utils.jwt_handler import jwt_handler
from utils.redis_client import RedisClient from utils.redis_client import RedisClient
from utils.response import unauthorized_response
from utils.logger import get_logger from utils.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -43,72 +46,53 @@ def is_public_path(path: str) -> bool:
return False return False
class AuthMiddleware: class AuthMiddleware(BaseHTTPMiddleware):
"""JWT认证中间件纯ASGI实现兼容CORS""" """JWT认证中间件"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
# 从 scope 中获取请求信息
path = scope.get("path", "")
method = scope.get("method", "")
async def dispatch(self, request: Request, call_next):
# OPTIONS 预检请求跳过认证 # OPTIONS 预检请求跳过认证
if method == "OPTIONS": if request.method == "OPTIONS":
await self.app(scope, receive, send) return await call_next(request)
return
path = request.url.path
# 公开路径跳过认证 # 公开路径跳过认证
if is_public_path(path): if is_public_path(path):
await self.app(scope, receive, send) return await call_next(request)
return
# 从 headers 中获取 Authorization
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode("utf-8") if b"authorization" in headers else None
if not auth_header:
await self._send_unauthorized(send, "缺少认证令牌")
return
# 解析Bearer Token
try: try:
scheme, token = auth_header.split() # 获取Authorization头
if scheme.lower() != "bearer": auth_header = request.headers.get("Authorization")
await self._send_unauthorized(send, "认证格式错误")
return
except ValueError:
await self._send_unauthorized(send, "认证格式错误")
return
# 验证Token if not auth_header:
try: return self._cors_response(request, 401, "缺少认证令牌")
# 解析Bearer Token
try:
scheme, token = auth_header.split()
if scheme.lower() != "bearer":
return self._cors_response(request, 401, "认证格式错误")
except ValueError:
return self._cors_response(request, 401, "认证格式错误")
# 验证Token
payload = jwt_handler.verify_token(token) payload = jwt_handler.verify_token(token)
if not payload: if not payload:
await self._send_unauthorized(send, "令牌无效或已过期") return self._cors_response(request, 401, "令牌无效或已过期")
return
# 验证Redis中的Token # 验证Redis中的Token
user_id = payload.get("user_id") user_id = payload.get("user_id")
stored_token = await RedisClient.get_user_token(user_id) stored_token = await RedisClient.get_user_token(user_id)
if not stored_token or stored_token != token: if not stored_token or stored_token != token:
await self._send_unauthorized(send, "令牌已失效,请重新登录") return self._cors_response(request, 401, "令牌已失效,请重新登录")
return
# 将用户信息存储到scope的state中request.state兼容) # 将用户信息存储到request.state
if "state" not in scope: request.state.user_id = payload.get("user_id")
scope["state"] = {} request.state.username = payload.get("username")
scope["state"]["user_id"] = payload.get("user_id") request.state.user_type = payload.get("user_type")
scope["state"]["username"] = payload.get("username") request.state.student_id = payload.get("student_id")
scope["state"]["user_type"] = payload.get("user_type") request.state.role = payload.get("role")
scope["state"]["student_id"] = payload.get("student_id")
scope["state"]["role"] = payload.get("role")
# 刷新Token过期时间 # 刷新Token过期时间
from config import settings from config import settings
@@ -116,32 +100,30 @@ class AuthMiddleware:
except Exception as e: except Exception as e:
logger.error(f"认证中间件异常: {e}", exc_info=True) logger.error(f"认证中间件异常: {e}", exc_info=True)
await self._send_unauthorized(send, "认证服务异常,请稍后重试") return self._cors_response(request, 401, "认证服务异常,请稍后重试")
return
await self.app(scope, receive, send) return await call_next(request)
async def _send_unauthorized(self, send, message: str): def _cors_response(self, request: Request, status_code: int, message: str) -> JSONResponse:
"""发送401未授权响应""" """创建带CORS头的响应"""
body = json.dumps({ origin = request.headers.get("origin", "")
"success": False, allowed_origins = ["https://class.sea-studio.top", "https://classbackendapi.sea-studio.top"]
"code": 401,
"message": message,
"data": None
}).encode("utf-8")
await send({ headers = {}
"type": "http.response.start", if origin in allowed_origins:
"status": 401, headers["Access-Control-Allow-Origin"] = origin
"headers": [ headers["Access-Control-Allow-Credentials"] = "true"
[b"content-type", b"application/json"],
[b"content-length", str(len(body)).encode()], return JSONResponse(
], status_code=status_code,
}) content={
await send({ "success": False,
"type": "http.response.body", "code": status_code,
"body": body, "message": message,
}) "data": None
},
headers=headers
)
async def get_current_user(request: Request) -> Dict[str, Any]: async def get_current_user(request: Request) -> Dict[str, Any]: