v0.1测试
This commit is contained in:
11
backend/middleware/__init__.py
Normal file
11
backend/middleware/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# ===========================================
|
||||
# 班级操行分管理系统 - 后端服务
|
||||
#
|
||||
# 开发者: Canglan
|
||||
# 联系方式: admin@sea-studio.top
|
||||
# 版权归属: Sea Network Technology Studio
|
||||
# 许可证: MIT License
|
||||
#
|
||||
# 版权所有 © Sea Network Technology Studio
|
||||
# ===========================================
|
||||
|
||||
111
backend/middleware/auth_middleware.py
Normal file
111
backend/middleware/auth_middleware.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# ===========================================
|
||||
# 班级操行分管理系统 - 后端服务
|
||||
#
|
||||
# 开发者: Canglan
|
||||
# 联系方式: admin@sea-studio.top
|
||||
# 版权归属: Sea Network Technology Studio
|
||||
# 许可证: MIT License
|
||||
#
|
||||
# 版权所有 © Sea Network Technology Studio
|
||||
# ===========================================
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
import re
|
||||
|
||||
from utils.jwt_handler import jwt_handler
|
||||
from utils.redis_client import RedisClient
|
||||
from utils.response import unauthorized_response
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 不需要认证的路由
|
||||
PUBLIC_PATHS = [
|
||||
r'^/$',
|
||||
r'^/health$',
|
||||
r'^/api/auth/login$',
|
||||
r'^/api/auth/logout$',
|
||||
r'^/debug/.*$', # 调试入口
|
||||
]
|
||||
|
||||
# 不需要Token验证但需要记录访问的路由
|
||||
OPEN_PATHS = [
|
||||
r'^/api/auth/change-password$',
|
||||
]
|
||||
|
||||
|
||||
def is_public_path(path: str) -> bool:
|
||||
"""检查是否为公开路径"""
|
||||
for pattern in PUBLIC_PATHS:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""JWT认证中间件"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
path = request.url.path
|
||||
|
||||
# 公开路径跳过认证
|
||||
if is_public_path(path):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取Authorization头
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
||||
if not auth_header:
|
||||
return unauthorized_response("缺少认证令牌")
|
||||
|
||||
# 解析Bearer Token
|
||||
try:
|
||||
scheme, token = auth_header.split()
|
||||
if scheme.lower() != "bearer":
|
||||
return unauthorized_response("认证格式错误")
|
||||
except ValueError:
|
||||
return unauthorized_response("认证格式错误")
|
||||
|
||||
# 验证Token
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
return unauthorized_response("令牌无效或已过期")
|
||||
|
||||
# 验证Redis中的Token
|
||||
user_id = payload.get("user_id")
|
||||
stored_token = await RedisClient.get_user_token(user_id)
|
||||
|
||||
if not stored_token or stored_token != token:
|
||||
return unauthorized_response("令牌已失效,请重新登录")
|
||||
|
||||
# 将用户信息存储到request.state
|
||||
request.state.user_id = payload.get("user_id")
|
||||
request.state.username = payload.get("username")
|
||||
request.state.user_type = payload.get("user_type")
|
||||
request.state.student_id = payload.get("student_id")
|
||||
request.state.role = payload.get("role")
|
||||
|
||||
# 刷新Token过期时间
|
||||
from config import settings
|
||||
await RedisClient.expire(f"user_token:{user_id}", settings.JWT_EXPIRE_MINUTES * 60)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> Dict[str, Any]:
|
||||
"""获取当前登录用户信息"""
|
||||
return {
|
||||
"user_id": request.state.user_id,
|
||||
"username": request.state.username,
|
||||
"user_type": request.state.user_type,
|
||||
"student_id": request.state.student_id,
|
||||
"role": request.state.role
|
||||
}
|
||||
|
||||
|
||||
async def get_current_user_id(request: Request) -> int:
|
||||
"""获取当前用户ID"""
|
||||
return request.state.user_id
|
||||
197
backend/middleware/permission.py
Normal file
197
backend/middleware/permission.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# ===========================================
|
||||
# 班级操行分管理系统 - 后端服务
|
||||
#
|
||||
# 开发者: Canglan
|
||||
# 联系方式: admin@sea-studio.top
|
||||
# 版权归属: Sea Network Technology Studio
|
||||
# 许可证: MIT License
|
||||
#
|
||||
# 版权所有 © Sea Network Technology Studio
|
||||
# ===========================================
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
from typing import List, Optional, Callable
|
||||
from functools import wraps
|
||||
|
||||
from utils.response import forbidden_response
|
||||
from utils.database import execute_one
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PermissionChecker:
|
||||
"""权限检查器"""
|
||||
|
||||
@staticmethod
|
||||
async def get_user_role(user_id: int) -> Optional[str]:
|
||||
"""获取用户的管理员角色"""
|
||||
sql = """
|
||||
SELECT role_type FROM admin_roles
|
||||
WHERE user_id = %s
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await execute_one(sql, (user_id,))
|
||||
return result["role_type"] if result else None
|
||||
|
||||
@staticmethod
|
||||
async def get_user_class_id(user_id: int) -> Optional[int]:
|
||||
"""获取用户管理的班级ID"""
|
||||
sql = """
|
||||
SELECT class_id FROM admin_roles
|
||||
WHERE user_id = %s
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await execute_one(sql, (user_id,))
|
||||
return result["class_id"] if result else None
|
||||
|
||||
@staticmethod
|
||||
async def get_user_subject_ids(user_id: int) -> List[int]:
|
||||
"""获取科代表管理的科目ID列表"""
|
||||
sql = """
|
||||
SELECT subject_id FROM admin_roles
|
||||
WHERE user_id = %s AND role_type = '科代表'
|
||||
"""
|
||||
results = await execute_one(sql, (user_id,))
|
||||
if results:
|
||||
return [r["subject_id"] for r in results] if isinstance(results, list) else [results["subject_id"]]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def check_is_teacher(user_id: int) -> bool:
|
||||
"""检查是否为班主任"""
|
||||
role = await PermissionChecker.get_user_role(user_id)
|
||||
return role == "班主任"
|
||||
|
||||
@staticmethod
|
||||
async def check_is_monitor(user_id: int) -> bool:
|
||||
"""检查是否为班长"""
|
||||
role = await PermissionChecker.get_user_role(user_id)
|
||||
return role == "班长"
|
||||
|
||||
@staticmethod
|
||||
async def check_is_subject_rep(user_id: int, subject_id: int = None) -> bool:
|
||||
"""检查是否为科代表"""
|
||||
role = await PermissionChecker.get_user_role(user_id)
|
||||
if role != "科代表":
|
||||
return False
|
||||
if subject_id:
|
||||
subject_ids = await PermissionChecker.get_user_subject_ids(user_id)
|
||||
return subject_id in subject_ids
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def check_is_attendance_rep(user_id: int) -> bool:
|
||||
"""检查是否为考勤委员"""
|
||||
role = await PermissionChecker.get_user_role(user_id)
|
||||
return role == "考勤委员"
|
||||
|
||||
@staticmethod
|
||||
async def check_is_labor_rep(user_id: int) -> bool:
|
||||
"""检查是否为劳动委员"""
|
||||
role = await PermissionChecker.get_user_role(user_id)
|
||||
return role == "劳动委员"
|
||||
|
||||
@staticmethod
|
||||
async def check_can_revoke(user_id: int, record_id: int) -> bool:
|
||||
"""
|
||||
检查是否可以撤销扣分记录
|
||||
班主任:可以撤销任何记录
|
||||
班长:可以撤销任何记录
|
||||
其他:只能撤销自己的记录
|
||||
"""
|
||||
# 获取记录信息
|
||||
sql = "SELECT recorder_id FROM conduct_records WHERE record_id = %s"
|
||||
record = await execute_one(sql, (record_id,))
|
||||
if not record:
|
||||
return False
|
||||
|
||||
role = await PermissionChecker.get_user_role(user_id)
|
||||
|
||||
# 班主任或班长可以撤销任何记录
|
||||
if role in ["班主任", "班长"]:
|
||||
return True
|
||||
|
||||
# 其他人只能撤销自己的记录
|
||||
return record["recorder_id"] == user_id
|
||||
|
||||
@staticmethod
|
||||
async def check_can_manage_student(user_id: int, student_id: int) -> bool:
|
||||
"""检查是否可以管理该学生(同班级)"""
|
||||
# 获取学生班级
|
||||
sql = "SELECT class_id FROM students WHERE student_id = %s"
|
||||
student = await execute_one(sql, (student_id,))
|
||||
if not student:
|
||||
return False
|
||||
|
||||
# 获取管理员管理的班级
|
||||
admin_class_id = await PermissionChecker.get_user_class_id(user_id)
|
||||
|
||||
return admin_class_id == student["class_id"]
|
||||
|
||||
|
||||
def require_auth(func: Callable):
|
||||
"""需要认证的装饰器"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = kwargs.get('request')
|
||||
if not request or not hasattr(request, 'state') or not hasattr(request.state, 'user_id'):
|
||||
return forbidden_response("请先登录")
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_role(roles: List[str]):
|
||||
"""需要特定角色的装饰器"""
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = kwargs.get('request')
|
||||
if not request or not hasattr(request, 'state'):
|
||||
return forbidden_response("请先登录")
|
||||
|
||||
user_id = request.state.user_id
|
||||
user_role = await PermissionChecker.get_user_role(user_id)
|
||||
|
||||
if user_role not in roles:
|
||||
return forbidden_response(f"需要{','.join(roles)}权限")
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def require_teacher(func: Callable):
|
||||
"""需要班主任权限的装饰器"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = kwargs.get('request')
|
||||
if not request or not hasattr(request, 'state'):
|
||||
return forbidden_response("请先登录")
|
||||
|
||||
user_id = request.state.user_id
|
||||
is_teacher = await PermissionChecker.check_is_teacher(user_id)
|
||||
|
||||
if not is_teacher:
|
||||
return forbidden_response("需要班主任权限")
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_monitor(func: Callable):
|
||||
"""需要班长权限的装饰器"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = kwargs.get('request')
|
||||
if not request or not hasattr(request, 'state'):
|
||||
return forbidden_response("请先登录")
|
||||
|
||||
user_id = request.state.user_id
|
||||
is_monitor = await PermissionChecker.check_is_monitor(user_id)
|
||||
|
||||
if not is_monitor:
|
||||
return forbidden_response("需要班长权限")
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
121
backend/middleware/sanitize.py
Normal file
121
backend/middleware/sanitize.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# ===========================================
|
||||
# 班级操行分管理系统 - 后端服务
|
||||
#
|
||||
# 开发者: Canglan
|
||||
# 联系方式: admin@sea-studio.top
|
||||
# 版权归属: Sea Network Technology Studio
|
||||
# 许可证: MIT License
|
||||
#
|
||||
# 版权所有 © Sea Network Technology Studio
|
||||
# ===========================================
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Dict, Any
|
||||
import re
|
||||
|
||||
|
||||
class SanitizeMiddleware(BaseHTTPMiddleware):
|
||||
"""输入过滤中间件"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# 只处理POST、PUT、PATCH请求
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
# 获取请求体
|
||||
body = await request.body()
|
||||
if body:
|
||||
import json
|
||||
try:
|
||||
data = json.loads(body)
|
||||
# 清理数据
|
||||
cleaned_data = self._sanitize_data(data)
|
||||
# 替换请求体
|
||||
request._body = json.dumps(cleaned_data).encode()
|
||||
except:
|
||||
pass
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
def _sanitize_data(self, data: Any) -> Any:
|
||||
"""递归清理数据"""
|
||||
if isinstance(data, dict):
|
||||
return {k: self._sanitize_data(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [self._sanitize_data(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
return self._sanitize_string(data)
|
||||
else:
|
||||
return data
|
||||
|
||||
def _sanitize_string(self, value: str) -> str:
|
||||
"""清理字符串"""
|
||||
if not value:
|
||||
return ""
|
||||
|
||||
# 去除首尾空格
|
||||
value = value.strip()
|
||||
|
||||
# 限制长度
|
||||
if len(value) > 1000:
|
||||
value = value[:1000]
|
||||
|
||||
# 转义HTML特殊字符
|
||||
html_chars = {
|
||||
'&': '&',
|
||||
'<': '<',
|
||||
'>': '>',
|
||||
'"': '"',
|
||||
"'": ''',
|
||||
'/': '/'
|
||||
}
|
||||
for char, escape in html_chars.items():
|
||||
value = value.replace(char, escape)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def sanitize_input(value: str, max_length: int = 255) -> str:
|
||||
"""清理单个输入值"""
|
||||
if not value:
|
||||
return ""
|
||||
|
||||
value = value.strip()
|
||||
if len(value) > max_length:
|
||||
value = value[:max_length]
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_points(points: int, min_val: int = -100, max_val: int = 100) -> tuple:
|
||||
"""
|
||||
验证分值
|
||||
返回: (是否有效, 错误信息)
|
||||
"""
|
||||
if points == 0:
|
||||
return False, "分值不能为0"
|
||||
if points < min_val or points > max_val:
|
||||
return False, f"分值必须在{min_val}到{max_val}之间"
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_reason(reason: str) -> tuple:
|
||||
"""
|
||||
验证原因
|
||||
返回: (是否有效, 错误信息)
|
||||
"""
|
||||
if not reason or not reason.strip():
|
||||
return False, "原因不能为空"
|
||||
if len(reason) > 255:
|
||||
return False, "原因长度不能超过255个字符"
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_date(date_str: str) -> bool:
|
||||
"""验证日期格式 YYYY-MM-DD"""
|
||||
if not date_str:
|
||||
return False
|
||||
pattern = r'^\d{4}-\d{2}-\d{2}$'
|
||||
if not re.match(pattern, date_str):
|
||||
return False
|
||||
return True
|
||||
Reference in New Issue
Block a user