v2.2更新
This commit is contained in:
231
backend/routes/upgrade.py
Normal file
231
backend/routes/upgrade.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# ===========================================
|
||||
# 班级操行分管理系统 - 升级管理路由
|
||||
#
|
||||
# 开发者: Canglan
|
||||
# 版权归属: Sea Network Technology Studio
|
||||
#
|
||||
# 版权所有 © Sea Network Technology Studio
|
||||
# ===========================================
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from utils.database import execute_query, execute_update, get_pool
|
||||
from utils.response import success_response, error_response
|
||||
from utils.logger import setup_logger
|
||||
from middleware.permission import PermissionChecker
|
||||
import os
|
||||
import re
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter()
|
||||
|
||||
# 版本列表(按顺序)
|
||||
ALL_VERSIONS = {
|
||||
'1.7': 'v1.7.sql',
|
||||
'1.8': 'v1.8.sql',
|
||||
'2.0': 'v2.0.sql',
|
||||
'2.0.1': 'v2.0.1.sql',
|
||||
'2.1': 'v2.1.sql',
|
||||
'2.2': 'v2.2.sql',
|
||||
}
|
||||
|
||||
|
||||
@router.get("/check")
|
||||
async def check_upgrade(request: Request):
|
||||
"""检查数据库版本是否需要升级"""
|
||||
# 权限检查:仅班主任可执行升级操作
|
||||
user_type = getattr(request.state, 'user_type', None)
|
||||
if user_type != 'admin':
|
||||
return error_response(message="仅管理员可执行升级操作", code=403)
|
||||
|
||||
is_teacher = await PermissionChecker.check_is_teacher(
|
||||
getattr(request.state, 'user_id', 0)
|
||||
)
|
||||
if not is_teacher:
|
||||
return error_response(message="仅班主任可执行升级操作", code=403)
|
||||
|
||||
user_id = request.state.user.get('user_id') if hasattr(request.state, 'user') else getattr(request.state, 'user_id', None)
|
||||
|
||||
# 检测当前数据库版本
|
||||
current_version = '0.0.0'
|
||||
try:
|
||||
row = await execute_query(
|
||||
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
|
||||
)
|
||||
if row:
|
||||
current_version = row[0]['setting_value']
|
||||
except Exception:
|
||||
pass # 表不存在时使用默认值
|
||||
|
||||
# 读取目标版本(从 VERSION 文件)
|
||||
version_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'VERSION')
|
||||
version_file = os.path.normpath(version_file)
|
||||
target_version = '0.0.0'
|
||||
try:
|
||||
if os.path.exists(version_file):
|
||||
with open(version_file, 'r') as f:
|
||||
target_version = f.read().strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 计算需要升级的步骤
|
||||
needs_upgrade = _compare_versions(target_version, current_version) > 0
|
||||
|
||||
steps = []
|
||||
for version, file_name in sorted(ALL_VERSIONS.items(), key=lambda x: _version_tuple(x[0])):
|
||||
if _compare_versions(version, current_version) > 0 and _compare_versions(version, target_version) <= 0:
|
||||
steps.append({'version': version, 'file': file_name})
|
||||
|
||||
return success_response(data={
|
||||
'needs_upgrade': needs_upgrade,
|
||||
'current': current_version,
|
||||
'target': target_version,
|
||||
'steps': steps
|
||||
})
|
||||
|
||||
|
||||
@router.post("/step")
|
||||
async def execute_upgrade_step(request: Request):
|
||||
"""执行单个升级步骤"""
|
||||
# 权限检查:仅班主任可执行升级操作
|
||||
user_type = getattr(request.state, 'user_type', None)
|
||||
if user_type != 'admin':
|
||||
return error_response(message="仅管理员可执行升级操作", code=403)
|
||||
|
||||
is_teacher = await PermissionChecker.check_is_teacher(
|
||||
getattr(request.state, 'user_id', 0)
|
||||
)
|
||||
if not is_teacher:
|
||||
return error_response(message="仅班主任可执行升级操作", code=403)
|
||||
|
||||
user_id = request.state.user.get('user_id') if hasattr(request.state, 'user') else getattr(request.state, 'user_id', None)
|
||||
|
||||
body = await request.json()
|
||||
version = body.get('version', '')
|
||||
|
||||
if not version:
|
||||
return error_response(message='缺少版本号参数', code=400)
|
||||
|
||||
if version not in ALL_VERSIONS:
|
||||
return error_response(message=f'未知版本: {version}', code=400)
|
||||
|
||||
# SQL 文件路径
|
||||
sql_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'sql', 'upgrades')
|
||||
sql_file = os.path.normpath(os.path.join(sql_dir, ALL_VERSIONS[version]))
|
||||
|
||||
if not os.path.exists(sql_file):
|
||||
return error_response(message=f'SQL 文件不存在: {ALL_VERSIONS[version]}', code=500)
|
||||
|
||||
try:
|
||||
# 读取并执行 SQL
|
||||
with open(sql_file, 'r', encoding='utf-8') as f:
|
||||
sql_content = f.read().strip()
|
||||
|
||||
if sql_content and sql_content != '--':
|
||||
# 使用 aiomysql 直接执行多条 SQL
|
||||
pool = get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# 分割 SQL 语句(按 DELIMITER 处理存储过程)
|
||||
await _execute_sql_content(cursor, sql_content)
|
||||
await conn.commit()
|
||||
|
||||
# 更新版本号
|
||||
await execute_update(
|
||||
"INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) "
|
||||
"ON DUPLICATE KEY UPDATE setting_value = %s",
|
||||
(version, version)
|
||||
)
|
||||
|
||||
# 重新检测版本
|
||||
new_version = '0.0.0'
|
||||
try:
|
||||
row = await execute_query(
|
||||
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
|
||||
)
|
||||
if row:
|
||||
new_version = row[0]['setting_value']
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"数据库升级成功: v{version} ({ALL_VERSIONS[version]})")
|
||||
|
||||
return success_response(data={
|
||||
'success': True,
|
||||
'version': version,
|
||||
'message': f"升级至 v{version} 成功 ({ALL_VERSIONS[version]})",
|
||||
'current': new_version
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库升级失败: v{version} - {str(e)}")
|
||||
return error_response(message=f"升级至 v{version} 失败: {str(e)}", code=500)
|
||||
|
||||
|
||||
def _compare_versions(v1: str, v2: str) -> int:
|
||||
"""比较两个版本号,返回 1/0/-1"""
|
||||
t1 = _version_tuple(v1)
|
||||
t2 = _version_tuple(v2)
|
||||
if t1 > t2:
|
||||
return 1
|
||||
elif t1 < t2:
|
||||
return -1
|
||||
return 0
|
||||
|
||||
|
||||
def _version_tuple(v: str) -> tuple:
|
||||
"""将版本字符串转为可比较的元组"""
|
||||
parts = []
|
||||
for p in v.split('.'):
|
||||
try:
|
||||
parts.append(int(p))
|
||||
except ValueError:
|
||||
parts.append(0)
|
||||
return tuple(parts)
|
||||
|
||||
|
||||
async def _execute_sql_content(cursor, sql_content: str):
|
||||
"""执行 SQL 内容,处理存储过程中的 DELIMITER"""
|
||||
# 如果包含 DELIMITER,需要特殊处理
|
||||
if 'DELIMITER' in sql_content:
|
||||
# 移除 DELIMITER 行,按 $$ 分割存储过程
|
||||
lines = sql_content.split('\n')
|
||||
current_block = []
|
||||
in_procedure = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped.upper().startswith('DELIMITER $$'):
|
||||
in_procedure = True
|
||||
current_block = []
|
||||
continue
|
||||
elif stripped.upper() == 'DELIMITER ;':
|
||||
# 执行累积的存储过程块
|
||||
if current_block:
|
||||
proc_sql = '\n'.join(current_block).strip()
|
||||
if proc_sql:
|
||||
await cursor.execute(proc_sql)
|
||||
in_procedure = False
|
||||
current_block = []
|
||||
continue
|
||||
elif stripped.upper().startswith('DELIMITER'):
|
||||
continue
|
||||
|
||||
if in_procedure:
|
||||
current_block.append(line)
|
||||
else:
|
||||
# 普通SQL,按分号分割执行
|
||||
if stripped and not stripped.startswith('--'):
|
||||
# 简单的按分号分割
|
||||
for stmt in stripped.split(';'):
|
||||
stmt = stmt.strip()
|
||||
if stmt:
|
||||
await cursor.execute(stmt)
|
||||
else:
|
||||
# 无 DELIMITER,简单执行
|
||||
# 按 CREATE 分割以支持多语句
|
||||
# 分割 SQL 语句
|
||||
statements = re.split(r';\s*\n', sql_content)
|
||||
for stmt in statements:
|
||||
stmt = stmt.strip()
|
||||
if stmt and stmt != '--':
|
||||
await cursor.execute(stmt)
|
||||
Reference in New Issue
Block a user