Files
ClassManager/backend/middleware/auth_middleware.py
2026-04-14 13:34:48 +08:00

160 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ===========================================
# 班级操行分管理系统 - 后端服务
#
# 开发者: Canglan
# 联系方式: admin@sea-studio.top
# 版权归属: Sea Network Technology Studio
# 许可证: MIT License
#
# 版权所有 © Sea Network Technology Studio
# ===========================================
from fastapi import Request
from typing import Dict, Any
import re
import json
from utils.jwt_handler import jwt_handler
from utils.redis_client import RedisClient
from utils.logger import get_logger
logger = get_logger(__name__)
# 不需要认证的路由
PUBLIC_PATHS = [
r'^/$',
r'^/health$',
r'^/api/auth/login$',
r'^/api/auth/logout$',
r'^/debug/.*$', # 调试入口
]
# 不需要Token验证但需要记录访问的路由
OPEN_PATHS = [
r'^/api/auth/change-password$',
]
def is_public_path(path: str) -> bool:
"""检查是否为公开路径"""
for pattern in PUBLIC_PATHS:
if re.match(pattern, path):
return True
return False
class AuthMiddleware:
"""JWT认证中间件纯ASGI实现兼容CORS"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
# 从 scope 中获取请求信息
path = scope.get("path", "")
method = scope.get("method", "")
# OPTIONS 预检请求跳过认证
if method == "OPTIONS":
await self.app(scope, receive, send)
return
# 公开路径跳过认证
if is_public_path(path):
await self.app(scope, receive, send)
return
# 从 headers 中获取 Authorization
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode("utf-8") if b"authorization" in headers else None
if not auth_header:
await self._send_unauthorized(send, "缺少认证令牌")
return
# 解析Bearer Token
try:
scheme, token = auth_header.split()
if scheme.lower() != "bearer":
await self._send_unauthorized(send, "认证格式错误")
return
except ValueError:
await self._send_unauthorized(send, "认证格式错误")
return
# 验证Token
try:
payload = jwt_handler.verify_token(token)
if not payload:
await self._send_unauthorized(send, "令牌无效或已过期")
return
# 验证Redis中的Token
user_id = payload.get("user_id")
stored_token = await RedisClient.get_user_token(user_id)
if not stored_token or stored_token != token:
await self._send_unauthorized(send, "令牌已失效,请重新登录")
return
# 将用户信息存储到scope的state中与request.state兼容
if "state" not in scope:
scope["state"] = {}
scope["state"]["user_id"] = payload.get("user_id")
scope["state"]["username"] = payload.get("username")
scope["state"]["user_type"] = payload.get("user_type")
scope["state"]["student_id"] = payload.get("student_id")
scope["state"]["role"] = payload.get("role")
# 刷新Token过期时间
from config import settings
await RedisClient.expire(f"user_token:{user_id}", settings.JWT_EXPIRE_MINUTES * 60)
except Exception as e:
logger.error(f"认证中间件异常: {e}", exc_info=True)
await self._send_unauthorized(send, "认证服务异常,请稍后重试")
return
await self.app(scope, receive, send)
async def _send_unauthorized(self, send, message: str):
"""发送401未授权响应"""
body = json.dumps({
"success": False,
"code": 401,
"message": message,
"data": None
}).encode("utf-8")
await send({
"type": "http.response.start",
"status": 401,
"headers": [
[b"content-type", b"application/json"],
[b"content-length", str(len(body)).encode()],
],
})
await send({
"type": "http.response.body",
"body": body,
})
async def get_current_user(request: Request) -> Dict[str, Any]:
"""获取当前登录用户信息"""
return {
"user_id": request.state.user_id,
"username": request.state.username,
"user_type": request.state.user_type,
"student_id": request.state.student_id,
"role": request.state.role
}
async def get_current_user_id(request: Request) -> int:
"""获取当前用户ID"""
return request.state.user_id