修复限流器问题

This commit is contained in:
2026-04-01 16:05:57 +08:00
parent 4e99ca2b21
commit 23319d8e10
8 changed files with 129 additions and 95 deletions

View File

@@ -23,7 +23,7 @@ RATE_LIMIT_REQUESTS=100
RATE_LIMIT_PERIOD=60 RATE_LIMIT_PERIOD=60
# CORS前端域名多个域名用逗号分隔 # CORS前端域名多个域名用逗号分隔
ALLOWED_ORIGINS=https://your-domain.com ALLOWED_ORIGINS=["https://your-domain.com", "http://localhost:3000"]
# 短信配置(阿里云接入) # 短信配置(阿里云接入)
ALIYUN_SMS_ACCESS_KEY_ID= ALIYUN_SMS_ACCESS_KEY_ID=

View File

@@ -11,19 +11,24 @@ from fastapi import Request, HTTPException
from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.extension import RateLimitDecorator
from ..config import settings from ..config import settings
# 创建限流器
limiter = Limiter(key_func=get_remote_address) limiter = Limiter(key_func=get_remote_address)
def setup_rate_limit(app): def setup_rate_limit(app):
"""配置限流"""
if settings.RATE_LIMIT_ENABLED: if settings.RATE_LIMIT_ENABLED:
app.state.limiter = limiter app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
def rate_limit(requests: int = None, period: int = None): def rate_limit(requests: int = None, period: int = None):
"""限流装饰器工厂 - 修复版本"""
if not settings.RATE_LIMIT_ENABLED: if not settings.RATE_LIMIT_ENABLED:
return lambda func: func return lambda func: func
req = requests or settings.RATE_LIMIT_REQUESTS req = requests or settings.RATE_LIMIT_REQUESTS
per = period or settings.RATE_LIMIT_PERIOD per = period or settings.RATE_LIMIT_PERIOD
return limiter.limit(f"{req}/{per} seconds") return limiter.limit(f"{req}/{per} seconds")

View File

