修复限流器问题
This commit is contained in:
@@ -11,19 +11,24 @@ from fastapi import Request, HTTPException
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.extension import RateLimitDecorator
|
||||
from ..config import settings
|
||||
|
||||
# 创建限流器
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
def setup_rate_limit(app):
|
||||
"""配置限流"""
|
||||
if settings.RATE_LIMIT_ENABLED:
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
def rate_limit(requests: int = None, period: int = None):
|
||||
"""限流装饰器工厂 - 修复版本"""
|
||||
if not settings.RATE_LIMIT_ENABLED:
|
||||
return lambda func: func
|
||||
|
||||
|
||||
req = requests or settings.RATE_LIMIT_REQUESTS
|
||||
per = period or settings.RATE_LIMIT_PERIOD
|
||||
|
||||
return limiter.limit(f"{req}/{per} seconds")
|
||||
@@ -8,7 +8,7 @@ License: AGPL v3
|
||||
"""
|
||||
|
||||
import re
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from fastapi import APIRouter, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from ...dependencies import DbDependency
|
||||
from ...models import User
|
||||
@@ -31,29 +31,25 @@ def is_email(account: str) -> bool:
|
||||
return re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', account) is not None
|
||||
|
||||
@router.post("/send-code")
|
||||
@rate_limit(requests=1, period=60) # 同一账号每分钟1次
|
||||
async def send_verify_code(request: SendCodeRequest, db: DbDependency):
|
||||
@rate_limit(requests=1, period=60)
|
||||
async def send_verify_code(
|
||||
request: Request, # 添加 request 参数
|
||||
req: SendCodeRequest,
|
||||
db: DbDependency
|
||||
):
|
||||
"""发送验证码"""
|
||||
account = request.account
|
||||
account_type = request.type
|
||||
account = req.account
|
||||
account_type = req.type
|
||||
|
||||
# 验证账号格式
|
||||
if account_type == "phone" and not is_phone(account):
|
||||
raise HTTPException(status_code=400, detail="手机号格式错误")
|
||||
if account_type == "email" and not is_email(account):
|
||||
raise HTTPException(status_code=400, detail="邮箱格式错误")
|
||||
|
||||
# 检查是否已存在(注册时使用)
|
||||
# 这里只负责发送,不检查是否存在
|
||||
|
||||
# 生成验证码
|
||||
code = generate_verify_code()
|
||||
|
||||
# 存储到 Redis(5分钟有效)
|
||||
redis_key = f"verify:code:{account}"
|
||||
redis_client.set(redis_key, code, expire=300)
|
||||
|
||||
# 发送验证码
|
||||
success = False
|
||||
if account_type == "phone":
|
||||
success = send_sms(account, code)
|
||||
@@ -68,19 +64,21 @@ async def send_verify_code(request: SendCodeRequest, db: DbDependency):
|
||||
|
||||
@router.post("/register")
|
||||
@rate_limit(requests=5, period=60)
|
||||
async def register(request: RegisterRequest, db: DbDependency):
|
||||
"""注册(手机号/邮箱 + 验证码)"""
|
||||
account = request.account
|
||||
code = request.code
|
||||
password = request.password
|
||||
async def register(
|
||||
request: Request, # 添加 request 参数
|
||||
req: RegisterRequest,
|
||||
db: DbDependency
|
||||
):
|
||||
"""注册"""
|
||||
account = req.account
|
||||
code = req.code
|
||||
password = req.password
|
||||
|
||||
# 验证验证码
|
||||
redis_key = f"verify:code:{account}"
|
||||
saved_code = redis_client.get(redis_key)
|
||||
if not saved_code or saved_code != code:
|
||||
raise HTTPException(status_code=400, detail="验证码错误或已过期")
|
||||
|
||||
# 检查用户是否已存在
|
||||
if is_phone(account):
|
||||
existing = db.query(User).filter(User.phone == account).first()
|
||||
if existing:
|
||||
@@ -98,10 +96,7 @@ async def register(request: RegisterRequest, db: DbDependency):
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
# 删除已使用的验证码
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
# 生成 token
|
||||
token = create_access_token({"sub": str(user.id)})
|
||||
|
||||
logger.info(f"新用户注册: {account}")
|
||||
@@ -113,13 +108,16 @@ async def register(request: RegisterRequest, db: DbDependency):
|
||||
|
||||
@router.post("/login")
|
||||
@rate_limit(requests=10, period=60)
|
||||
async def login(request: LoginRequest, db: DbDependency):
|
||||
"""登录(支持验证码登录或密码登录)"""
|
||||
account = request.account
|
||||
password = request.password
|
||||
code = request.code
|
||||
async def login(
|
||||
request: Request, # 添加 request 参数
|
||||
req: LoginRequest,
|
||||
db: DbDependency
|
||||
):
|
||||
"""登录"""
|
||||
account = req.account
|
||||
password = req.password
|
||||
code = req.code
|
||||
|
||||
# 查找用户
|
||||
user = None
|
||||
if is_phone(account):
|
||||
user = db.query(User).filter(User.phone == account).first()
|
||||
@@ -131,22 +129,18 @@ async def login(request: LoginRequest, db: DbDependency):
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="用户不存在")
|
||||
|
||||
# 验证码登录
|
||||
if code:
|
||||
redis_key = f"verify:code:{account}"
|
||||
saved_code = redis_client.get(redis_key)
|
||||
if not saved_code or saved_code != code:
|
||||
raise HTTPException(status_code=400, detail="验证码错误或已过期")
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
# 密码登录
|
||||
elif password:
|
||||
if not verify_password(password, user.password_hash):
|
||||
raise HTTPException(status_code=400, detail="密码错误")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="请提供验证码或密码")
|
||||
|
||||
# 更新最后登录时间
|
||||
from datetime import datetime
|
||||
user.last_login = datetime.now()
|
||||
db.commit()
|
||||
@@ -161,7 +155,6 @@ async def login(request: LoginRequest, db: DbDependency):
|
||||
}
|
||||
|
||||
@router.post("/wechat/login")
|
||||
async def wechat_login():
|
||||
async def wechat_login(request: Request): # 添加 request 参数
|
||||
"""微信公众号登录(预留)"""
|
||||
# TODO: 实现公众号 OAuth2 登录
|
||||
raise HTTPException(status_code=501, detail="功能开发中")
|
||||
@@ -7,7 +7,7 @@ Author: Canglan <admin@sea-studio.top>
|
||||
License: AGPL v3
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from typing import Optional, List
|
||||
from ...dependencies import CurrentUserDependency, DbDependency
|
||||
from ...models import Note
|
||||
@@ -19,6 +19,7 @@ router = APIRouter(prefix="/api/v1/notes", tags=["notes"])
|
||||
@router.get("/", response_model=List[NoteResponse])
|
||||
@rate_limit(requests=100, period=60)
|
||||
async def get_notes(
|
||||
request: Request, # 添加 request 参数
|
||||
current_user: CurrentUserDependency,
|
||||
db: DbDependency,
|
||||
skip: int = Query(0, ge=0),
|
||||
@@ -33,6 +34,7 @@ async def get_notes(
|
||||
@router.post("/", response_model=NoteResponse, status_code=201)
|
||||
@rate_limit(requests=50, period=60)
|
||||
async def create_note(
|
||||
request: Request, # 添加 request 参数
|
||||
data: NoteCreate,
|
||||
current_user: CurrentUserDependency,
|
||||
db: DbDependency
|
||||
@@ -46,6 +48,7 @@ async def create_note(
|
||||
@router.put("/{note_id}", response_model=NoteResponse)
|
||||
@rate_limit(requests=50, period=60)
|
||||
async def update_note(
|
||||
request: Request, # 添加 request 参数
|
||||
note_id: int,
|
||||
data: NoteUpdate,
|
||||
current_user: CurrentUserDependency,
|
||||
@@ -66,6 +69,7 @@ async def update_note(
|
||||
@router.delete("/{note_id}")
|
||||
@rate_limit(requests=30, period=60)
|
||||
async def delete_note(
|
||||
request: Request, # 添加 request 参数
|
||||
note_id: int,
|
||||
current_user: CurrentUserDependency,
|
||||
db: DbDependency
|
||||
|
||||
@@ -8,7 +8,7 @@ License: AGPL v3
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from ...utils.redis_client import redis_client
|
||||
from ...models import ToolStatsTotal
|
||||
from ...dependencies import DbDependency
|
||||
@@ -16,7 +16,6 @@ from ...middleware.rate_limit import rate_limit
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["stats"])
|
||||
|
||||
# 预定义工具名称(对应前端页面)
|
||||
TOOL_NAMES = [
|
||||
"todos", "notes", "password", "qrcode",
|
||||
"crypto_hash", "crypto_base64", "crypto_url", "crypto_aes", "json"
|
||||
@@ -24,7 +23,11 @@ TOOL_NAMES = [
|
||||
|
||||
@router.post("/tool/usage")
|
||||
@rate_limit(requests=20, period=60)
|
||||
async def record_usage(tool_name: str, db: DbDependency):
|
||||
async def record_usage(
|
||||
request: Request, # 添加 request 参数
|
||||
tool_name: str,
|
||||
db: DbDependency
|
||||
):
|
||||
"""记录页面访问次数(热度)"""
|
||||
if tool_name not in TOOL_NAMES:
|
||||
raise HTTPException(status_code=400, detail="无效的工具名")
|
||||
@@ -33,15 +36,10 @@ async def record_usage(tool_name: str, db: DbDependency):
|
||||
today_key = f"tool:stats:today:{tool_name}:{today}"
|
||||
total_key = f"tool:stats:total:{tool_name}"
|
||||
|
||||
# 增加今日计数(设置48小时过期)
|
||||
today_count = redis_client.incr(today_key)
|
||||
redis_client.expire(today_key, 48 * 3600)
|
||||
|
||||
# 增加总计数
|
||||
total_count = redis_client.incr(total_key)
|
||||
|
||||
# 异步更新 MySQL(可选,这里简单处理)
|
||||
# 实际可改为定时任务同步,此处为简化,直接更新
|
||||
stats = db.query(ToolStatsTotal).filter(ToolStatsTotal.tool_name == tool_name).first()
|
||||
if stats:
|
||||
stats.total_count = total_count
|
||||
@@ -54,7 +52,10 @@ async def record_usage(tool_name: str, db: DbDependency):
|
||||
|
||||
@router.get("/tool/stats")
|
||||
@rate_limit(requests=100, period=60)
|
||||
async def get_stats(db: DbDependency):
|
||||
async def get_stats(
|
||||
request: Request, # 添加 request 参数
|
||||
db: DbDependency
|
||||
):
|
||||
"""获取所有工具的今日/总访问次数"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
result = {}
|
||||
@@ -67,7 +68,6 @@ async def get_stats(db: DbDependency):
|
||||
total_count = redis_client.get(total_key)
|
||||
|
||||
if total_count is None:
|
||||
# 从 MySQL 读取
|
||||
stats = db.query(ToolStatsTotal).filter(ToolStatsTotal.tool_name == tool_name).first()
|
||||
total_count = stats.total_count if stats else 0
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,7 @@ Author: Canglan <admin@sea-studio.top>
|
||||
License: AGPL v3
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from typing import Optional, List
|
||||
from ...dependencies import CurrentUserDependency, DbDependency
|
||||
from ...models import Todo
|
||||
@@ -19,6 +19,7 @@ router = APIRouter(prefix="/api/v1/todos", tags=["todos"])
|
||||
@router.get("/", response_model=List[TodoResponse])
|
||||
@rate_limit(requests=100, period=60)
|
||||
async def get_todos(
|
||||
request: Request, # 添加 request 参数
|
||||
current_user: CurrentUserDependency,
|
||||
db: DbDependency,
|
||||
skip: int = Query(0, ge=0),
|
||||
@@ -38,6 +39,7 @@ async def get_todos(
|
||||
@router.post("/", response_model=TodoResponse, status_code=201)
|
||||
@rate_limit(requests=50, period=60)
|
||||
async def create_todo(
|
||||
request: Request, # 添加 request 参数
|
||||
data: TodoCreate,
|
||||
current_user: CurrentUserDependency,
|
||||
db: DbDependency
|
||||
@@ -51,6 +53,7 @@ async def create_todo(
|
||||
@router.put("/{todo_id}", response_model=TodoResponse)
|
||||
@rate_limit(requests=50, period=60)
|
||||
async def update_todo(
|
||||
request: Request, # 添加 request 参数
|
||||
todo_id: int,
|
||||
data: TodoUpdate,
|
||||
current_user: CurrentUserDependency,
|
||||
@@ -71,6 +74,7 @@ async def update_todo(
|
||||
@router.delete("/{todo_id}")
|
||||
@rate_limit(requests=30, period=60)
|
||||
async def delete_todo(
|
||||
request: Request, # 添加 request 参数
|
||||
todo_id: int,
|
||||
current_user: CurrentUserDependency,
|
||||
db: DbDependency
|
||||
|
||||
@@ -16,7 +16,7 @@ from io import BytesIO
|
||||
import base64 as b64
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request, Query
|
||||
from ...schemas import (
|
||||
HashRequest, Base64Request, URLRequest, AESRequest,
|
||||
JSONValidateRequest, JSONValidateResponse
|
||||
@@ -29,12 +29,13 @@ router = APIRouter(prefix="/api/v1", tags=["tools"])
|
||||
@router.get("/password/generate")
|
||||
@rate_limit(requests=50, period=60)
|
||||
async def generate_password(
|
||||
length: int = 12,
|
||||
upper: bool = True,
|
||||
lower: bool = True,
|
||||
digits: bool = True,
|
||||
symbols: bool = True,
|
||||
count: int = 1
|
||||
request: Request, # 添加 request 参数
|
||||
length: int = Query(12, ge=4, le=64),
|
||||
upper: bool = Query(True),
|
||||
lower: bool = Query(True),
|
||||
digits: bool = Query(True),
|
||||
symbols: bool = Query(True),
|
||||
count: int = Query(1, ge=1, le=10)
|
||||
):
|
||||
import random
|
||||
import string
|
||||
@@ -58,7 +59,11 @@ async def generate_password(
|
||||
# ========== 二维码 ==========
|
||||
@router.post("/qrcode/generate")
|
||||
@rate_limit(requests=30, period=60)
|
||||
async def generate_qrcode(content: str, size: int = 10):
|
||||
async def generate_qrcode(
|
||||
request: Request, # 添加 request 参数
|
||||
content: str,
|
||||
size: int = 10
|
||||
):
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="内容不能为空")
|
||||
|
||||
@@ -76,9 +81,12 @@ async def generate_qrcode(content: str, size: int = 10):
|
||||
# ========== 哈希 ==========
|
||||
@router.post("/crypto/hash")
|
||||
@rate_limit(requests=100, period=60)
|
||||
async def compute_hash(request: HashRequest):
|
||||
text = request.text.encode('utf-8')
|
||||
algo = request.algorithm.lower()
|
||||
async def compute_hash(
|
||||
request: Request, # 添加 request 参数
|
||||
req: HashRequest
|
||||
):
|
||||
text = req.text.encode('utf-8')
|
||||
algo = req.algorithm.lower()
|
||||
|
||||
if algo == "md5":
|
||||
result = hashlib.md5(text).hexdigest()
|
||||
@@ -96,61 +104,69 @@ async def compute_hash(request: HashRequest):
|
||||
# ========== Base64 ==========
|
||||
@router.post("/crypto/base64")
|
||||
@rate_limit(requests=100, period=60)
|
||||
async def base64_process(request: Base64Request):
|
||||
if request.action == "encode":
|
||||
result = base64.b64encode(request.text.encode('utf-8')).decode('utf-8')
|
||||
elif request.action == "decode":
|
||||
async def base64_process(
|
||||
request: Request, # 添加 request 参数
|
||||
req: Base64Request
|
||||
):
|
||||
if req.action == "encode":
|
||||
result = base64.b64encode(req.text.encode('utf-8')).decode('utf-8')
|
||||
elif req.action == "decode":
|
||||
try:
|
||||
result = base64.b64decode(request.text).decode('utf-8')
|
||||
result = base64.b64decode(req.text).decode('utf-8')
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Base64 解码失败")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="无效的 action")
|
||||
|
||||
return {"action": request.action, "result": result}
|
||||
return {"action": req.action, "result": result}
|
||||
|
||||
# ========== URL 编解码 ==========
|
||||
@router.post("/crypto/url")
|
||||
@rate_limit(requests=100, period=60)
|
||||
async def url_process(request: URLRequest):
|
||||
if request.action == "encode":
|
||||
result = urllib.parse.quote(request.text, safe='')
|
||||
elif request.action == "decode":
|
||||
result = urllib.parse.unquote(request.text)
|
||||
async def url_process(
|
||||
request: Request, # 添加 request 参数
|
||||
req: URLRequest
|
||||
):
|
||||
if req.action == "encode":
|
||||
result = urllib.parse.quote(req.text, safe='')
|
||||
elif req.action == "decode":
|
||||
result = urllib.parse.unquote(req.text)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="无效的 action")
|
||||
|
||||
return {"action": request.action, "result": result}
|
||||
return {"action": req.action, "result": result}
|
||||
|
||||
# ========== AES 加解密 ==========
|
||||
@router.post("/crypto/aes")
|
||||
@rate_limit(requests=50, period=60)
|
||||
async def aes_process(request: AESRequest):
|
||||
async def aes_process(
|
||||
request: Request, # 添加 request 参数
|
||||
req: AESRequest
|
||||
):
|
||||
try:
|
||||
key = request.key.encode('utf-8')
|
||||
key = req.key.encode('utf-8')
|
||||
mode_map = {"ECB": AES.MODE_ECB, "CBC": AES.MODE_CBC, "GCM": AES.MODE_GCM}
|
||||
mode = mode_map.get(request.mode)
|
||||
mode = mode_map.get(req.mode)
|
||||
|
||||
if not mode:
|
||||
raise HTTPException(status_code=400, detail="不支持的 AES 模式")
|
||||
|
||||
# 密钥长度处理
|
||||
if len(key) not in [16, 24, 32]:
|
||||
raise HTTPException(status_code=400, detail="密钥长度必须为 16/24/32 字节")
|
||||
|
||||
if request.action == "encrypt":
|
||||
cipher = AES.new(key, mode, iv=request.iv.encode('utf-8') if request.iv else None)
|
||||
if request.mode == "GCM":
|
||||
ciphertext = cipher.encrypt(request.text.encode('utf-8'))
|
||||
if req.action == "encrypt":
|
||||
cipher = AES.new(key, mode, iv=req.iv.encode('utf-8') if req.iv else None)
|
||||
if req.mode == "GCM":
|
||||
ciphertext = cipher.encrypt(req.text.encode('utf-8'))
|
||||
result = b64.b64encode(ciphertext).decode('utf-8')
|
||||
else:
|
||||
padded = pad(request.text.encode('utf-8'), AES.block_size)
|
||||
padded = pad(req.text.encode('utf-8'), AES.block_size)
|
||||
ciphertext = cipher.encrypt(padded)
|
||||
result = b64.b64encode(ciphertext).decode('utf-8')
|
||||
elif request.action == "decrypt":
|
||||
ciphertext = b64.b64decode(request.text)
|
||||
cipher = AES.new(key, mode, iv=request.iv.encode('utf-8') if request.iv else None)
|
||||
if request.mode == "GCM":
|
||||
elif req.action == "decrypt":
|
||||
ciphertext = b64.b64decode(req.text)
|
||||
cipher = AES.new(key, mode, iv=req.iv.encode('utf-8') if req.iv else None)
|
||||
if req.mode == "GCM":
|
||||
plaintext = cipher.decrypt(ciphertext).decode('utf-8')
|
||||
result = plaintext
|
||||
else:
|
||||
@@ -159,7 +175,7 @@ async def aes_process(request: AESRequest):
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="无效的 action")
|
||||
|
||||
return {"mode": request.mode, "action": request.action, "result": result}
|
||||
return {"mode": req.mode, "action": req.action, "result": result}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"AES 操作失败: {str(e)}")
|
||||
@@ -167,9 +183,12 @@ async def aes_process(request: AESRequest):
|
||||
# ========== JSON 校验 ==========
|
||||
@router.post("/json/validate", response_model=JSONValidateResponse)
|
||||
@rate_limit(requests=100, period=60)
|
||||
async def validate_json(request: JSONValidateRequest):
|
||||
async def validate_json(
|
||||
request: Request, # 添加 request 参数
|
||||
req: JSONValidateRequest
|
||||
):
|
||||
try:
|
||||
parsed = json.loads(request.json_string)
|
||||
parsed = json.loads(req.json_string)
|
||||
formatted = json.dumps(parsed, indent=2, ensure_ascii=False)
|
||||
return JSONValidateResponse(valid=True, formatted=formatted)
|
||||
except json.JSONDecodeError as e:
|
||||
|
||||
@@ -7,31 +7,38 @@ Author: Canglan <admin@sea-studio.top>
|
||||
License: AGPL v3
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from ...dependencies import CurrentUserDependency, DbDependency
|
||||
from ...models import User
|
||||
from ...schemas import UserResponse, UserUpdateRequest
|
||||
from ...utils.security import hash_password, verify_password
|
||||
from ...utils.logger import logger
|
||||
from ...middleware.rate_limit import rate_limit
|
||||
|
||||
router = APIRouter(prefix="/api/v1/user", tags=["user"])
|
||||
|
||||
@router.get("/profile", response_model=UserResponse)
|
||||
async def get_profile(current_user: CurrentUserDependency):
|
||||
@rate_limit(requests=50, period=60)
|
||||
async def get_profile(
|
||||
request: Request, # 添加 request 参数
|
||||
current_user: CurrentUserDependency
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
return current_user
|
||||
|
||||
@router.put("/profile", response_model=UserResponse)
|
||||
@rate_limit(requests=20, period=60)
|
||||
async def update_profile(
|
||||
request: UserUpdateRequest,
|
||||
request: Request, # 添加 request 参数
|
||||
req: UserUpdateRequest,
|
||||
current_user: CurrentUserDependency,
|
||||
db: DbDependency
|
||||
):
|
||||
"""更新用户信息"""
|
||||
if request.username:
|
||||
current_user.username = request.username
|
||||
if request.avatar:
|
||||
current_user.avatar = request.avatar
|
||||
if req.username:
|
||||
current_user.username = req.username
|
||||
if req.avatar:
|
||||
current_user.avatar = req.avatar
|
||||
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
@@ -40,7 +47,9 @@ async def update_profile(
|
||||
return current_user
|
||||
|
||||
@router.post("/change-password")
|
||||
@rate_limit(requests=10, period=60)
|
||||
async def change_password(
|
||||
request: Request, # 添加 request 参数
|
||||
old_password: str,
|
||||
new_password: str,
|
||||
current_user: CurrentUserDependency,
|
||||
|
||||
Reference in New Issue
Block a user