Files
ClassManager/backend/routes/upgrade.py
2026-05-28 15:38:32 +08:00

232 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ===========================================
# 班级操行分管理系统 - 升级管理路由
#
# 开发者: Canglan
# 版权归属: Sea Network Technology Studio
#
# 版权所有 © Sea Network Technology Studio
# ===========================================
from fastapi import APIRouter, Request
from utils.database import execute_query, execute_update, get_pool
from utils.response import success_response, error_response
from utils.logger import setup_logger
from middleware.permission import PermissionChecker
import os
import re
logger = setup_logger()
router = APIRouter()
# 版本列表(按顺序)
ALL_VERSIONS = {
'1.7': 'v1.7.sql',
'1.8': 'v1.8.sql',
'2.0': 'v2.0.sql',
'2.0.1': 'v2.0.1.sql',
'2.1': 'v2.1.sql',
'2.2': 'v2.2.sql',
}
@router.get("/check")
async def check_upgrade(request: Request):
"""检查数据库版本是否需要升级"""
# 权限检查:仅班主任可执行升级操作
user_type = getattr(request.state, 'user_type', None)
if user_type != 'admin':
return error_response(message="仅管理员可执行升级操作", code=403)
is_teacher = await PermissionChecker.check_is_teacher(
getattr(request.state, 'user_id', 0)
)
if not is_teacher:
return error_response(message="仅班主任可执行升级操作", code=403)
user_id = request.state.user.get('user_id') if hasattr(request.state, 'user') else getattr(request.state, 'user_id', None)
# 检测当前数据库版本
current_version = '0.0.0'
try:
row = await execute_query(
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
)
if row:
current_version = row[0]['setting_value']
except Exception:
pass # 表不存在时使用默认值
# 读取目标版本(从 VERSION 文件)
version_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'VERSION')
version_file = os.path.normpath(version_file)
target_version = '0.0.0'
try:
if os.path.exists(version_file):
with open(version_file, 'r') as f:
target_version = f.read().strip()
except Exception:
pass
# 计算需要升级的步骤
needs_upgrade = _compare_versions(target_version, current_version) > 0
steps = []
for version, file_name in sorted(ALL_VERSIONS.items(), key=lambda x: _version_tuple(x[0])):
if _compare_versions(version, current_version) > 0 and _compare_versions(version, target_version) <= 0:
steps.append({'version': version, 'file': file_name})
return success_response(data={
'needs_upgrade': needs_upgrade,
'current': current_version,
'target': target_version,
'steps': steps
})
@router.post("/step")
async def execute_upgrade_step(request: Request):
"""执行单个升级步骤"""
# 权限检查:仅班主任可执行升级操作
user_type = getattr(request.state, 'user_type', None)
if user_type != 'admin':
return error_response(message="仅管理员可执行升级操作", code=403)
is_teacher = await PermissionChecker.check_is_teacher(
getattr(request.state, 'user_id', 0)
)
if not is_teacher:
return error_response(message="仅班主任可执行升级操作", code=403)
user_id = request.state.user.get('user_id') if hasattr(request.state, 'user') else getattr(request.state, 'user_id', None)
body = await request.json()
version = body.get('version', '')
if not version:
return error_response(message='缺少版本号参数', code=400)
if version not in ALL_VERSIONS:
return error_response(message=f'未知版本: {version}', code=400)
# SQL 文件路径
sql_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'sql', 'upgrades')
sql_file = os.path.normpath(os.path.join(sql_dir, ALL_VERSIONS[version]))
if not os.path.exists(sql_file):
return error_response(message=f'SQL 文件不存在: {ALL_VERSIONS[version]}', code=500)
try:
# 读取并执行 SQL
with open(sql_file, 'r', encoding='utf-8') as f:
sql_content = f.read().strip()
if sql_content and sql_content != '--':
# 使用 aiomysql 直接执行多条 SQL
pool = get_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
# 分割 SQL 语句(按 DELIMITER 处理存储过程)
await _execute_sql_content(cursor, sql_content)
await conn.commit()
# 更新版本号
await execute_update(
"INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) "
"ON DUPLICATE KEY UPDATE setting_value = %s",
(version, version)
)
# 重新检测版本
new_version = '0.0.0'
try:
row = await execute_query(
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
)
if row:
new_version = row[0]['setting_value']
except Exception:
pass
logger.info(f"数据库升级成功: v{version} ({ALL_VERSIONS[version]})")
return success_response(data={
'success': True,
'version': version,
'message': f"升级至 v{version} 成功 ({ALL_VERSIONS[version]})",
'current': new_version
})
except Exception as e:
logger.error(f"数据库升级失败: v{version} - {str(e)}")
return error_response(message=f"升级至 v{version} 失败: {str(e)}", code=500)
def _compare_versions(v1: str, v2: str) -> int:
"""比较两个版本号,返回 1/0/-1"""
t1 = _version_tuple(v1)
t2 = _version_tuple(v2)
if t1 > t2:
return 1
elif t1 < t2:
return -1
return 0
def _version_tuple(v: str) -> tuple:
"""将版本字符串转为可比较的元组"""
parts = []
for p in v.split('.'):
try:
parts.append(int(p))
except ValueError:
parts.append(0)
return tuple(parts)
async def _execute_sql_content(cursor, sql_content: str):
"""执行 SQL 内容,处理存储过程中的 DELIMITER"""
# 如果包含 DELIMITER需要特殊处理
if 'DELIMITER' in sql_content:
# 移除 DELIMITER 行,按 $$ 分割存储过程
lines = sql_content.split('\n')
current_block = []
in_procedure = False
for line in lines:
stripped = line.strip()
if stripped.upper().startswith('DELIMITER $$'):
in_procedure = True
current_block = []
continue
elif stripped.upper() == 'DELIMITER ;':
# 执行累积的存储过程块
if current_block:
proc_sql = '\n'.join(current_block).strip()
if proc_sql:
await cursor.execute(proc_sql)
in_procedure = False
current_block = []
continue
elif stripped.upper().startswith('DELIMITER'):
continue
if in_procedure:
current_block.append(line)
else:
# 普通SQL按分号分割执行
if stripped and not stripped.startswith('--'):
# 简单的按分号分割
for stmt in stripped.split(';'):
stmt = stmt.strip()
if stmt:
await cursor.execute(stmt)
else:
# 无 DELIMITER简单执行
# 按 CREATE 分割以支持多语句
# 分割 SQL 语句
statements = re.split(r';\s*\n', sql_content)
for stmt in statements:
stmt = stmt.strip()
if stmt and stmt != '--':
await cursor.execute(stmt)