Files
SharedClassManager/backend/routes/upgrade.py
2026-05-29 08:32:28 +08:00

353 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ===========================================
# 班级操行分管理系统 - 升级管理路由
#
# 开发者: Canglan
# 版权归属: Sea Network Technology Studio
#
# 版权所有 © Sea Network Technology Studio
# ===========================================
from fastapi import APIRouter, Request
from utils.database import execute_query, execute_update, get_pool
from utils.response import success_response, error_response
from utils.logger import setup_logger
from middleware.permission import PermissionChecker
import os
import re
logger = setup_logger()
router = APIRouter()
# 版本列表(按顺序)
# 版本列表(按顺序)
ALL_VERSIONS = {
'1.0': 'v1.0.sql',
'1.1': 'v1.1.sql',
'1.2': 'v1.2.sql',
'1.3': 'v1.3.sql',
'1.4': 'v1.4.sql',
'1.5': 'v1.5.sql',
'1.6': 'v1.6.sql',
'1.7': 'v1.7.sql',
'1.8': 'v1.8.sql',
'2.0': 'v2.0.sql',
'2.0.1': 'v2.0.1.sql',
'2.1': 'v2.1.sql',
'2.2': 'v2.2.sql',
'2.3': 'v2.3.sql',
}
# 版本特征标记(按优先级从高到低)
VERSION_MARKERS = [
('2.0', 'students', 'dormitory_number'),
('1.8', 'conduct_records', 'related_type'),
('1.7', 'subjects', 'sort_order'),
]
async def _detect_current_version() -> str:
"""检测当前数据库版本,优先从 system_settings 读取,否则通过列特征推断"""
# 1. 尝试从 system_settings 读取 db_version
try:
row = await execute_query(
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
)
if row:
return row[0]['setting_value']
except Exception as e:
logger.warning(f"查询 system_settings 表失败,将通过列特征推断版本: {e}")
# 2. 通过列特征推断版本
inferred_version = '1.0'
for version, table, column in VERSION_MARKERS:
try:
result = await execute_query(
"SELECT COUNT(*) as cnt FROM INFORMATION_SCHEMA.COLUMNS "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s AND COLUMN_NAME = %s",
(table, column)
)
if result and result[0]['cnt'] > 0:
inferred_version = version
break
except Exception as e:
logger.warning(f"检查列特征失败 ({table}.{column}): {e}")
logger.info(f"通过列特征推断数据库版本为: {inferred_version}")
# 3. 确保 system_settings 表存在并写入推断版本
try:
await execute_update(
"CREATE TABLE IF NOT EXISTS `system_settings` ("
"`setting_key` VARCHAR(50) PRIMARY KEY,"
"`setting_value` VARCHAR(255) NOT NULL,"
"`updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci"
)
await execute_update(
"INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) "
"ON DUPLICATE KEY UPDATE setting_value = %s",
(inferred_version, inferred_version)
)
logger.info(f"已将推断版本 {inferred_version} 写入 system_settings")
except Exception as e:
logger.error(f"写入推断版本失败: {e}")
return inferred_version
@router.get("/check")
async def check_upgrade(request: Request):
"""检查数据库版本是否需要升级"""
# 权限检查:仅班主任可执行升级操作
user_type = getattr(request.state, 'user_type', None)
if user_type != 'admin':
return error_response(message="仅管理员可执行升级操作", code=403)
is_teacher = await PermissionChecker.check_is_teacher(
getattr(request.state, 'user_id', 0)
)
if not is_teacher:
return error_response(message="仅班主任可执行升级操作", code=403)
# 检测当前数据库版本(支持自动推断)
current_version = await _detect_current_version()
# 读取目标版本(从 VERSION 文件)
version_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'VERSION')
version_file = os.path.normpath(version_file)
target_version = '0.0.0'
try:
if os.path.exists(version_file):
with open(version_file, 'r') as f:
target_version = f.read().strip()
except Exception:
pass
# 计算需要升级的步骤
needs_upgrade = _compare_versions(target_version, current_version) > 0
steps = []
for version, file_name in sorted(ALL_VERSIONS.items(), key=lambda x: _version_tuple(x[0])):
if _compare_versions(version, current_version) > 0 and _compare_versions(version, target_version) <= 0:
steps.append({'version': version, 'file': file_name})
return success_response(data={
'needs_upgrade': needs_upgrade,
'current': current_version,
'target': target_version,
'steps': steps
})
async def _verify_upgrade(expected_version: str) -> dict:
"""验证升级结果:检查版本号是否已正确更新
Returns:
{'ok': bool, 'message': str}
"""
try:
row = await execute_query(
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
)
if not row:
return {'ok': False, 'message': 'db_version 记录不存在'}
actual = row[0]['setting_value']
if actual != expected_version:
return {'ok': False, 'message': f'版本号不匹配:期望 {expected_version},实际 {actual}'}
return {'ok': True, 'message': '验证通过'}
except Exception as e:
return {'ok': False, 'message': f'验证查询失败: {str(e)}'}
MAX_RETRIES = 2
@router.post("/step")
async def execute_upgrade_step(request: Request):
"""执行单个升级步骤(含验证与重试)"""
# 权限检查:仅班主任可执行升级操作
user_type = getattr(request.state, 'user_type', None)
if user_type != 'admin':
return error_response(message="仅管理员可执行升级操作", code=403)
is_teacher = await PermissionChecker.check_is_teacher(
getattr(request.state, 'user_id', 0)
)
if not is_teacher:
return error_response(message="仅班主任可执行升级操作", code=403)
body = await request.json()
version = body.get('version', '')
if not version:
return error_response(message='缺少版本号参数', code=400)
if version not in ALL_VERSIONS:
return error_response(message=f'未知版本: {version}', code=400)
# SQL 文件路径
sql_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'sql', 'upgrades')
sql_file = os.path.normpath(os.path.join(sql_dir, ALL_VERSIONS[version]))
if not os.path.exists(sql_file):
return error_response(message=f'SQL 文件不存在: {ALL_VERSIONS[version]}', code=500)
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
# 读取并执行 SQL
with open(sql_file, 'r', encoding='utf-8') as f:
sql_content = f.read().strip()
if sql_content and sql_content != '--':
pool = get_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
await _execute_sql_content(cursor, sql_content)
await conn.commit()
# 更新版本号
await execute_update(
"INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) "
"ON DUPLICATE KEY UPDATE setting_value = %s",
(version, version)
)
# 验证版本号是否正确写入
verify = await _verify_upgrade(version)
if verify['ok']:
new_version = version
logger.info(f"数据库升级成功: v{version} ({ALL_VERSIONS[version]})")
return success_response(data={
'success': True,
'version': version,
'message': f"升级至 v{version} 成功 ({ALL_VERSIONS[version]})",
'current': new_version
})
# 验证失败,准备重试
last_error = f"升级验证失败: {verify['message']}"
if attempt < MAX_RETRIES:
logger.warning(f"v{version} 升级验证失败,准备第 {attempt + 1} 次重试: {last_error}")
continue
except Exception as e:
last_error = str(e)
logger.warning(f"v{version} 升级第 {attempt} 次失败: {last_error}")
if attempt < MAX_RETRIES:
continue
# 所有重试均失败
logger.error(f"数据库升级失败: v{version} (尝试 {MAX_RETRIES} 次) - {last_error}")
return error_response(
message=f"升级至 v{version} 失败 (尝试 {MAX_RETRIES} 次): {last_error}",
code=500
)
def _compare_versions(v1: str, v2: str) -> int:
"""比较两个版本号,返回 1/0/-1"""
t1 = _version_tuple(v1)
t2 = _version_tuple(v2)
if t1 > t2:
return 1
elif t1 < t2:
return -1
return 0
def _version_tuple(v: str) -> tuple:
"""将版本字符串转为可比较的元组"""
parts = []
for p in v.split('.'):
try:
parts.append(int(p))
except ValueError:
parts.append(0)
return tuple(parts)
async def _execute_sql_content(cursor, sql_content: str):
"""执行 SQL 内容,处理存储过程中的 DELIMITER"""
sql_content = sql_content.strip()
if not sql_content or sql_content == '--':
return # 空文件或纯注释,无需执行
# 如果包含 DELIMITER需要特殊处理
if 'DELIMITER' in sql_content.upper():
lines = sql_content.split('\n')
current_block = []
in_procedure = False
buffer = '' # 使用局部变量而非函数属性,避免跨调用泄漏
for line in lines:
stripped = line.strip()
# 跳过纯注释行
if stripped.startswith('--') or stripped.startswith('#'):
if not in_procedure:
continue
else:
current_block.append(line)
continue
if stripped.upper().startswith('DELIMITER $$'):
# 开始存储过程定义
in_procedure = True
current_block = []
continue
elif stripped.upper() == 'DELIMITER ;':
# 执行缓冲区中剩余的存储过程
if current_block:
proc_sql = '\n'.join(current_block).strip()
if proc_sql:
proc_sql = re.sub(r'\$\$\s*$', '', proc_sql)
if proc_sql:
await cursor.execute(proc_sql)
in_procedure = False
current_block = []
continue
elif stripped.upper().startswith('DELIMITER'):
# 其他 DELIMITER 指令,跳过
continue
if in_procedure:
current_block.append(line)
# 遇到 $$ 结尾的行,说明一个存储过程定义结束,立即执行
if stripped.endswith('$$'):
proc_sql = '\n'.join(current_block).strip()
if proc_sql:
# 移除结尾的 $$ 定界符
proc_sql = re.sub(r'\$\$\s*$', '', proc_sql)
if proc_sql:
await cursor.execute(proc_sql)
current_block = []
else:
# 普通SQL按完整语句分割以分号结尾
if stripped:
# 累积多行直到遇到分号
if buffer:
buffer += ' ' + stripped
else:
buffer = stripped
# 如果以分号结尾,执行并清空缓冲区
if buffer.rstrip().endswith(';'):
stmt = buffer.rstrip(';').strip()
if stmt:
await cursor.execute(stmt)
buffer = ''
# 处理缓冲区中剩余的语句
if buffer:
stmt = buffer.rstrip(';').strip()
if stmt:
await cursor.execute(stmt)
else:
# 无 DELIMITER按分号+换行分割语句
statements = re.split(r';\s*\n', sql_content)
for stmt in statements:
stmt = stmt.strip()
if stmt and stmt != '--':
await cursor.execute(stmt)