355 lines
13 KiB
Python
355 lines
13 KiB
Python
# ===========================================
|
||
# 班级操行分管理系统 - 升级管理路由
|
||
#
|
||
# 开发者: Canglan
|
||
# 版权归属: Sea Network Technology Studio
|
||
#
|
||
# 版权所有 © Sea Network Technology Studio
|
||
# ===========================================
|
||
|
||
from fastapi import APIRouter, Request
|
||
from utils.database import execute_query, execute_update, get_pool
|
||
from utils.response import success_response, error_response
|
||
from utils.logger import setup_logger
|
||
from middleware.permission import PermissionChecker
|
||
import os
|
||
import re
|
||
|
||
logger = setup_logger()
|
||
router = APIRouter()
|
||
|
||
# 版本列表(按顺序)
|
||
ALL_VERSIONS = {
|
||
'1.0': 'v1.0.sql',
|
||
'1.1': 'v1.1.sql',
|
||
'1.2': 'v1.2.sql',
|
||
'1.3': 'v1.3.sql',
|
||
'1.4': 'v1.4.sql',
|
||
'1.5': 'v1.5.sql',
|
||
'1.6': 'v1.6.sql',
|
||
'1.7': 'v1.7.sql',
|
||
'1.8': 'v1.8.sql',
|
||
'2.0': 'v2.0.sql',
|
||
'2.0.1': 'v2.0.1.sql',
|
||
'2.1': 'v2.1.sql',
|
||
'2.2': 'v2.2.sql',
|
||
'2.3': 'v2.3.sql',
|
||
'2.3': 'v2.3.sql',
|
||
'2.4': 'v2.4.sql',
|
||
'2.5': 'v2.5.sql',
|
||
}
|
||
|
||
# 版本特征标记(按优先级从高到低)
|
||
VERSION_MARKERS = [
|
||
('2.0', 'students', 'dormitory_number'),
|
||
('1.8', 'conduct_records', 'related_type'),
|
||
('1.7', 'subjects', 'sort_order'),
|
||
]
|
||
|
||
|
||
async def _detect_current_version() -> str:
|
||
"""检测当前数据库版本,优先从 system_settings 读取,否则通过列特征推断"""
|
||
# 1. 尝试从 system_settings 读取 db_version
|
||
try:
|
||
row = await execute_query(
|
||
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
|
||
)
|
||
if row:
|
||
return row[0]['setting_value']
|
||
except Exception as e:
|
||
logger.warning(f"查询 system_settings 表失败,将通过列特征推断版本: {e}")
|
||
|
||
# 2. 通过列特征推断版本
|
||
inferred_version = '1.0'
|
||
for version, table, column in VERSION_MARKERS:
|
||
try:
|
||
result = await execute_query(
|
||
"SELECT COUNT(*) as cnt FROM INFORMATION_SCHEMA.COLUMNS "
|
||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s AND COLUMN_NAME = %s",
|
||
(table, column)
|
||
)
|
||
if result and result[0]['cnt'] > 0:
|
||
inferred_version = version
|
||
break
|
||
except Exception as e:
|
||
logger.warning(f"检查列特征失败 ({table}.{column}): {e}")
|
||
|
||
logger.info(f"通过列特征推断数据库版本为: {inferred_version}")
|
||
|
||
# 3. 确保 system_settings 表存在并写入推断版本
|
||
try:
|
||
await execute_update(
|
||
"CREATE TABLE IF NOT EXISTS `system_settings` ("
|
||
"`setting_key` VARCHAR(50) PRIMARY KEY,"
|
||
"`setting_value` VARCHAR(255) NOT NULL,"
|
||
"`updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
|
||
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci"
|
||
)
|
||
await execute_update(
|
||
"INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) "
|
||
"ON DUPLICATE KEY UPDATE setting_value = %s",
|
||
(inferred_version, inferred_version)
|
||
)
|
||
logger.info(f"已将推断版本 {inferred_version} 写入 system_settings")
|
||
except Exception as e:
|
||
logger.error(f"写入推断版本失败: {e}")
|
||
|
||
return inferred_version
|
||
|
||
|
||
@router.get("/check")
|
||
async def check_upgrade(request: Request):
|
||
"""检查数据库版本是否需要升级"""
|
||
# 权限检查:仅班主任可执行升级操作
|
||
user_type = getattr(request.state, 'user_type', None)
|
||
if user_type != 'admin':
|
||
return error_response(message="仅管理员可执行升级操作", code=403)
|
||
|
||
is_teacher = await PermissionChecker.check_is_teacher(
|
||
getattr(request.state, 'user_id', 0)
|
||
)
|
||
if not is_teacher:
|
||
return error_response(message="仅班主任可执行升级操作", code=403)
|
||
|
||
# 检测当前数据库版本(支持自动推断)
|
||
current_version = await _detect_current_version()
|
||
|
||
# 读取目标版本(从 VERSION 文件)
|
||
version_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'VERSION')
|
||
version_file = os.path.normpath(version_file)
|
||
target_version = '0.0.0'
|
||
try:
|
||
if os.path.exists(version_file):
|
||
with open(version_file, 'r') as f:
|
||
target_version = f.read().strip()
|
||
except Exception:
|
||
pass
|
||
|
||
# 计算需要升级的步骤
|
||
needs_upgrade = _compare_versions(target_version, current_version) > 0
|
||
|
||
steps = []
|
||
for version, file_name in sorted(ALL_VERSIONS.items(), key=lambda x: _version_tuple(x[0])):
|
||
if _compare_versions(version, current_version) > 0 and _compare_versions(version, target_version) <= 0:
|
||
steps.append({'version': version, 'file': file_name})
|
||
|
||
return success_response(data={
|
||
'needs_upgrade': needs_upgrade,
|
||
'current': current_version,
|
||
'target': target_version,
|
||
'steps': steps
|
||
})
|
||
|
||
|
||
async def _verify_upgrade(expected_version: str) -> dict:
|
||
"""验证升级结果:检查版本号是否已正确更新
|
||
|
||
Returns:
|
||
{'ok': bool, 'message': str}
|
||
"""
|
||
try:
|
||
row = await execute_query(
|
||
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
|
||
)
|
||
if not row:
|
||
return {'ok': False, 'message': 'db_version 记录不存在'}
|
||
actual = row[0]['setting_value']
|
||
if actual != expected_version:
|
||
return {'ok': False, 'message': f'版本号不匹配:期望 {expected_version},实际 {actual}'}
|
||
return {'ok': True, 'message': '验证通过'}
|
||
except Exception as e:
|
||
return {'ok': False, 'message': f'验证查询失败: {str(e)}'}
|
||
|
||
|
||
MAX_RETRIES = 2
|
||
|
||
|
||
@router.post("/step")
|
||
async def execute_upgrade_step(request: Request):
|
||
"""执行单个升级步骤(含验证与重试)"""
|
||
# 权限检查:仅班主任可执行升级操作
|
||
user_type = getattr(request.state, 'user_type', None)
|
||
if user_type != 'admin':
|
||
return error_response(message="仅管理员可执行升级操作", code=403)
|
||
|
||
is_teacher = await PermissionChecker.check_is_teacher(
|
||
getattr(request.state, 'user_id', 0)
|
||
)
|
||
if not is_teacher:
|
||
return error_response(message="仅班主任可执行升级操作", code=403)
|
||
|
||
body = await request.json()
|
||
version = body.get('version', '')
|
||
|
||
if not version:
|
||
return error_response(message='缺少版本号参数', code=400)
|
||
|
||
if version not in ALL_VERSIONS:
|
||
return error_response(message=f'未知版本: {version}', code=400)
|
||
|
||
# SQL 文件路径
|
||
sql_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'sql', 'upgrades')
|
||
sql_file = os.path.normpath(os.path.join(sql_dir, ALL_VERSIONS[version]))
|
||
|
||
if not os.path.exists(sql_file):
|
||
return error_response(message=f'SQL 文件不存在: {ALL_VERSIONS[version]}', code=500)
|
||
|
||
last_error = None
|
||
|
||
for attempt in range(1, MAX_RETRIES + 1):
|
||
try:
|
||
# 读取并执行 SQL
|
||
with open(sql_file, 'r', encoding='utf-8') as f:
|
||
sql_content = f.read().strip()
|
||
|
||
if sql_content and sql_content != '--':
|
||
pool = get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await _execute_sql_content(cursor, sql_content)
|
||
await conn.commit()
|
||
|
||
# 更新版本号
|
||
await execute_update(
|
||
"INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) "
|
||
"ON DUPLICATE KEY UPDATE setting_value = %s",
|
||
(version, version)
|
||
)
|
||
|
||
# 验证版本号是否正确写入
|
||
verify = await _verify_upgrade(version)
|
||
if verify['ok']:
|
||
new_version = version
|
||
logger.info(f"数据库升级成功: v{version} ({ALL_VERSIONS[version]})")
|
||
return success_response(data={
|
||
'success': True,
|
||
'version': version,
|
||
'message': f"升级至 v{version} 成功 ({ALL_VERSIONS[version]})",
|
||
'current': new_version
|
||
})
|
||
|
||
# 验证失败,准备重试
|
||
last_error = f"升级验证失败: {verify['message']}"
|
||
if attempt < MAX_RETRIES:
|
||
logger.warning(f"v{version} 升级验证失败,准备第 {attempt + 1} 次重试: {last_error}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
logger.warning(f"v{version} 升级第 {attempt} 次失败: {last_error}")
|
||
if attempt < MAX_RETRIES:
|
||
continue
|
||
|
||
# 所有重试均失败
|
||
logger.error(f"数据库升级失败: v{version} (尝试 {MAX_RETRIES} 次) - {last_error}")
|
||
return error_response(
|
||
message=f"升级至 v{version} 失败 (尝试 {MAX_RETRIES} 次): {last_error}",
|
||
code=500
|
||
)
|
||
|
||
|
||
def _compare_versions(v1: str, v2: str) -> int:
|
||
"""比较两个版本号,返回 1/0/-1"""
|
||
t1 = _version_tuple(v1)
|
||
t2 = _version_tuple(v2)
|
||
if t1 > t2:
|
||
return 1
|
||
elif t1 < t2:
|
||
return -1
|
||
return 0
|
||
|
||
|
||
def _version_tuple(v: str) -> tuple:
|
||
"""将版本字符串转为可比较的元组"""
|
||
parts = []
|
||
for p in v.split('.'):
|
||
try:
|
||
parts.append(int(p))
|
||
except ValueError:
|
||
parts.append(0)
|
||
return tuple(parts)
|
||
|
||
|
||
async def _execute_sql_content(cursor, sql_content: str):
|
||
"""执行 SQL 内容,处理存储过程中的 DELIMITER"""
|
||
sql_content = sql_content.strip()
|
||
if not sql_content or sql_content == '--':
|
||
return # 空文件或纯注释,无需执行
|
||
|
||
# 如果包含 DELIMITER,需要特殊处理
|
||
if 'DELIMITER' in sql_content.upper():
|
||
lines = sql_content.split('\n')
|
||
current_block = []
|
||
in_procedure = False
|
||
buffer = '' # 使用局部变量而非函数属性,避免跨调用泄漏
|
||
|
||
for line in lines:
|
||
stripped = line.strip()
|
||
# 跳过纯注释行
|
||
if stripped.startswith('--') or stripped.startswith('#'):
|
||
if not in_procedure:
|
||
continue
|
||
else:
|
||
current_block.append(line)
|
||
continue
|
||
|
||
if stripped.upper().startswith('DELIMITER $$'):
|
||
# 开始存储过程定义
|
||
in_procedure = True
|
||
current_block = []
|
||
continue
|
||
elif stripped.upper() == 'DELIMITER ;':
|
||
# 执行缓冲区中剩余的存储过程
|
||
if current_block:
|
||
proc_sql = '\n'.join(current_block).strip()
|
||
if proc_sql:
|
||
proc_sql = re.sub(r'\$\$\s*$', '', proc_sql)
|
||
if proc_sql:
|
||
await cursor.execute(proc_sql)
|
||
in_procedure = False
|
||
current_block = []
|
||
continue
|
||
elif stripped.upper().startswith('DELIMITER'):
|
||
# 其他 DELIMITER 指令,跳过
|
||
continue
|
||
|
||
if in_procedure:
|
||
current_block.append(line)
|
||
# 遇到 $$ 结尾的行,说明一个存储过程定义结束,立即执行
|
||
if stripped.endswith('$$'):
|
||
proc_sql = '\n'.join(current_block).strip()
|
||
if proc_sql:
|
||
# 移除结尾的 $$ 定界符
|
||
proc_sql = re.sub(r'\$\$\s*$', '', proc_sql)
|
||
if proc_sql:
|
||
await cursor.execute(proc_sql)
|
||
current_block = []
|
||
else:
|
||
# 普通SQL,按完整语句分割(以分号结尾)
|
||
if stripped:
|
||
# 累积多行直到遇到分号
|
||
if buffer:
|
||
buffer += ' ' + stripped
|
||
else:
|
||
buffer = stripped
|
||
|
||
# 如果以分号结尾,执行并清空缓冲区
|
||
if buffer.rstrip().endswith(';'):
|
||
stmt = buffer.rstrip(';').strip()
|
||
if stmt:
|
||
await cursor.execute(stmt)
|
||
buffer = ''
|
||
|
||
# 处理缓冲区中剩余的语句
|
||
if buffer:
|
||
stmt = buffer.rstrip(';').strip()
|
||
if stmt:
|
||
await cursor.execute(stmt)
|
||
else:
|
||
# 无 DELIMITER,按分号+换行分割语句
|
||
statements = re.split(r';\s*\n', sql_content)
|
||
for stmt in statements:
|
||
stmt = stmt.strip()
|
||
if stmt and stmt != '--':
|
||
await cursor.execute(stmt)
|