Files
ClassManager/backend/utils/database.py
2026-04-09 15:17:47 +08:00

135 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ===========================================
# 班级操行分管理系统 - 数据库连接池
#
# 开发者: Canglan
# 联系方式: admin@sea-studio.top
# 版权归属: Sea Network Technology Studio
# 许可证: MIT License
#
# 版权所有 © Sea Network Technology Studio
# ===========================================
import aiomysql
from typing import Optional, Dict, Any, List
from contextlib import asynccontextmanager
from config import settings
from utils.logger import get_logger
logger = get_logger(__name__)
# 连接池实例
_pool: Optional[aiomysql.Pool] = None
async def init_db_pool() -> None:
"""初始化数据库连接池"""
global _pool
try:
_pool = await aiomysql.create_pool(
host=settings.DB_HOST,
port=settings.DB_PORT,
user=settings.DB_USER,
password=settings.DB_PASSWORD,
db=settings.DB_NAME,
minsize=1,
maxsize=settings.DB_POOL_SIZE,
autocommit=False,
charset='utf8mb4',
cursorclass=aiomysql.DictCursor
)
logger.info("数据库连接池初始化成功")
except Exception as e:
logger.error(f"数据库连接池初始化失败: {e}")
raise
async def close_db_pool() -> None:
"""关闭数据库连接池"""
global _pool
if _pool:
_pool.close()
await _pool.wait_closed()
logger.info("数据库连接池已关闭")
def get_pool() -> aiomysql.Pool:
"""获取连接池实例"""
if _pool is None:
raise RuntimeError("数据库连接池未初始化")
return _pool
@asynccontextmanager
async def get_connection():
"""获取数据库连接(上下文管理器)"""
pool = get_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
yield cursor
await conn.commit()
@asynccontextmanager
async def get_transaction():
"""获取事务连接"""
pool = get_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
try:
yield cursor
await conn.commit()
except Exception:
await conn.rollback()
raise
async def execute_query(sql: str, params: tuple = None) -> List[Dict[str, Any]]:
"""执行查询SQL"""
async with get_connection() as cursor:
await cursor.execute(sql, params)
return await cursor.fetchall()
async def execute_one(sql: str, params: tuple = None) -> Optional[Dict[str, Any]]:
"""执行查询SQL单条"""
async with get_connection() as cursor:
await cursor.execute(sql, params)
return await cursor.fetchone()
async def execute_insert(sql: str, params: tuple = None) -> int:
"""执行插入SQL返回自增ID"""
async with get_connection() as cursor:
await cursor.execute(sql, params)
return cursor.lastrowid
async def execute_update(sql: str, params: tuple = None) -> int:
"""执行更新SQL返回影响行数"""
async with get_connection() as cursor:
result = await cursor.execute(sql, params)
return result
async def execute_many(sql: str, params_list: list) -> int:
"""批量执行SQL"""
async with get_connection() as cursor:
await cursor.executemany(sql, params_list)
return cursor.rowcount
async def call_procedure(proc_name: str, args: tuple = None) -> List[Dict[str, Any]]:
"""调用存储过程"""
async with get_connection() as cursor:
if args:
await cursor.callproc(proc_name, args)
else:
await cursor.callproc(proc_name)
# 获取结果
result = []
for result_set in cursor.fetchall():
if result_set:
result.extend(result_set)
return result