Files
SharedClassManager/backend/routes/upgrade.py
2026-05-28 20:48:29 +08:00

294 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ===========================================
# 班级操行分管理系统 - 升级管理路由
#
# 开发者: Canglan
# 版权归属: Sea Network Technology Studio
#
# 版权所有 © Sea Network Technology Studio
# ===========================================
from fastapi import APIRouter, Request
from utils.database import execute_query, execute_update, get_pool
from utils.response import success_response, error_response
from utils.logger import setup_logger
from middleware.permission import PermissionChecker
import os
import re
logger = setup_logger()
router = APIRouter()
# 版本列表(按顺序)
# 版本列表(按顺序)
ALL_VERSIONS = {
'1.0': 'v1.0.sql',
'1.1': 'v1.1.sql',
'1.2': 'v1.2.sql',
'1.3': 'v1.3.sql',
'1.4': 'v1.4.sql',
'1.5': 'v1.5.sql',
'1.6': 'v1.6.sql',
'1.7': 'v1.7.sql',
'1.8': 'v1.8.sql',
'2.0': 'v2.0.sql',
'2.0.1': 'v2.0.1.sql',
'2.1': 'v2.1.sql',
'2.2': 'v2.2.sql',
'2.3': 'v2.3.sql',
}
@router.get("/check")
async def check_upgrade(request: Request):
"""检查数据库版本是否需要升级"""
# 权限检查:仅班主任可执行升级操作
user_type = getattr(request.state, 'user_type', None)
if user_type != 'admin':
return error_response(message="仅管理员可执行升级操作", code=403)
is_teacher = await PermissionChecker.check_is_teacher(
getattr(request.state, 'user_id', 0)
)
if not is_teacher:
return error_response(message="仅班主任可执行升级操作", code=403)
# 检测当前数据库版本
current_version = '0.0.0'
try:
row = await execute_query(
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
)
if row:
current_version = row[0]['setting_value']
except Exception:
pass # 表不存在时使用默认值
# 读取目标版本(从 VERSION 文件)
version_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'VERSION')
version_file = os.path.normpath(version_file)
target_version = '0.0.0'
try:
if os.path.exists(version_file):
with open(version_file, 'r') as f:
target_version = f.read().strip()
except Exception:
pass
# 计算需要升级的步骤
needs_upgrade = _compare_versions(target_version, current_version) > 0
steps = []
for version, file_name in sorted(ALL_VERSIONS.items(), key=lambda x: _version_tuple(x[0])):
if _compare_versions(version, current_version) > 0 and _compare_versions(version, target_version) <= 0:
steps.append({'version': version, 'file': file_name})
return success_response(data={
'needs_upgrade': needs_upgrade,
'current': current_version,
'target': target_version,
'steps': steps
})
async def _verify_upgrade(expected_version: str) -> dict:
"""验证升级结果:检查版本号是否已正确更新
Returns:
{'ok': bool, 'message': str}
"""
try:
row = await execute_query(
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
)
if not row:
return {'ok': False, 'message': 'db_version 记录不存在'}
actual = row[0]['setting_value']
if actual != expected_version:
return {'ok': False, 'message': f'版本号不匹配:期望 {expected_version},实际 {actual}'}
return {'ok': True, 'message': '验证通过'}
except Exception as e:
return {'ok': False, 'message': f'验证查询失败: {str(e)}'}
MAX_RETRIES = 2
@router.post("/step")
async def execute_upgrade_step(request: Request):
"""执行单个升级步骤(含验证与重试)"""
# 权限检查:仅班主任可执行升级操作
user_type = getattr(request.state, 'user_type', None)
if user_type != 'admin':
return error_response(message="仅管理员可执行升级操作", code=403)
is_teacher = await PermissionChecker.check_is_teacher(
getattr(request.state, 'user_id', 0)
)
if not is_teacher:
return error_response(message="仅班主任可执行升级操作", code=403)
body = await request.json()
version = body.get('version', '')
if not version:
return error_response(message='缺少版本号参数', code=400)
if version not in ALL_VERSIONS:
return error_response(message=f'未知版本: {version}', code=400)
# SQL 文件路径
sql_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'sql', 'upgrades')
sql_file = os.path.normpath(os.path.join(sql_dir, ALL_VERSIONS[version]))
if not os.path.exists(sql_file):
return error_response(message=f'SQL 文件不存在: {ALL_VERSIONS[version]}', code=500)
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
# 读取并执行 SQL
with open(sql_file, 'r', encoding='utf-8') as f:
sql_content = f.read().strip()
if sql_content and sql_content != '--':
pool = get_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
await _execute_sql_content(cursor, sql_content)
await conn.commit()
# 更新版本号
await execute_update(
"INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) "
"ON DUPLICATE KEY UPDATE setting_value = %s",
(version, version)
)
# 验证版本号是否正确写入
verify = await _verify_upgrade(version)
if verify['ok']:
new_version = version
logger.info(f"数据库升级成功: v{version} ({ALL_VERSIONS[version]})")
return success_response(data={
'success': True,
'version': version,
'message': f"升级至 v{version} 成功 ({ALL_VERSIONS[version]})",
'current': new_version
})
# 验证失败,准备重试
last_error = f"升级验证失败: {verify['message']}"
if attempt < MAX_RETRIES:
logger.warning(f"v{version} 升级验证失败,准备第 {attempt + 1} 次重试: {last_error}")
continue
except Exception as e:
last_error = str(e)
logger.warning(f"v{version} 升级第 {attempt} 次失败: {last_error}")
if attempt < MAX_RETRIES:
continue
# 所有重试均失败
logger.error(f"数据库升级失败: v{version} (尝试 {MAX_RETRIES} 次) - {last_error}")
return error_response(
message=f"升级至 v{version} 失败 (尝试 {MAX_RETRIES} 次): {last_error}",
code=500
)
def _compare_versions(v1: str, v2: str) -> int:
"""比较两个版本号,返回 1/0/-1"""
t1 = _version_tuple(v1)
t2 = _version_tuple(v2)
if t1 > t2:
return 1
elif t1 < t2:
return -1
return 0
def _version_tuple(v: str) -> tuple:
"""将版本字符串转为可比较的元组"""
parts = []
for p in v.split('.'):
try:
parts.append(int(p))
except ValueError:
parts.append(0)
return tuple(parts)
async def _execute_sql_content(cursor, sql_content: str):
"""执行 SQL 内容,处理存储过程中的 DELIMITER"""
sql_content = sql_content.strip()
if not sql_content or sql_content == '--':
return # 空文件或纯注释,无需执行
# 如果包含 DELIMITER需要特殊处理
if 'DELIMITER' in sql_content.upper():
lines = sql_content.split('\n')
current_block = []
in_procedure = False
buffer = '' # 使用局部变量而非函数属性,避免跨调用泄漏
for line in lines:
stripped = line.strip()
# 跳过纯注释行
if stripped.startswith('--') or stripped.startswith('#'):
if not in_procedure:
continue
else:
current_block.append(line)
continue
if stripped.upper().startswith('DELIMITER $$'):
# 开始存储过程定义
in_procedure = True
current_block = []
continue
elif stripped.upper() == 'DELIMITER ;':
# 执行累积的存储过程块
if current_block:
proc_sql = '\n'.join(current_block).strip()
if proc_sql:
# 移除存储过程结尾的 $$ 定界符(发送给 MySQL 服务器时不需要)
proc_sql = re.sub(r'\$\$\s*$', '', proc_sql)
await cursor.execute(proc_sql)
in_procedure = False
current_block = []
continue
elif stripped.upper().startswith('DELIMITER'):
# 其他 DELIMITER 指令,跳过
continue
if in_procedure:
current_block.append(line)
else:
# 普通SQL按完整语句分割以分号结尾
if stripped:
# 累积多行直到遇到分号
if buffer:
buffer += ' ' + stripped
else:
buffer = stripped
# 如果以分号结尾,执行并清空缓冲区
if buffer.rstrip().endswith(';'):
stmt = buffer.rstrip(';').strip()
if stmt:
await cursor.execute(stmt)
buffer = ''
# 处理缓冲区中剩余的语句
if buffer:
stmt = buffer.rstrip(';').strip()
if stmt:
await cursor.execute(stmt)
else:
# 无 DELIMITER按分号+换行分割语句
statements = re.split(r';\s*\n', sql_content)
for stmt in statements:
stmt = stmt.strip()
if stmt and stmt != '--':
await cursor.execute(stmt)