v2.3更新

This commit is contained in:
2026-05-28 20:48:29 +08:00
parent ca53fdc349
commit 7dbe98ee02
15 changed files with 749 additions and 86 deletions

View File

@@ -18,17 +18,25 @@ import re
logger = setup_logger()
router = APIRouter()
# 版本列表(按顺序)
# 版本列表(按顺序)
ALL_VERSIONS = {
'1.0': 'v1.0.sql',
'1.1': 'v1.1.sql',
'1.2': 'v1.2.sql',
'1.3': 'v1.3.sql',
'1.4': 'v1.4.sql',
'1.5': 'v1.5.sql',
'1.6': 'v1.6.sql',
'1.7': 'v1.7.sql',
'1.8': 'v1.8.sql',
'2.0': 'v2.0.sql',
'2.0.1': 'v2.0.1.sql',
'2.1': 'v2.1.sql',
'2.2': 'v2.2.sql',
'2.3': 'v2.3.sql',
}
@router.get("/check")
async def check_upgrade(request: Request):
"""检查数据库版本是否需要升级"""
@@ -43,8 +51,6 @@ async def check_upgrade(request: Request):
if not is_teacher:
return error_response(message="仅班主任可执行升级操作", code=403)
user_id = request.state.user.get('user_id') if hasattr(request.state, 'user') else getattr(request.state, 'user_id', None)
# 检测当前数据库版本
current_version = '0.0.0'
try:
@@ -83,9 +89,32 @@ async def check_upgrade(request: Request):
})
async def _verify_upgrade(expected_version: str) -> dict:
"""验证升级结果:检查版本号是否已正确更新
Returns:
{'ok': bool, 'message': str}
"""
try:
row = await execute_query(
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
)
if not row:
return {'ok': False, 'message': 'db_version 记录不存在'}
actual = row[0]['setting_value']
if actual != expected_version:
return {'ok': False, 'message': f'版本号不匹配:期望 {expected_version},实际 {actual}'}
return {'ok': True, 'message': '验证通过'}
except Exception as e:
return {'ok': False, 'message': f'验证查询失败: {str(e)}'}
MAX_RETRIES = 2
@router.post("/step")
async def execute_upgrade_step(request: Request):
"""执行单个升级步骤"""
"""执行单个升级步骤(含验证与重试)"""
# 权限检查:仅班主任可执行升级操作
user_type = getattr(request.state, 'user_type', None)
if user_type != 'admin':
@@ -97,8 +126,6 @@ async def execute_upgrade_step(request: Request):
if not is_teacher:
return error_response(message="仅班主任可执行升级操作", code=403)
user_id = request.state.user.get('user_id') if hasattr(request.state, 'user') else getattr(request.state, 'user_id', None)
body = await request.json()
version = body.get('version', '')
@@ -115,50 +142,58 @@ async def execute_upgrade_step(request: Request):
if not os.path.exists(sql_file):
return error_response(message=f'SQL 文件不存在: {ALL_VERSIONS[version]}', code=500)
try:
# 读取并执行 SQL
with open(sql_file, 'r', encoding='utf-8') as f:
sql_content = f.read().strip()
if sql_content and sql_content != '--':
# 使用 aiomysql 直接执行多条 SQL
pool = get_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
# 分割 SQL 语句(按 DELIMITER 处理存储过程)
await _execute_sql_content(cursor, sql_content)
await conn.commit()
# 更新版本号
await execute_update(
"INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) "
"ON DUPLICATE KEY UPDATE setting_value = %s",
(version, version)
)
# 重新检测版本
new_version = '0.0.0'
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
row = await execute_query(
"SELECT setting_value FROM system_settings WHERE setting_key = 'db_version'"
# 读取并执行 SQL
with open(sql_file, 'r', encoding='utf-8') as f:
sql_content = f.read().strip()
if sql_content and sql_content != '--':
pool = get_pool()
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
await _execute_sql_content(cursor, sql_content)
await conn.commit()
# 更新版本号
await execute_update(
"INSERT INTO system_settings (setting_key, setting_value) VALUES ('db_version', %s) "
"ON DUPLICATE KEY UPDATE setting_value = %s",
(version, version)
)
if row:
new_version = row[0]['setting_value']
except Exception:
pass
logger.info(f"数据库升级成功: v{version} ({ALL_VERSIONS[version]})")
return success_response(data={
'success': True,
'version': version,
'message': f"升级至 v{version} 成功 ({ALL_VERSIONS[version]})",
'current': new_version
})
except Exception as e:
logger.error(f"数据库升级失败: v{version} - {str(e)}")
return error_response(message=f"升级至 v{version} 失败: {str(e)}", code=500)
# 验证版本号是否正确写入
verify = await _verify_upgrade(version)
if verify['ok']:
new_version = version
logger.info(f"数据库升级成功: v{version} ({ALL_VERSIONS[version]})")
return success_response(data={
'success': True,
'version': version,
'message': f"升级至 v{version} 成功 ({ALL_VERSIONS[version]})",
'current': new_version
})
# 验证失败,准备重试
last_error = f"升级验证失败: {verify['message']}"
if attempt < MAX_RETRIES:
logger.warning(f"v{version} 升级验证失败,准备第 {attempt + 1} 次重试: {last_error}")
continue
except Exception as e:
last_error = str(e)
logger.warning(f"v{version} 升级第 {attempt} 次失败: {last_error}")
if attempt < MAX_RETRIES:
continue
# 所有重试均失败
logger.error(f"数据库升级失败: v{version} (尝试 {MAX_RETRIES} 次) - {last_error}")
return error_response(
message=f"升级至 v{version} 失败 (尝试 {MAX_RETRIES} 次): {last_error}",
code=500
)
def _compare_versions(v1: str, v2: str) -> int:
@@ -185,16 +220,29 @@ def _version_tuple(v: str) -> tuple:
async def _execute_sql_content(cursor, sql_content: str):
"""执行 SQL 内容,处理存储过程中的 DELIMITER"""
sql_content = sql_content.strip()
if not sql_content or sql_content == '--':
return # 空文件或纯注释,无需执行
# 如果包含 DELIMITER需要特殊处理
if 'DELIMITER' in sql_content:
# 移除 DELIMITER 行,按 $$ 分割存储过程
if 'DELIMITER' in sql_content.upper():
lines = sql_content.split('\n')
current_block = []
in_procedure = False
buffer = '' # 使用局部变量而非函数属性,避免跨调用泄漏
for line in lines:
stripped = line.strip()
# 跳过纯注释行
if stripped.startswith('--') or stripped.startswith('#'):
if not in_procedure:
continue
else:
current_block.append(line)
continue
if stripped.upper().startswith('DELIMITER $$'):
# 开始存储过程定义
in_procedure = True
current_block = []
continue
@@ -203,27 +251,41 @@ async def _execute_sql_content(cursor, sql_content: str):
if current_block:
proc_sql = '\n'.join(current_block).strip()
if proc_sql:
# 移除存储过程结尾的 $$ 定界符(发送给 MySQL 服务器时不需要)
proc_sql = re.sub(r'\$\$\s*$', '', proc_sql)
await cursor.execute(proc_sql)
in_procedure = False
current_block = []
continue
elif stripped.upper().startswith('DELIMITER'):
# 其他 DELIMITER 指令,跳过
continue
if in_procedure:
current_block.append(line)
else:
# 普通SQL分号分割执行
if stripped and not stripped.startswith('--'):
# 简单的按分号分割
for stmt in stripped.split(';'):
stmt = stmt.strip()
# 普通SQL完整语句分割(以分号结尾)
if stripped:
# 累积多行直到遇到分号
if buffer:
buffer += ' ' + stripped
else:
buffer = stripped
# 如果以分号结尾,执行并清空缓冲区
if buffer.rstrip().endswith(';'):
stmt = buffer.rstrip(';').strip()
if stmt:
await cursor.execute(stmt)
buffer = ''
# 处理缓冲区中剩余的语句
if buffer:
stmt = buffer.rstrip(';').strip()
if stmt:
await cursor.execute(stmt)
else:
# 无 DELIMITER简单执行
# 按 CREATE 分割以支持多语句
# 分割 SQL 语句
# 无 DELIMITER按分号+换行分割语句
statements = re.split(r';\s*\n', sql_content)
for stmt in statements:
stmt = stmt.strip()