# =========================================== # 班级操行分管理系统 - 升级管理路由 # # 开发者: Canglan # 版权归属: Sea Network Technology Studio # # 版权所有 © Sea Network Technology Studio # =========================================== from fastapi import APIRouter, Request from utils.database import execute_query, execute_update, get_pool from utils.response import success_response, error_response from utils.logger import setup_logger from middleware.permission import PermissionChecker import os import re logger = setup_logger() router = APIRouter() # 版本列表(按顺序) # 版本列表(按顺序) ALL_VERSIONS = { '1.0': 'v1.0.sql', '1.1': 'v1.1.sql', '1.2': 'v1.2.sql', '1.3': 'v1.3.sql', '1.4': 'v1.4.sql', '1.5': 'v1.5.sql', '1.6': 'v1.6.sql', '1.7': 'v1.7.sql', '1.8': 'v1.8.sql', '2.0': 'v2.0.sql', '2.0.1': 'v2.0.1.sql', '2.1': 'v2.1.sql', '2.2': 'v2.2.sql', '2.3': 'v2.3.sql', '2.4': 'v2.4.sql', '2.5': 'v2.5.sql', '2.5.1': 'v2.5.1.sql', '2.6': 'v2.6.sql', '2.7': 'v2.7.sql', } # 版本特征标记(按优先级从高到低) VERSION_MARKERS = [ ('2.0', 'students', 'dormitory_number'), ('1.8', 'conduct_records', 'related_type'), ('1.7', 'subjects', 'sort_order'), ] async def _detect_current_version() -> str: """检测当前数据库版本,优先从 system_settings 读取,否则通过列特征推断""" # 1. 尝试从 system_settings 读取 db_version try: row = await execute_query( "SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'" ) if row: return row[0]['setting_value'] except Exception as e: logger.warning(f"查询 system_settings 表失败,将通过列特征推断版本: {e}") # 2. 通过列特征推断版本 inferred_version = '1.0' for version, table, column in VERSION_MARKERS: try: result = await execute_query( "SELECT COUNT(*) as cnt FROM INFORMATION_SCHEMA.COLUMNS " "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s AND COLUMN_NAME = %s", (table, column) ) if result and result[0]['cnt'] > 0: inferred_version = version break except Exception as e: logger.warning(f"检查列特征失败 ({table}.{column}): {e}") logger.info(f"通过列特征推断数据库版本为: {inferred_version}") # 3. 确保 system_settings 表存在并写入推断版本 try: await execute_update( "CREATE TABLE IF NOT EXISTS `system_settings` (" "`setting_key` VARCHAR(50) PRIMARY KEY," "`setting_value` VARCHAR(255) NOT NULL," "`updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci" ) await execute_update( "INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) " "ON DUPLICATE KEY UPDATE setting_value = %s", (inferred_version, inferred_version) ) logger.info(f"已将推断版本 {inferred_version} 写入 system_settings") except Exception as e: logger.error(f"写入推断版本失败: {e}") return inferred_version @router.get("/check") async def check_upgrade(request: Request): """检查数据库版本是否需要升级""" # 权限检查:仅班主任可执行升级操作 user_type = getattr(request.state, 'user_type', None) if user_type != 'admin': return error_response(message="仅管理员可执行升级操作", code=403) is_teacher = await PermissionChecker.check_is_teacher( getattr(request.state, 'user_id', 0) ) if not is_teacher: return error_response(message="仅班主任可执行升级操作", code=403) # 检测当前数据库版本(支持自动推断) current_version = await _detect_current_version() # 读取目标版本(从 VERSION 文件) version_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'VERSION') version_file = os.path.normpath(version_file) target_version = '0.0.0' try: if os.path.exists(version_file): with open(version_file, 'r') as f: target_version = f.read().strip() except Exception: pass # 计算需要升级的步骤 needs_upgrade = _compare_versions(target_version, current_version) > 0 steps = [] for version, file_name in sorted(ALL_VERSIONS.items(), key=lambda x: _version_tuple(x[0])): if _compare_versions(version, current_version) > 0 and _compare_versions(version, target_version) <= 0: steps.append({'version': version, 'file': file_name}) return success_response(data={ 'needs_upgrade': needs_upgrade, 'current': current_version, 'target': target_version, 'steps': steps }) async def _verify_upgrade(expected_version: str) -> dict: """验证升级结果:检查版本号是否已正确更新 Returns: {'ok': bool, 'message': str} """ try: row = await execute_query( "SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'" ) if not row: return {'ok': False, 'message': 'db_version 记录不存在'} actual = row[0]['setting_value'] if actual != expected_version: return {'ok': False, 'message': f'版本号不匹配:期望 {expected_version},实际 {actual}'} return {'ok': True, 'message': '验证通过'} except Exception as e: return {'ok': False, 'message': f'验证查询失败: {str(e)}'} MAX_RETRIES = 2 @router.post("/step") async def execute_upgrade_step(request: Request): """执行单个升级步骤(含验证与重试)""" # 权限检查:仅班主任可执行升级操作 user_type = getattr(request.state, 'user_type', None) if user_type != 'admin': return error_response(message="仅管理员可执行升级操作", code=403) is_teacher = await PermissionChecker.check_is_teacher( getattr(request.state, 'user_id', 0) ) if not is_teacher: return error_response(message="仅班主任可执行升级操作", code=403) body = await request.json() version = body.get('version', '') if not version: return error_response(message='缺少版本号参数', code=400) if version not in ALL_VERSIONS: return error_response(message=f'未知版本: {version}', code=400) # SQL 文件路径 sql_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', 'sql', 'upgrades') sql_file = os.path.normpath(os.path.join(sql_dir, ALL_VERSIONS[version])) if not os.path.exists(sql_file): return error_response(message=f'SQL 文件不存在: {ALL_VERSIONS[version]}', code=500) last_error = None for attempt in range(1, MAX_RETRIES + 1): try: # 读取并执行 SQL with open(sql_file, 'r', encoding='utf-8') as f: sql_content = f.read().strip() if sql_content and sql_content != '--': pool = get_pool() async with pool.acquire() as conn: async with conn.cursor() as cursor: await _execute_sql_content(cursor, sql_content) await conn.commit() # 更新版本号 await execute_update( "INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) " "ON DUPLICATE KEY UPDATE setting_value = %s", (version, version) ) # 验证版本号是否正确写入 verify = await _verify_upgrade(version) if verify['ok']: new_version = version logger.info(f"数据库升级成功: v{version} ({ALL_VERSIONS[version]})") return success_response(data={ 'success': True, 'version': version, 'message': f"升级至 v{version} 成功 ({ALL_VERSIONS[version]})", 'current': new_version }) # 验证失败,准备重试 last_error = f"升级验证失败: {verify['message']}" if attempt < MAX_RETRIES: logger.warning(f"v{version} 升级验证失败,准备第 {attempt + 1} 次重试: {last_error}") continue except Exception as e: last_error = str(e) logger.warning(f"v{version} 升级第 {attempt} 次失败: {last_error}") if attempt < MAX_RETRIES: continue # 所有重试均失败 logger.error(f"数据库升级失败: v{version} (尝试 {MAX_RETRIES} 次) - {last_error}") return error_response( message=f"升级至 v{version} 失败 (尝试 {MAX_RETRIES} 次): {last_error}", code=500 ) def _compare_versions(v1: str, v2: str) -> int: """比较两个版本号,返回 1/0/-1""" t1 = _version_tuple(v1) t2 = _version_tuple(v2) if t1 > t2: return 1 elif t1 < t2: return -1 return 0 def _version_tuple(v: str) -> tuple: """将版本字符串转为可比较的元组""" parts = [] for p in v.split('.'): try: parts.append(int(p)) except ValueError: parts.append(0) return tuple(parts) async def _execute_sql_content(cursor, sql_content: str): """执行 SQL 内容,处理存储过程中的 DELIMITER""" sql_content = sql_content.strip() if not sql_content or sql_content == '--': return # 空文件或纯注释,无需执行 # 如果包含 DELIMITER,需要特殊处理 if 'DELIMITER' in sql_content.upper(): lines = sql_content.split('\n') current_block = [] in_procedure = False buffer = '' # 使用局部变量而非函数属性,避免跨调用泄漏 for line in lines: stripped = line.strip() # 跳过纯注释行 if stripped.startswith('--') or stripped.startswith('#'): if not in_procedure: continue else: current_block.append(line) continue if stripped.upper().startswith('DELIMITER $$'): # 开始存储过程定义 in_procedure = True current_block = [] continue elif stripped.upper() == 'DELIMITER ;': # 执行缓冲区中剩余的存储过程 if current_block: proc_sql = '\n'.join(current_block).strip() if proc_sql: proc_sql = re.sub(r'\$\$\s*$', '', proc_sql) if proc_sql: await cursor.execute(proc_sql) in_procedure = False current_block = [] continue elif stripped.upper().startswith('DELIMITER'): # 其他 DELIMITER 指令,跳过 continue if in_procedure: current_block.append(line) # 遇到 $$ 结尾的行,说明一个存储过程定义结束,立即执行 if stripped.endswith('$$'): proc_sql = '\n'.join(current_block).strip() if proc_sql: # 移除结尾的 $$ 定界符 proc_sql = re.sub(r'\$\$\s*$', '', proc_sql) if proc_sql: await cursor.execute(proc_sql) current_block = [] else: # 普通SQL,按完整语句分割(以分号结尾) if stripped: # 累积多行直到遇到分号 if buffer: buffer += ' ' + stripped else: buffer = stripped # 如果以分号结尾,执行并清空缓冲区 if buffer.rstrip().endswith(';'): stmt = buffer.rstrip(';').strip() if stmt: await cursor.execute(stmt) buffer = '' # 处理缓冲区中剩余的语句 if buffer: stmt = buffer.rstrip(';').strip() if stmt: await cursor.execute(stmt) else: # 无 DELIMITER,按分号+换行分割语句 statements = re.split(r';\s*\n', sql_content) for stmt in statements: stmt = stmt.strip() if stmt and stmt != '--': await cursor.execute(stmt)