Files
SharedClassManager/backend/middleware/auth_middleware.py
2026-04-14 13:42:57 +08:00

143 lines
4.7 KiB
Python

# ===========================================
# 班级操行分管理系统 - 后端服务
#
# 开发者: Canglan
# 联系方式: admin@sea-studio.top
# 版权归属: Sea Network Technology Studio
# 许可证: MIT License
#
# 版权所有 © Sea Network Technology Studio
# ===========================================
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from fastapi.responses import JSONResponse
from typing import Optional, Dict, Any
import re
from utils.jwt_handler import jwt_handler
from utils.redis_client import RedisClient
from utils.response import unauthorized_response
from utils.logger import get_logger
logger = get_logger(__name__)
# 不需要认证的路由
PUBLIC_PATHS = [
r'^/$',
r'^/health$',
r'^/api/auth/login$',
r'^/api/auth/logout$',
r'^/debug/.*$', # 调试入口
]
# 不需要Token验证但需要记录访问的路由
OPEN_PATHS = [
r'^/api/auth/change-password$',
]
def is_public_path(path: str) -> bool:
"""检查是否为公开路径"""
for pattern in PUBLIC_PATHS:
if re.match(pattern, path):
return True
return False
class AuthMiddleware(BaseHTTPMiddleware):
"""JWT认证中间件"""
async def dispatch(self, request: Request, call_next):
# OPTIONS 预检请求跳过认证
if request.method == "OPTIONS":
return await call_next(request)
path = request.url.path
# 公开路径跳过认证
if is_public_path(path):
return await call_next(request)
try:
# 获取Authorization头
auth_header = request.headers.get("Authorization")
if not auth_header:
return self._cors_response(request, 401, "缺少认证令牌")
# 解析Bearer Token
try:
scheme, token = auth_header.split()
if scheme.lower() != "bearer":
return self._cors_response(request, 401, "认证格式错误")
except ValueError:
return self._cors_response(request, 401, "认证格式错误")
# 验证Token
payload = jwt_handler.verify_token(token)
if not payload:
return self._cors_response(request, 401, "令牌无效或已过期")
# 验证Redis中的Token
user_id = payload.get("user_id")
stored_token = await RedisClient.get_user_token(user_id)
if not stored_token or stored_token != token:
return self._cors_response(request, 401, "令牌已失效,请重新登录")
# 将用户信息存储到request.state
request.state.user_id = payload.get("user_id")
request.state.username = payload.get("username")
request.state.user_type = payload.get("user_type")
request.state.student_id = payload.get("student_id")
request.state.role = payload.get("role")
# 刷新Token过期时间
from config import settings
await RedisClient.expire(f"user_token:{user_id}", settings.JWT_EXPIRE_MINUTES * 60)
except Exception as e:
logger.error(f"认证中间件异常: {e}", exc_info=True)
return self._cors_response(request, 401, "认证服务异常,请稍后重试")
return await call_next(request)
def _cors_response(self, request: Request, status_code: int, message: str) -> JSONResponse:
"""创建带CORS头的响应"""
origin = request.headers.get("origin", "")
allowed_origins = ["https://class.sea-studio.top", "https://classbackendapi.sea-studio.top"]
headers = {}
if origin in allowed_origins:
headers["Access-Control-Allow-Origin"] = origin
headers["Access-Control-Allow-Credentials"] = "true"
return JSONResponse(
status_code=status_code,
content={
"success": False,
"code": status_code,
"message": message,
"data": None
},
headers=headers
)
async def get_current_user(request: Request) -> Dict[str, Any]:
"""获取当前登录用户信息"""
return {
"user_id": request.state.user_id,
"username": request.state.username,
"user_type": request.state.user_type,
"student_id": request.state.student_id,
"role": request.state.role
}
async def get_current_user_id(request: Request) -> int:
"""获取当前用户ID"""
return request.state.user_id