@@ -8,7 +8,7 @@ License: AGPL v3
""" """
import re import re
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException, status, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from ...dependencies import DbDependency from ...dependencies import DbDependency
from ...models import User from ...models import User
@@ -31,29 +31,25 @@ def is_email(account: str) -> bool:
return re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', account) is not None return re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', account) is not None
@router.post("/send-code") @router.post("/send-code")
@rate_limit(requests=1, period=60) # 同一账号每分钟1次 @rate_limit(requests=1, period=60)
async def send_verify_code(request: SendCodeRequest, db: DbDependency): async def send_verify_code(
request: Request, # 添加 request 参数
req: SendCodeRequest,
db: DbDependency
):
"""发送验证码""" """发送验证码"""
account = request.account account = req.account
account_type = request.type account_type = req.type
# 验证账号格式
if account_type == "phone" and not is_phone(account): if account_type == "phone" and not is_phone(account):
raise HTTPException(status_code=400, detail="手机号格式错误") raise HTTPException(status_code=400, detail="手机号格式错误")
if account_type == "email" and not is_email(account): if account_type == "email" and not is_email(account):
raise HTTPException(status_code=400, detail="邮箱格式错误") raise HTTPException(status_code=400, detail="邮箱格式错误")
# 检查是否已存在(注册时使用)
# 这里只负责发送,不检查是否存在
# 生成验证码
code = generate_verify_code() code = generate_verify_code()
# 存储到 Redis5分钟有效
redis_key = f"verify:code:{account}" redis_key = f"verify:code:{account}"
redis_client.set(redis_key, code, expire=300) redis_client.set(redis_key, code, expire=300)
# 发送验证码
success = False success = False
if account_type == "phone": if account_type == "phone":
success = send_sms(account, code) success = send_sms(account, code)
@@ -68,19 +64,21 @@ async def send_verify_code(request: SendCodeRequest, db: DbDependency):
@router.post("/register") @router.post("/register")
@rate_limit(requests=5, period=60) @rate_limit(requests=5, period=60)
async def register(request: RegisterRequest, db: DbDependency): async def register(
"""注册(手机号/邮箱 + 验证码)""" request: Request, # 添加 request 参数
account = request.account req: RegisterRequest,
code = request.code db: DbDependency
password = request.password ):
"""注册"""
account = req.account
code = req.code
password = req.password
# 验证验证码
redis_key = f"verify:code:{account}" redis_key = f"verify:code:{account}"
saved_code = redis_client.get(redis_key) saved_code = redis_client.get(redis_key)
if not saved_code or saved_code != code: if not saved_code or saved_code != code:
raise HTTPException(status_code=400, detail="验证码错误或已过期") raise HTTPException(status_code=400, detail="验证码错误或已过期")
# 检查用户是否已存在
if is_phone(account): if is_phone(account):
existing = db.query(User).filter(User.phone == account).first() existing = db.query(User).filter(User.phone == account).first()
if existing: if existing:
@@ -98,10 +96,7 @@ async def register(request: RegisterRequest, db: DbDependency):
db.commit() db.commit()
db.refresh(user) db.refresh(user)
# 删除已使用的验证码
redis_client.delete(redis_key) redis_client.delete(redis_key)
# 生成 token
token = create_access_token({"sub": str(user.id)}) token = create_access_token({"sub": str(user.id)})
logger.info(f"新用户注册: {account}") logger.info(f"新用户注册: {account}")
@@ -113,13 +108,16 @@ async def register(request: RegisterRequest, db: DbDependency):
@router.post("/login") @router.post("/login")
@rate_limit(requests=10, period=60) @rate_limit(requests=10, period=60)
async def login(request: LoginRequest, db: DbDependency): async def login(
"""登录(支持验证码登录或密码登录)""" request: Request, # 添加 request 参数
account = request.account req: LoginRequest,
password = request.password db: DbDependency
code = request.code ):
"""登录"""
account = req.account
password = req.password
code = req.code
# 查找用户
user = None user = None
if is_phone(account): if is_phone(account):
user = db.query(User).filter(User.phone == account).first() user = db.query(User).filter(User.phone == account).first()
@@ -131,22 +129,18 @@ async def login(request: LoginRequest, db: DbDependency):
if not user: if not user:
raise HTTPException(status_code=400, detail="用户不存在") raise HTTPException(status_code=400, detail="用户不存在")
# 验证码登录
if code: if code:
redis_key = f"verify:code:{account}" redis_key = f"verify:code:{account}"
saved_code = redis_client.get(redis_key) saved_code = redis_client.get(redis_key)
if not saved_code or saved_code != code: if not saved_code or saved_code != code:
raise HTTPException(status_code=400, detail="验证码错误或已过期") raise HTTPException(status_code=400, detail="验证码错误或已过期")
redis_client.delete(redis_key) redis_client.delete(redis_key)
# 密码登录
elif password: elif password:
if not verify_password(password, user.password_hash): if not verify_password(password, user.password_hash):
raise HTTPException(status_code=400, detail="密码错误") raise HTTPException(status_code=400, detail="密码错误")
else: else:
raise HTTPException(status_code=400, detail="请提供验证码或密码") raise HTTPException(status_code=400, detail="请提供验证码或密码")
# 更新最后登录时间
from datetime import datetime from datetime import datetime
user.last_login = datetime.now() user.last_login = datetime.now()
db.commit() db.commit()
@@ -161,7 +155,6 @@ async def login(request: LoginRequest, db: DbDependency):
} }
@router.post("/wechat/login") @router.post("/wechat/login")
async def wechat_login(): async def wechat_login(request: Request): # 添加 request 参数
"""微信公众号登录(预留)""" """微信公众号登录(预留)"""
# TODO: 实现公众号 OAuth2 登录
raise HTTPException(status_code=501, detail="功能开发中") raise HTTPException(status_code=501, detail="功能开发中")

