160 lines
5.0 KiB
Python
160 lines
5.0 KiB
Python
# ===========================================
|
||
# 班级操行分管理系统 - 后端服务
|
||
#
|
||
# 开发者: Canglan
|
||
# 联系方式: admin@sea-studio.top
|
||
# 版权归属: Sea Network Technology Studio
|
||
# 许可证: MIT License
|
||
#
|
||
# 版权所有 © Sea Network Technology Studio
|
||
# ===========================================
|
||
|
||
from fastapi import Request
|
||
from typing import Dict, Any
|
||
import re
|
||
import json
|
||
|
||
from utils.jwt_handler import jwt_handler
|
||
from utils.redis_client import RedisClient
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 不需要认证的路由
|
||
PUBLIC_PATHS = [
|
||
r'^/$',
|
||
r'^/health$',
|
||
r'^/api/auth/login$',
|
||
r'^/api/auth/logout$',
|
||
r'^/debug/.*$', # 调试入口
|
||
]
|
||
|
||
# 不需要Token验证但需要记录访问的路由
|
||
OPEN_PATHS = [
|
||
r'^/api/auth/change-password$',
|
||
]
|
||
|
||
|
||
def is_public_path(path: str) -> bool:
|
||
"""检查是否为公开路径"""
|
||
for pattern in PUBLIC_PATHS:
|
||
if re.match(pattern, path):
|
||
return True
|
||
return False
|
||
|
||
|
||
class AuthMiddleware:
|
||
"""JWT认证中间件(纯ASGI实现,兼容CORS)"""
|
||
|
||
def __init__(self, app):
|
||
self.app = app
|
||
|
||
async def __call__(self, scope, receive, send):
|
||
if scope["type"] not in ("http", "websocket"):
|
||
await self.app(scope, receive, send)
|
||
return
|
||
|
||
# 从 scope 中获取请求信息
|
||
path = scope.get("path", "")
|
||
method = scope.get("method", "")
|
||
|
||
# OPTIONS 预检请求跳过认证
|
||
if method == "OPTIONS":
|
||
await self.app(scope, receive, send)
|
||
return
|
||
|
||
# 公开路径跳过认证
|
||
if is_public_path(path):
|
||
await self.app(scope, receive, send)
|
||
return
|
||
|
||
# 从 headers 中获取 Authorization
|
||
headers = dict(scope.get("headers", []))
|
||
auth_header = headers.get(b"authorization", b"").decode("utf-8") if b"authorization" in headers else None
|
||
|
||
if not auth_header:
|
||
await self._send_unauthorized(send, "缺少认证令牌")
|
||
return
|
||
|
||
# 解析Bearer Token
|
||
try:
|
||
scheme, token = auth_header.split()
|
||
if scheme.lower() != "bearer":
|
||
await self._send_unauthorized(send, "认证格式错误")
|
||
return
|
||
except ValueError:
|
||
await self._send_unauthorized(send, "认证格式错误")
|
||
return
|
||
|
||
# 验证Token
|
||
try:
|
||
payload = jwt_handler.verify_token(token)
|
||
if not payload:
|
||
await self._send_unauthorized(send, "令牌无效或已过期")
|
||
return
|
||
|
||
# 验证Redis中的Token
|
||
user_id = payload.get("user_id")
|
||
stored_token = await RedisClient.get_user_token(user_id)
|
||
|
||
if not stored_token or stored_token != token:
|
||
await self._send_unauthorized(send, "令牌已失效,请重新登录")
|
||
return
|
||
|
||
# 将用户信息存储到scope的state中(与request.state兼容)
|
||
if "state" not in scope:
|
||
scope["state"] = {}
|
||
scope["state"]["user_id"] = payload.get("user_id")
|
||
scope["state"]["username"] = payload.get("username")
|
||
scope["state"]["user_type"] = payload.get("user_type")
|
||
scope["state"]["student_id"] = payload.get("student_id")
|
||
scope["state"]["role"] = payload.get("role")
|
||
|
||
# 刷新Token过期时间
|
||
from config import settings
|
||
await RedisClient.expire(f"user_token:{user_id}", settings.JWT_EXPIRE_MINUTES * 60)
|
||
|
||
except Exception as e:
|
||
logger.error(f"认证中间件异常: {e}", exc_info=True)
|
||
await self._send_unauthorized(send, "认证服务异常,请稍后重试")
|
||
return
|
||
|
||
await self.app(scope, receive, send)
|
||
|
||
async def _send_unauthorized(self, send, message: str):
|
||
"""发送401未授权响应"""
|
||
body = json.dumps({
|
||
"success": False,
|
||
"code": 401,
|
||
"message": message,
|
||
"data": None
|
||
}).encode("utf-8")
|
||
|
||
await send({
|
||
"type": "http.response.start",
|
||
"status": 401,
|
||
"headers": [
|
||
[b"content-type", b"application/json"],
|
||
[b"content-length", str(len(body)).encode()],
|
||
],
|
||
})
|
||
await send({
|
||
"type": "http.response.body",
|
||
"body": body,
|
||
})
|
||
|
||
|
||
async def get_current_user(request: Request) -> Dict[str, Any]:
|
||
"""获取当前登录用户信息"""
|
||
return {
|
||
"user_id": request.state.user_id,
|
||
"username": request.state.username,
|
||
"user_type": request.state.user_type,
|
||
"student_id": request.state.student_id,
|
||
"role": request.state.role
|
||
}
|
||
|
||
|
||
async def get_current_user_id(request: Request) -> int:
|
||
"""获取当前用户ID"""
|
||
return request.state.user_id |