# =========================================== # 班级操行分管理系统 - 数据库连接池 # # 开发者: Canglan # 联系方式: admin@sea-studio.top # 版权归属: Sea Network Technology Studio # 许可证: MIT License # # 版权所有 © Sea Network Technology Studio # =========================================== import aiomysql from typing import Optional, Dict, Any, List from contextlib import asynccontextmanager from config import settings from utils.logger import get_logger logger = get_logger(__name__) # 连接池实例 _pool: Optional[aiomysql.Pool] = None async def init_db_pool() -> None: """初始化数据库连接池""" global _pool try: _pool = await aiomysql.create_pool( host=settings.DB_HOST, port=settings.DB_PORT, user=settings.DB_USER, password=settings.DB_PASSWORD, db=settings.DB_NAME, minsize=1, maxsize=settings.DB_POOL_SIZE, autocommit=False, charset='utf8mb4', cursorclass=aiomysql.DictCursor ) logger.info("数据库连接池初始化成功") except Exception as e: logger.error(f"数据库连接池初始化失败: {e}") raise async def close_db_pool() -> None: """关闭数据库连接池""" global _pool if _pool: _pool.close() await _pool.wait_closed() logger.info("数据库连接池已关闭") def get_pool() -> aiomysql.Pool: """获取连接池实例""" if _pool is None: raise RuntimeError("数据库连接池未初始化") return _pool @asynccontextmanager async def get_connection(): """获取数据库连接(上下文管理器)""" pool = get_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: yield cursor await conn.commit() @asynccontextmanager async def get_transaction(): """获取事务连接""" pool = get_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: try: yield cursor await conn.commit() except Exception: await conn.rollback() raise async def execute_query(sql: str, params: tuple = None) -> List[Dict[str, Any]]: """执行查询SQL""" async with get_connection() as cursor: await cursor.execute(sql, params) return await cursor.fetchall() async def execute_one(sql: str, params: tuple = None) -> Optional[Dict[str, Any]]: """执行查询SQL(单条)""" async with get_connection() as cursor: await cursor.execute(sql, params) return await cursor.fetchone() async def execute_insert(sql: str, params: tuple = None) -> int: """执行插入SQL,返回自增ID""" async with get_connection() as cursor: await cursor.execute(sql, params) return cursor.lastrowid async def execute_update(sql: str, params: tuple = None) -> int: """执行更新SQL,返回影响行数""" async with get_connection() as cursor: result = await cursor.execute(sql, params) return result async def execute_many(sql: str, params_list: list) -> int: """批量执行SQL""" async with get_connection() as cursor: await cursor.executemany(sql, params_list) return cursor.rowcount async def call_procedure(proc_name: str, args: tuple = None) -> List[Dict[str, Any]]: """调用存储过程""" async with get_connection() as cursor: if args: await cursor.callproc(proc_name, args) else: await cursor.callproc(proc_name) # 获取结果 result = [] for result_set in cursor.fetchall(): if result_set: result.extend(result_set) return result