136 lines
3.8 KiB
Python
136 lines
3.8 KiB
Python
# ===========================================
|
||
# 班级操行分管理系统 - 后端服务
|
||
#
|
||
# 开发者: Canglan
|
||
# 联系方式: admin@sea-studio.top
|
||
# 版权归属: Sea Network Technology Studio
|
||
# 许可证: MIT License
|
||
#
|
||
# 版权所有 © Sea Network Technology Studio
|
||
# ===========================================
|
||
|
||
import aiomysql
|
||
from typing import Optional, Dict, Any, List
|
||
from contextlib import asynccontextmanager
|
||
|
||
from config import settings
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 连接池实例
|
||
_pool: Optional[aiomysql.Pool] = None
|
||
|
||
|
||
async def init_db_pool() -> None:
|
||
"""初始化数据库连接池"""
|
||
global _pool
|
||
try:
|
||
_pool = await aiomysql.create_pool(
|
||
host=settings.DB_HOST,
|
||
port=settings.DB_PORT,
|
||
user=settings.DB_USER,
|
||
password=settings.DB_PASSWORD,
|
||
db=settings.DB_NAME,
|
||
minsize=1,
|
||
maxsize=settings.DB_POOL_SIZE,
|
||
maxsize=settings.DB_MAX_OVERFLOW,
|
||
autocommit=False,
|
||
charset='utf8mb4',
|
||
cursorclass=aiomysql.DictCursor
|
||
)
|
||
logger.info("数据库连接池初始化成功")
|
||
except Exception as e:
|
||
logger.error(f"数据库连接池初始化失败: {e}")
|
||
raise
|
||
|
||
|
||
async def close_db_pool() -> None:
|
||
"""关闭数据库连接池"""
|
||
global _pool
|
||
if _pool:
|
||
_pool.close()
|
||
await _pool.wait_closed()
|
||
logger.info("数据库连接池已关闭")
|
||
|
||
|
||
def get_pool() -> aiomysql.Pool:
|
||
"""获取连接池实例"""
|
||
if _pool is None:
|
||
raise RuntimeError("数据库连接池未初始化")
|
||
return _pool
|
||
|
||
|
||
@asynccontextmanager
|
||
async def get_connection():
|
||
"""获取数据库连接(上下文管理器)"""
|
||
pool = get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
yield cursor
|
||
await conn.commit()
|
||
|
||
|
||
@asynccontextmanager
|
||
async def get_transaction():
|
||
"""获取事务连接"""
|
||
pool = get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
try:
|
||
yield cursor
|
||
await conn.commit()
|
||
except Exception:
|
||
await conn.rollback()
|
||
raise
|
||
|
||
|
||
async def execute_query(sql: str, params: tuple = None) -> List[Dict[str, Any]]:
|
||
"""执行查询SQL"""
|
||
async with get_connection() as cursor:
|
||
await cursor.execute(sql, params)
|
||
return await cursor.fetchall()
|
||
|
||
|
||
async def execute_one(sql: str, params: tuple = None) -> Optional[Dict[str, Any]]:
|
||
"""执行查询SQL(单条)"""
|
||
async with get_connection() as cursor:
|
||
await cursor.execute(sql, params)
|
||
return await cursor.fetchone()
|
||
|
||
|
||
async def execute_insert(sql: str, params: tuple = None) -> int:
|
||
"""执行插入SQL,返回自增ID"""
|
||
async with get_connection() as cursor:
|
||
await cursor.execute(sql, params)
|
||
return cursor.lastrowid
|
||
|
||
|
||
async def execute_update(sql: str, params: tuple = None) -> int:
|
||
"""执行更新SQL,返回影响行数"""
|
||
async with get_connection() as cursor:
|
||
result = await cursor.execute(sql, params)
|
||
return result
|
||
|
||
|
||
async def execute_many(sql: str, params_list: list) -> int:
|
||
"""批量执行SQL"""
|
||
async with get_connection() as cursor:
|
||
await cursor.executemany(sql, params_list)
|
||
return cursor.rowcount
|
||
|
||
|
||
async def call_procedure(proc_name: str, args: tuple = None) -> List[Dict[str, Any]]:
|
||
"""调用存储过程"""
|
||
async with get_connection() as cursor:
|
||
if args:
|
||
await cursor.callproc(proc_name, args)
|
||
else:
|
||
await cursor.callproc(proc_name)
|
||
|
||
# 获取结果
|
||
result = []
|
||
for result_set in cursor.fetchall():
|
||
if result_set:
|
||
result.extend(result_set)
|
||
return result |