View File

@@ -7,7 +7,7 @@ Author: Canglan <admin@sea-studio.top>
License: AGPL v3 License: AGPL v3
""" """
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query, Request
from typing import Optional, List from typing import Optional, List
from ...dependencies import CurrentUserDependency, DbDependency from ...dependencies import CurrentUserDependency, DbDependency
from ...models import Note from ...models import Note
@@ -19,6 +19,7 @@ router = APIRouter(prefix="/api/v1/notes", tags=["notes"])
@router.get("/", response_model=List[NoteResponse]) @router.get("/", response_model=List[NoteResponse])
@rate_limit(requests=100, period=60) @rate_limit(requests=100, period=60)
async def get_notes( async def get_notes(
request: Request, # 添加 request 参数
current_user: CurrentUserDependency, current_user: CurrentUserDependency,
db: DbDependency, db: DbDependency,
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
@@ -33,6 +34,7 @@ async def get_notes(
@router.post("/", response_model=NoteResponse, status_code=201) @router.post("/", response_model=NoteResponse, status_code=201)
@rate_limit(requests=50, period=60) @rate_limit(requests=50, period=60)
async def create_note( async def create_note(
request: Request, # 添加 request 参数
data: NoteCreate, data: NoteCreate,
current_user: CurrentUserDependency, current_user: CurrentUserDependency,
db: DbDependency db: DbDependency
@@ -46,6 +48,7 @@ async def create_note(
@router.put("/{note_id}", response_model=NoteResponse) @router.put("/{note_id}", response_model=NoteResponse)
@rate_limit(requests=50, period=60) @rate_limit(requests=50, period=60)
async def update_note( async def update_note(
request: Request, # 添加 request 参数
note_id: int, note_id: int,
data: NoteUpdate, data: NoteUpdate,
current_user: CurrentUserDependency, current_user: CurrentUserDependency,
@@ -66,6 +69,7 @@ async def update_note(
@router.delete("/{note_id}") @router.delete("/{note_id}")
@rate_limit(requests=30, period=60) @rate_limit(requests=30, period=60)
async def delete_note( async def delete_note(
request: Request, # 添加 request 参数
note_id: int, note_id: int,
current_user: CurrentUserDependency, current_user: CurrentUserDependency,
db: DbDependency db: DbDependency

View File

@@ -8,7 +8,7 @@ License: AGPL v3
""" """
from datetime import datetime from datetime import datetime
from fastapi import APIRouter from fastapi import APIRouter, HTTPException, Request
from ...utils.redis_client import redis_client from ...utils.redis_client import redis_client
from ...models import ToolStatsTotal from ...models import ToolStatsTotal
from ...dependencies import DbDependency from ...dependencies import DbDependency
@@ -16,7 +16,6 @@ from ...middleware.rate_limit import rate_limit
router = APIRouter(prefix="/api/v1", tags=["stats"]) router = APIRouter(prefix="/api/v1", tags=["stats"])
# 预定义工具名称(对应前端页面)
TOOL_NAMES = [ TOOL_NAMES = [
"todos", "notes", "password", "qrcode", "todos", "notes", "password", "qrcode",
"crypto_hash", "crypto_base64", "crypto_url", "crypto_aes", "json" "crypto_hash", "crypto_base64", "crypto_url", "crypto_aes", "json"
@@ -24,7 +23,11 @@ TOOL_NAMES = [
@router.post("/tool/usage") @router.post("/tool/usage")
@rate_limit(requests=20, period=60) @rate_limit(requests=20, period=60)
async def record_usage(tool_name: str, db: DbDependency): async def record_usage(
request: Request, # 添加 request 参数
tool_name: str,
db: DbDependency
):
"""记录页面访问次数(热度)""" """记录页面访问次数(热度)"""
if tool_name not in TOOL_NAMES: if tool_name not in TOOL_NAMES:
raise HTTPException(status_code=400, detail="无效的工具名") raise HTTPException(status_code=400, detail="无效的工具名")
@@ -33,15 +36,10 @@ async def record_usage(tool_name: str, db: DbDependency):
today_key = f"tool:stats:today:{tool_name}:{today}" today_key = f"tool:stats:today:{tool_name}:{today}"
total_key = f"tool:stats:total:{tool_name}" total_key = f"tool:stats:total:{tool_name}"
# 增加今日计数设置48小时过期
today_count = redis_client.incr(today_key) today_count = redis_client.incr(today_key)
redis_client.expire(today_key, 48 * 3600) redis_client.expire(today_key, 48 * 3600)
# 增加总计数
total_count = redis_client.incr(total_key) total_count = redis_client.incr(total_key)
# 异步更新 MySQL可选这里简单处理
# 实际可改为定时任务同步,此处为简化,直接更新
stats = db.query(ToolStatsTotal).filter(ToolStatsTotal.tool_name == tool_name).first() stats = db.query(ToolStatsTotal).filter(ToolStatsTotal.tool_name == tool_name).first()
if stats: if stats:
stats.total_count = total_count stats.total_count = total_count
@@ -54,7 +52,10 @@ async def record_usage(tool_name: str, db: DbDependency):
@router.get("/tool/stats") @router.get("/tool/stats")
@rate_limit(requests=100, period=60) @rate_limit(requests=100, period=60)
async def get_stats(db: DbDependency): async def get_stats(
request: Request, # 添加 request 参数
db: DbDependency
):
"""获取所有工具的今日/总访问次数""" """获取所有工具的今日/总访问次数"""
today = datetime.now().strftime("%Y-%m-%d") today = datetime.now().strftime("%Y-%m-%d")
result = {} result = {}
@@ -67,7 +68,6 @@ async def get_stats(db: DbDependency):
total_count = redis_client.get(total_key) total_count = redis_client.get(total_key)
if total_count is None: if total_count is None:
# 从 MySQL 读取
stats = db.query(ToolStatsTotal).filter(ToolStatsTotal.tool_name == tool_name).first() stats = db.query(ToolStatsTotal).filter(ToolStatsTotal.tool_name == tool_name).first()
total_count = stats.total_count if stats else 0 total_count = stats.total_count if stats else 0
else: else:

View File

@@ -7,7 +7,7 @@ Author: Canglan <admin@sea-studio.top>
License: AGPL v3 License: AGPL v3
""" """
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query, Request
from typing import Optional, List from typing import Optional, List
from ...dependencies import CurrentUserDependency, DbDependency from ...dependencies import CurrentUserDependency, DbDependency
from ...models import Todo from ...models import Todo
@@ -19,6 +19,7 @@ router = APIRouter(prefix="/api/v1/todos", tags=["todos"])
@router.get("/", response_model=List[TodoResponse]) @router.get("/", response_model=List[TodoResponse])
@rate_limit(requests=100, period=60) @rate_limit(requests=100, period=60)
async def get_todos( async def get_todos(
request: Request, # 添加 request 参数
current_user: CurrentUserDependency, current_user: CurrentUserDependency,
db: DbDependency, db: DbDependency,
skip: int = Query(0, ge=0), skip: int = Query(0, ge=0),
@@ -38,6 +39,7 @@ async def get_todos(
@router.post("/", response_model=TodoResponse, status_code=201) @router.post("/", response_model=TodoResponse, status_code=201)
@rate_limit(requests=50, period=60) @rate_limit(requests=50, period=60)
async def create_todo( async def create_todo(
request: Request, # 添加 request 参数
data: TodoCreate, data: TodoCreate,
current_user: CurrentUserDependency, current_user: CurrentUserDependency,
db: DbDependency db: DbDependency
@@ -51,6 +53,7 @@ async def create_todo(
@router.put("/{todo_id}", response_model=TodoResponse) @router.put("/{todo_id}", response_model=TodoResponse)
@rate_limit(requests=50, period=60) @rate_limit(requests=50, period=60)
async def update_todo( async def update_todo(
request: Request, # 添加 request 参数
todo_id: int, todo_id: int,
data: TodoUpdate, data: TodoUpdate,
current_user: CurrentUserDependency, current_user: CurrentUserDependency,
@@ -71,6 +74,7 @@ async def update_todo(
@router.delete("/{todo_id}") @router.delete("/{todo_id}")
@rate_limit(requests=30, period=60) @rate_limit(requests=30, period=60)
async def delete_todo( async def delete_todo(
request: Request, # 添加 request 参数
todo_id: int, todo_id: int,
current_user: CurrentUserDependency, current_user: CurrentUserDependency,
db: DbDependency db: DbDependency

View File

@@ -16,7 +16,7 @@ from io import BytesIO
import base64 as b64 import base64 as b64
from Crypto.Cipher import AES from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad from Crypto.Util.Padding import pad, unpad
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Request, Query
from ...schemas import ( from ...schemas import (
HashRequest, Base64Request, URLRequest, AESRequest, HashRequest, Base64Request, URLRequest, AESRequest,
JSONValidateRequest, JSONValidateResponse JSONValidateRequest, JSONValidateResponse
@@ -29,12 +29,13 @@ router = APIRouter(prefix="/api/v1", tags=["tools"])
@router.get("/password/generate") @router.get("/password/generate")
@rate_limit(requests=50, period=60) @rate_limit(requests=50, period=60)
async def generate_password( async def generate_password(
length: int = 12, request: Request, # 添加 request 参数
upper: bool = True, length: int = Query(12, ge=4, le=64),
lower: bool = True, upper: bool = Query(True),
digits: bool = True, lower: bool = Query(True),
symbols: bool = True, digits: bool = Query(True),
count: int = 1 symbols: bool = Query(True),
count: int = Query(1, ge=1, le=10)
): ):
import random import random
import string import string
@@ -58,7 +59,11 @@ async def generate_password(
# ========== 二维码 ========== # ========== 二维码 ==========
@router.post("/qrcode/generate") @router.post("/qrcode/generate")
@rate_limit(requests=30, period=60) @rate_limit(requests=30, period=60)
async def generate_qrcode(content: str, size: int = 10): async def generate_qrcode(
request: Request, # 添加 request 参数
content: str,
size: int = 10
):
if not content: if not content:
raise HTTPException(status_code=400, detail="内容不能为空") raise HTTPException(status_code=400, detail="内容不能为空")
@@ -76,9 +81,12 @@ async def generate_qrcode(content: str, size: int = 10):
# ========== 哈希 ========== # ========== 哈希 ==========
@router.post("/crypto/hash") @router.post("/crypto/hash")
@rate_limit(requests=100, period=60) @rate_limit(requests=100, period=60)
async def compute_hash(request: HashRequest): async def compute_hash(
text = request.text.encode('utf-8') request: Request, # 添加 request 参数
algo = request.algorithm.lower() req: HashRequest
):
text = req.text.encode('utf-8')
algo = req.algorithm.lower()
if algo == "md5": if algo == "md5":
result = hashlib.md5(text).hexdigest() result = hashlib.md5(text).hexdigest()
@@ -96,61 +104,69 @@ async def compute_hash(request: HashRequest):
# ========== Base64 ========== # ========== Base64 ==========
@router.post("/crypto/base64") @router.post("/crypto/base64")
@rate_limit(requests=100, period=60) @rate_limit(requests=100, period=60)
async def base64_process(request: Base64Request): async def base64_process(
if request.action == "encode": request: Request, # 添加 request 参数
result = base64.b64encode(request.text.encode('utf-8')).decode('utf-8') req: Base64Request
elif request.action == "decode": ):
if req.action == "encode":
result = base64.b64encode(req.text.encode('utf-8')).decode('utf-8')
elif req.action == "decode":
try: try:
result = base64.b64decode(request.text).decode('utf-8') result = base64.b64decode(req.text).decode('utf-8')
except Exception: except Exception:
raise HTTPException(status_code=400, detail="Base64 解码失败") raise HTTPException(status_code=400, detail="Base64 解码失败")
else: else:
raise HTTPException(status_code=400, detail="无效的 action") raise HTTPException(status_code=400, detail="无效的 action")
return {"action": request.action, "result": result} return {"action": req.action, "result": result}
# ========== URL 编解码 ========== # ========== URL 编解码 ==========
@router.post("/crypto/url") @router.post("/crypto/url")
@rate_limit(requests=100, period=60) @rate_limit(requests=100, period=60)
async def url_process(request: URLRequest): async def url_process(
if request.action == "encode": request: Request, # 添加 request 参数
result = urllib.parse.quote(request.text, safe='') req: URLRequest
elif request.action == "decode": ):
result = urllib.parse.unquote(request.text) if req.action == "encode":
result = urllib.parse.quote(req.text, safe='')
elif req.action == "decode":
result = urllib.parse.unquote(req.text)
else: else:
raise HTTPException(status_code=400, detail="无效的 action") raise HTTPException(status_code=400, detail="无效的 action")
return {"action": request.action, "result": result} return {"action": req.action, "result": result}
# ========== AES 加解密 ========== # ========== AES 加解密 ==========
@router.post("/crypto/aes") @router.post("/crypto/aes")
@rate_limit(requests=50, period=60) @rate_limit(requests=50, period=60)
async def aes_process(request: AESRequest): async def aes_process(
request: Request, # 添加 request 参数
req: AESRequest
):
try: try:
key = request.key.encode('utf-8') key = req.key.encode('utf-8')
mode_map = {"ECB": AES.MODE_ECB, "CBC": AES.MODE_CBC, "GCM": AES.MODE_GCM} mode_map = {"ECB": AES.MODE_ECB, "CBC": AES.MODE_CBC, "GCM": AES.MODE_GCM}
mode = mode_map.get(request.mode) mode = mode_map.get(req.mode)
if not mode: if not mode:
raise HTTPException(status_code=400, detail="不支持的 AES 模式") raise HTTPException(status_code=400, detail="不支持的 AES 模式")
# 密钥长度处理
if len(key) not in [16, 24, 32]: if len(key) not in [16, 24, 32]:
raise HTTPException(status_code=400, detail="密钥长度必须为 16/24/32 字节") raise HTTPException(status_code=400, detail="密钥长度必须为 16/24/32 字节")
if request.action == "encrypt": if req.action == "encrypt":
cipher = AES.new(key, mode, iv=request.iv.encode('utf-8') if request.iv else None) cipher = AES.new(key, mode, iv=req.iv.encode('utf-8') if req.iv else None)
if request.mode == "GCM": if req.mode == "GCM":
ciphertext = cipher.encrypt(request.text.encode('utf-8')) ciphertext = cipher.encrypt(req.text.encode('utf-8'))
result = b64.b64encode(ciphertext).decode('utf-8') result = b64.b64encode(ciphertext).decode('utf-8')
else: else:
padded = pad(request.text.encode('utf-8'), AES.block_size) padded = pad(req.text.encode('utf-8'), AES.block_size)
ciphertext = cipher.encrypt(padded) ciphertext = cipher.encrypt(padded)
result = b64.b64encode(ciphertext).decode('utf-8') result = b64.b64encode(ciphertext).decode('utf-8')
elif request.action == "decrypt": elif req.action == "decrypt":
ciphertext = b64.b64decode(request.text) ciphertext = b64.b64decode(req.text)
cipher = AES.new(key, mode, iv=request.iv.encode('utf-8') if request.iv else None) cipher = AES.new(key, mode, iv=req.iv.encode('utf-8') if req.iv else None)
if request.mode == "GCM": if req.mode == "GCM":
plaintext = cipher.decrypt(ciphertext).decode('utf-8') plaintext = cipher.decrypt(ciphertext).decode('utf-8')
result = plaintext result = plaintext
else: else:
@@ -159,7 +175,7 @@ async def aes_process(request: AESRequest):
else: else:
raise HTTPException(status_code=400, detail="无效的 action") raise HTTPException(status_code=400, detail="无效的 action")
return {"mode": request.mode, "action": request.action, "result": result} return {"mode": req.mode, "action": req.action, "result": result}
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"AES 操作失败: {str(e)}") raise HTTPException(status_code=400, detail=f"AES 操作失败: {str(e)}")
@@ -167,9 +183,12 @@ async def aes_process(request: AESRequest):
# ========== JSON 校验 ========== # ========== JSON 校验 ==========
@router.post("/json/validate", response_model=JSONValidateResponse) @router.post("/json/validate", response_model=JSONValidateResponse)
@rate_limit(requests=100, period=60) @rate_limit(requests=100, period=60)
async def validate_json(request: JSONValidateRequest): async def validate_json(
request: Request, # 添加 request 参数
req: JSONValidateRequest
):
try: try:
parsed = json.loads(request.json_string) parsed = json.loads(req.json_string)
formatted = json.dumps(parsed, indent=2, ensure_ascii=False) formatted = json.dumps(parsed, indent=2, ensure_ascii=False)
return JSONValidateResponse(valid=True, formatted=formatted) return JSONValidateResponse(valid=True, formatted=formatted)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:

View File

@@ -7,31 +7,38 @@ Author: Canglan <admin@sea-studio.top>
License: AGPL v3 License: AGPL v3
""" """
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Request
from ...dependencies import CurrentUserDependency, DbDependency from ...dependencies import CurrentUserDependency, DbDependency
from ...models import User from ...models import User
from ...schemas import UserResponse, UserUpdateRequest from ...schemas import UserResponse, UserUpdateRequest
from ...utils.security import hash_password, verify_password from ...utils.security import hash_password, verify_password
from ...utils.logger import logger from ...utils.logger import logger
from ...middleware.rate_limit import rate_limit
router = APIRouter(prefix="/api/v1/user", tags=["user"]) router = APIRouter(prefix="/api/v1/user", tags=["user"])
@router.get("/profile", response_model=UserResponse) @router.get("/profile", response_model=UserResponse)
async def get_profile(current_user: CurrentUserDependency): @rate_limit(requests=50, period=60)
async def get_profile(
request: Request, # 添加 request 参数
current_user: CurrentUserDependency
):
"""获取当前用户信息""" """获取当前用户信息"""
return current_user return current_user
@router.put("/profile", response_model=UserResponse) @router.put("/profile", response_model=UserResponse)
@rate_limit(requests=20, period=60)
async def update_profile( async def update_profile(
request: UserUpdateRequest, request: Request, # 添加 request 参数
req: UserUpdateRequest,
current_user: CurrentUserDependency, current_user: CurrentUserDependency,
db: DbDependency db: DbDependency
): ):
"""更新用户信息""" """更新用户信息"""
if request.username: if req.username:
current_user.username = request.username current_user.username = req.username
if request.avatar: if req.avatar:
current_user.avatar = request.avatar current_user.avatar = req.avatar
db.commit() db.commit()
db.refresh(current_user) db.refresh(current_user)
@@ -40,7 +47,9 @@ async def update_profile(
return current_user return current_user
@router.post("/change-password") @router.post("/change-password")
@rate_limit(requests=10, period=60)
async def change_password( async def change_password(
request: Request, # 添加 request 参数
old_password: str, old_password: str,
new_password: str, new_password: str,
current_user: CurrentUserDependency, current_user: CurrentUserDependency,