回滚bug修复

This commit is contained in:
2026-04-14 13:42:57 +08:00
parent 13917b337e
commit 0f88844e10

View File

@@ -9,13 +9,16 @@
# 版权所有 © Sea Network Technology Studio
# ===========================================
from fastapi import Request
from typing import Dict, Any
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from fastapi.responses import JSONResponse
from typing import Optional, Dict, Any
import re
import json
from utils.jwt_handler import jwt_handler
from utils.redis_client import RedisClient
from utils.response import unauthorized_response
from utils.logger import get_logger
logger = get_logger(__name__)
@@ -43,72 +46,53 @@ def is_public_path(path: str) -> bool:
return False
class AuthMiddleware:
"""JWT认证中间件纯ASGI实现兼容CORS"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
# 从 scope 中获取请求信息
path = scope.get("path", "")
method = scope.get("method", "")
class AuthMiddleware(BaseHTTPMiddleware):
"""JWT认证中间件"""
async def dispatch(self, request: Request, call_next):
# OPTIONS 预检请求跳过认证
if method == "OPTIONS":
await self.app(scope, receive, send)
return
if request.method == "OPTIONS":
return await call_next(request)
path = request.url.path
# 公开路径跳过认证
if is_public_path(path):
await self.app(scope, receive, send)
return
return await call_next(request)
# 从 headers 中获取 Authorization
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode("utf-8") if b"authorization" in headers else None
try:
# 获取Authorization头
auth_header = request.headers.get("Authorization")
if not auth_header:
await self._send_unauthorized(send, "缺少认证令牌")
return
return self._cors_response(request, 401, "缺少认证令牌")
# 解析Bearer Token
try:
scheme, token = auth_header.split()
if scheme.lower() != "bearer":
await self._send_unauthorized(send, "认证格式错误")
return
return self._cors_response(request, 401, "认证格式错误")
except ValueError:
await self._send_unauthorized(send, "认证格式错误")
return
return self._cors_response(request, 401, "认证格式错误")
# 验证Token
try:
payload = jwt_handler.verify_token(token)
if not payload:
await self._send_unauthorized(send, "令牌无效或已过期")
return
return self._cors_response(request, 401, "令牌无效或已过期")
# 验证Redis中的Token
user_id = payload.get("user_id")
stored_token = await RedisClient.get_user_token(user_id)
if not stored_token or stored_token != token:
await self._send_unauthorized(send, "令牌已失效,请重新登录")
return
return self._cors_response(request, 401, "令牌已失效,请重新登录")
# 将用户信息存储到scope的state中request.state兼容)
if "state" not in scope:
scope["state"] = {}
scope["state"]["user_id"] = payload.get("user_id")
scope["state"]["username"] = payload.get("username")
scope["state"]["user_type"] = payload.get("user_type")
scope["state"]["student_id"] = payload.get("student_id")
scope["state"]["role"] = payload.get("role")
# 将用户信息存储到request.state
request.state.user_id = payload.get("user_id")
request.state.username = payload.get("username")
request.state.user_type = payload.get("user_type")
request.state.student_id = payload.get("student_id")
request.state.role = payload.get("role")
# 刷新Token过期时间
from config import settings
@@ -116,32 +100,30 @@ class AuthMiddleware:
except Exception as e:
logger.error(f"认证中间件异常: {e}", exc_info=True)
await self._send_unauthorized(send, "认证服务异常,请稍后重试")
return
return self._cors_response(request, 401, "认证服务异常,请稍后重试")
await self.app(scope, receive, send)
return await call_next(request)
async def _send_unauthorized(self, send, message: str):
"""发送401未授权响应"""
body = json.dumps({
def _cors_response(self, request: Request, status_code: int, message: str) -> JSONResponse:
"""创建带CORS头的响应"""
origin = request.headers.get("origin", "")
allowed_origins = ["https://class.sea-studio.top", "https://classbackendapi.sea-studio.top"]
headers = {}
if origin in allowed_origins:
headers["Access-Control-Allow-Origin"] = origin
headers["Access-Control-Allow-Credentials"] = "true"
return JSONResponse(
status_code=status_code,
content={
"success": False,
"code": 401,
"code": status_code,
"message": message,
"data": None
}).encode("utf-8")
await send({
"type": "http.response.start",
"status": 401,
"headers": [
[b"content-type", b"application/json"],
[b"content-length", str(len(body)).encode()],
],
})
await send({
"type": "http.response.body",
"body": body,
})
},
headers=headers
)
async def get_current_user(request: Request) -> Dict[str, Any]: