# =========================================== # 班级操行分管理系统 - 后端服务 # # 开发者: Canglan # 联系方式: admin@sea-studio.top # 版权归属: Sea Network Technology Studio # 许可证: MIT License # # 版权所有 © Sea Network Technology Studio # =========================================== import redis.asyncio as redis from typing import Optional, Any import json from config import settings from utils.logger import get_logger logger = get_logger(__name__) # Redis客户端实例 _redis_client: Optional[redis.Redis] = None async def init_redis_pool() -> None: """初始化Redis连接池""" global _redis_client try: _redis_client = redis.from_url( settings.REDIS_URL, max_connections=settings.REDIS_MAX_CONNECTIONS, decode_responses=True ) # 测试连接 await _redis_client.ping() logger.info("Redis连接池初始化成功") except Exception as e: logger.error(f"Redis连接池初始化失败: {e}") raise async def close_redis_pool() -> None: """关闭Redis连接池""" global _redis_client if _redis_client: await _redis_client.close() logger.info("Redis连接池已关闭") def get_redis() -> redis.Redis: """获取Redis客户端""" if _redis_client is None: raise RuntimeError("Redis客户端未初始化") return _redis_client class RedisClient: """Redis操作封装类""" @staticmethod async def set(key: str, value: Any, expire: int = None) -> bool: """设置缓存""" client = get_redis() if isinstance(value, (dict, list)): value = json.dumps(value, ensure_ascii=False) else: value = str(value) if expire: return await client.setex(key, expire, value) return await client.set(key, value) @staticmethod async def get(key: str) -> Optional[str]: """获取缓存""" client = get_redis() return await client.get(key) @staticmethod async def get_json(key: str) -> Optional[Any]: """获取JSON格式缓存""" value = await RedisClient.get(key) if value: try: return json.loads(value) except json.JSONDecodeError: return value return None @staticmethod async def delete(key: str) -> int: """删除缓存""" client = get_redis() return await client.delete(key) @staticmethod async def exists(key: str) -> bool: """检查key是否存在""" client = get_redis() return await client.exists(key) > 0 @staticmethod async def expire(key: str, seconds: int) -> bool: """设置过期时间""" client = get_redis() return await client.expire(key, seconds) @staticmethod async def set_user_token(user_id: int, token: str, expire: int = None) -> bool: """设置用户Token缓存""" key = f"user_token:{user_id}" expire = expire or settings.JWT_EXPIRE_MINUTES * 60 return await RedisClient.set(key, token, expire) @staticmethod async def get_user_token(user_id: int) -> Optional[str]: """获取用户Token""" key = f"user_token:{user_id}" return await RedisClient.get(key) @staticmethod async def delete_user_token(user_id: int) -> int: """删除用户Token""" key = f"user_token:{user_id}" return await RedisClient.delete(key) @staticmethod async def set_login_attempts(username: str) -> int: """记录登录失败次数""" key = f"login_attempts:{username}" attempts = await RedisClient.get(key) attempts = int(attempts) + 1 if attempts else 1 await RedisClient.set(key, attempts, 300) # 5分钟锁定 return attempts @staticmethod async def clear_login_attempts(username: str) -> None: """清除登录失败记录""" key = f"login_attempts:{username}" await RedisClient.delete(key)