diff --git a/.env.example b/.env.example index e71ea78..7a85ed7 100644 --- a/.env.example +++ b/.env.example @@ -23,7 +23,7 @@ RATE_LIMIT_REQUESTS=100 RATE_LIMIT_PERIOD=60 # CORS(前端域名,多个域名用逗号分隔) -ALLOWED_ORIGINS=https://your-domain.com +ALLOWED_ORIGINS=["https://your-domain.com", "http://localhost:3000"] # 短信配置(阿里云接入) ALIYUN_SMS_ACCESS_KEY_ID= diff --git a/backend/middleware/rate_limit.py b/backend/middleware/rate_limit.py index 6e3a18b..0c79bba 100644 --- a/backend/middleware/rate_limit.py +++ b/backend/middleware/rate_limit.py @@ -11,19 +11,24 @@ from fastapi import Request, HTTPException from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded +from slowapi.extension import RateLimitDecorator from ..config import settings +# 创建限流器 limiter = Limiter(key_func=get_remote_address) def setup_rate_limit(app): + """配置限流""" if settings.RATE_LIMIT_ENABLED: app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) def rate_limit(requests: int = None, period: int = None): + """限流装饰器工厂 - 修复版本""" if not settings.RATE_LIMIT_ENABLED: return lambda func: func - + req = requests or settings.RATE_LIMIT_REQUESTS per = period or settings.RATE_LIMIT_PERIOD + return limiter.limit(f"{req}/{per} seconds") \ No newline at end of file diff --git a/backend/routers/v1/auth.py b/backend/routers/v1/auth.py index 498dcad..304684a 100644 --- a/backend/routers/v1/auth.py +++ b/backend/routers/v1/auth.py @@ -8,7 +8,7 @@ License: AGPL v3 """ import re -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, HTTPException, status, Request from sqlalchemy.orm import Session from ...dependencies import DbDependency from ...models import User @@ -31,29 +31,25 @@ def is_email(account: str) -> bool: return re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', account) is not None @router.post("/send-code") -@rate_limit(requests=1, period=60) # 同一账号每分钟1次 -async def send_verify_code(request: SendCodeRequest, db: DbDependency): +@rate_limit(requests=1, period=60) +async def send_verify_code( + request: Request, # 添加 request 参数 + req: SendCodeRequest, + db: DbDependency +): """发送验证码""" - account = request.account - account_type = request.type + account = req.account + account_type = req.type - # 验证账号格式 if account_type == "phone" and not is_phone(account): raise HTTPException(status_code=400, detail="手机号格式错误") if account_type == "email" and not is_email(account): raise HTTPException(status_code=400, detail="邮箱格式错误") - # 检查是否已存在(注册时使用) - # 这里只负责发送,不检查是否存在 - - # 生成验证码 code = generate_verify_code() - - # 存储到 Redis(5分钟有效) redis_key = f"verify:code:{account}" redis_client.set(redis_key, code, expire=300) - # 发送验证码 success = False if account_type == "phone": success = send_sms(account, code) @@ -68,19 +64,21 @@ async def send_verify_code(request: SendCodeRequest, db: DbDependency): @router.post("/register") @rate_limit(requests=5, period=60) -async def register(request: RegisterRequest, db: DbDependency): - """注册(手机号/邮箱 + 验证码)""" - account = request.account - code = request.code - password = request.password +async def register( + request: Request, # 添加 request 参数 + req: RegisterRequest, + db: DbDependency +): + """注册""" + account = req.account + code = req.code + password = req.password - # 验证验证码 redis_key = f"verify:code:{account}" saved_code = redis_client.get(redis_key) if not saved_code or saved_code != code: raise HTTPException(status_code=400, detail="验证码错误或已过期") - # 检查用户是否已存在 if is_phone(account): existing = db.query(User).filter(User.phone == account).first() if existing: @@ -98,10 +96,7 @@ async def register(request: RegisterRequest, db: DbDependency): db.commit() db.refresh(user) - # 删除已使用的验证码 redis_client.delete(redis_key) - - # 生成 token token = create_access_token({"sub": str(user.id)}) logger.info(f"新用户注册: {account}") @@ -113,13 +108,16 @@ async def register(request: RegisterRequest, db: DbDependency): @router.post("/login") @rate_limit(requests=10, period=60) -async def login(request: LoginRequest, db: DbDependency): - """登录(支持验证码登录或密码登录)""" - account = request.account - password = request.password - code = request.code +async def login( + request: Request, # 添加 request 参数 + req: LoginRequest, + db: DbDependency +): + """登录""" + account = req.account + password = req.password + code = req.code - # 查找用户 user = None if is_phone(account): user = db.query(User).filter(User.phone == account).first() @@ -131,22 +129,18 @@ async def login(request: LoginRequest, db: DbDependency): if not user: raise HTTPException(status_code=400, detail="用户不存在") - # 验证码登录 if code: redis_key = f"verify:code:{account}" saved_code = redis_client.get(redis_key) if not saved_code or saved_code != code: raise HTTPException(status_code=400, detail="验证码错误或已过期") redis_client.delete(redis_key) - - # 密码登录 elif password: if not verify_password(password, user.password_hash): raise HTTPException(status_code=400, detail="密码错误") else: raise HTTPException(status_code=400, detail="请提供验证码或密码") - # 更新最后登录时间 from datetime import datetime user.last_login = datetime.now() db.commit() @@ -161,7 +155,6 @@ async def login(request: LoginRequest, db: DbDependency): } @router.post("/wechat/login") -async def wechat_login(): +async def wechat_login(request: Request): # 添加 request 参数 """微信公众号登录(预留)""" - # TODO: 实现公众号 OAuth2 登录 raise HTTPException(status_code=501, detail="功能开发中") \ No newline at end of file diff --git a/backend/routers/v1/notes.py b/backend/routers/v1/notes.py index 06330c0..4f84e94 100644 --- a/backend/routers/v1/notes.py +++ b/backend/routers/v1/notes.py @@ -7,7 +7,7 @@ Author: Canglan License: AGPL v3 """ -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, HTTPException, Query, Request from typing import Optional, List from ...dependencies import CurrentUserDependency, DbDependency from ...models import Note @@ -19,6 +19,7 @@ router = APIRouter(prefix="/api/v1/notes", tags=["notes"]) @router.get("/", response_model=List[NoteResponse]) @rate_limit(requests=100, period=60) async def get_notes( + request: Request, # 添加 request 参数 current_user: CurrentUserDependency, db: DbDependency, skip: int = Query(0, ge=0), @@ -33,6 +34,7 @@ async def get_notes( @router.post("/", response_model=NoteResponse, status_code=201) @rate_limit(requests=50, period=60) async def create_note( + request: Request, # 添加 request 参数 data: NoteCreate, current_user: CurrentUserDependency, db: DbDependency @@ -46,6 +48,7 @@ async def create_note( @router.put("/{note_id}", response_model=NoteResponse) @rate_limit(requests=50, period=60) async def update_note( + request: Request, # 添加 request 参数 note_id: int, data: NoteUpdate, current_user: CurrentUserDependency, @@ -66,6 +69,7 @@ async def update_note( @router.delete("/{note_id}") @rate_limit(requests=30, period=60) async def delete_note( + request: Request, # 添加 request 参数 note_id: int, current_user: CurrentUserDependency, db: DbDependency diff --git a/backend/routers/v1/stats.py b/backend/routers/v1/stats.py index 4c88b02..7f3c31a 100644 --- a/backend/routers/v1/stats.py +++ b/backend/routers/v1/stats.py @@ -8,7 +8,7 @@ License: AGPL v3 """ from datetime import datetime -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException, Request from ...utils.redis_client import redis_client from ...models import ToolStatsTotal from ...dependencies import DbDependency @@ -16,7 +16,6 @@ from ...middleware.rate_limit import rate_limit router = APIRouter(prefix="/api/v1", tags=["stats"]) -# 预定义工具名称(对应前端页面) TOOL_NAMES = [ "todos", "notes", "password", "qrcode", "crypto_hash", "crypto_base64", "crypto_url", "crypto_aes", "json" @@ -24,7 +23,11 @@ TOOL_NAMES = [ @router.post("/tool/usage") @rate_limit(requests=20, period=60) -async def record_usage(tool_name: str, db: DbDependency): +async def record_usage( + request: Request, # 添加 request 参数 + tool_name: str, + db: DbDependency +): """记录页面访问次数(热度)""" if tool_name not in TOOL_NAMES: raise HTTPException(status_code=400, detail="无效的工具名") @@ -33,15 +36,10 @@ async def record_usage(tool_name: str, db: DbDependency): today_key = f"tool:stats:today:{tool_name}:{today}" total_key = f"tool:stats:total:{tool_name}" - # 增加今日计数(设置48小时过期) today_count = redis_client.incr(today_key) redis_client.expire(today_key, 48 * 3600) - - # 增加总计数 total_count = redis_client.incr(total_key) - # 异步更新 MySQL(可选,这里简单处理) - # 实际可改为定时任务同步,此处为简化,直接更新 stats = db.query(ToolStatsTotal).filter(ToolStatsTotal.tool_name == tool_name).first() if stats: stats.total_count = total_count @@ -54,7 +52,10 @@ async def record_usage(tool_name: str, db: DbDependency): @router.get("/tool/stats") @rate_limit(requests=100, period=60) -async def get_stats(db: DbDependency): +async def get_stats( + request: Request, # 添加 request 参数 + db: DbDependency +): """获取所有工具的今日/总访问次数""" today = datetime.now().strftime("%Y-%m-%d") result = {} @@ -67,7 +68,6 @@ async def get_stats(db: DbDependency): total_count = redis_client.get(total_key) if total_count is None: - # 从 MySQL 读取 stats = db.query(ToolStatsTotal).filter(ToolStatsTotal.tool_name == tool_name).first() total_count = stats.total_count if stats else 0 else: diff --git a/backend/routers/v1/todos.py b/backend/routers/v1/todos.py index d72ef87..f7baea2 100644 --- a/backend/routers/v1/todos.py +++ b/backend/routers/v1/todos.py @@ -7,7 +7,7 @@ Author: Canglan License: AGPL v3 """ -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, HTTPException, Query, Request from typing import Optional, List from ...dependencies import CurrentUserDependency, DbDependency from ...models import Todo @@ -19,6 +19,7 @@ router = APIRouter(prefix="/api/v1/todos", tags=["todos"]) @router.get("/", response_model=List[TodoResponse]) @rate_limit(requests=100, period=60) async def get_todos( + request: Request, # 添加 request 参数 current_user: CurrentUserDependency, db: DbDependency, skip: int = Query(0, ge=0), @@ -38,6 +39,7 @@ async def get_todos( @router.post("/", response_model=TodoResponse, status_code=201) @rate_limit(requests=50, period=60) async def create_todo( + request: Request, # 添加 request 参数 data: TodoCreate, current_user: CurrentUserDependency, db: DbDependency @@ -51,6 +53,7 @@ async def create_todo( @router.put("/{todo_id}", response_model=TodoResponse) @rate_limit(requests=50, period=60) async def update_todo( + request: Request, # 添加 request 参数 todo_id: int, data: TodoUpdate, current_user: CurrentUserDependency, @@ -71,6 +74,7 @@ async def update_todo( @router.delete("/{todo_id}") @rate_limit(requests=30, period=60) async def delete_todo( + request: Request, # 添加 request 参数 todo_id: int, current_user: CurrentUserDependency, db: DbDependency diff --git a/backend/routers/v1/tools.py b/backend/routers/v1/tools.py index 63797bc..5434bf4 100644 --- a/backend/routers/v1/tools.py +++ b/backend/routers/v1/tools.py @@ -16,7 +16,7 @@ from io import BytesIO import base64 as b64 from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request, Query from ...schemas import ( HashRequest, Base64Request, URLRequest, AESRequest, JSONValidateRequest, JSONValidateResponse @@ -29,12 +29,13 @@ router = APIRouter(prefix="/api/v1", tags=["tools"]) @router.get("/password/generate") @rate_limit(requests=50, period=60) async def generate_password( - length: int = 12, - upper: bool = True, - lower: bool = True, - digits: bool = True, - symbols: bool = True, - count: int = 1 + request: Request, # 添加 request 参数 + length: int = Query(12, ge=4, le=64), + upper: bool = Query(True), + lower: bool = Query(True), + digits: bool = Query(True), + symbols: bool = Query(True), + count: int = Query(1, ge=1, le=10) ): import random import string @@ -58,7 +59,11 @@ async def generate_password( # ========== 二维码 ========== @router.post("/qrcode/generate") @rate_limit(requests=30, period=60) -async def generate_qrcode(content: str, size: int = 10): +async def generate_qrcode( + request: Request, # 添加 request 参数 + content: str, + size: int = 10 +): if not content: raise HTTPException(status_code=400, detail="内容不能为空") @@ -76,9 +81,12 @@ async def generate_qrcode(content: str, size: int = 10): # ========== 哈希 ========== @router.post("/crypto/hash") @rate_limit(requests=100, period=60) -async def compute_hash(request: HashRequest): - text = request.text.encode('utf-8') - algo = request.algorithm.lower() +async def compute_hash( + request: Request, # 添加 request 参数 + req: HashRequest +): + text = req.text.encode('utf-8') + algo = req.algorithm.lower() if algo == "md5": result = hashlib.md5(text).hexdigest() @@ -96,61 +104,69 @@ async def compute_hash(request: HashRequest): # ========== Base64 ========== @router.post("/crypto/base64") @rate_limit(requests=100, period=60) -async def base64_process(request: Base64Request): - if request.action == "encode": - result = base64.b64encode(request.text.encode('utf-8')).decode('utf-8') - elif request.action == "decode": +async def base64_process( + request: Request, # 添加 request 参数 + req: Base64Request +): + if req.action == "encode": + result = base64.b64encode(req.text.encode('utf-8')).decode('utf-8') + elif req.action == "decode": try: - result = base64.b64decode(request.text).decode('utf-8') + result = base64.b64decode(req.text).decode('utf-8') except Exception: raise HTTPException(status_code=400, detail="Base64 解码失败") else: raise HTTPException(status_code=400, detail="无效的 action") - return {"action": request.action, "result": result} + return {"action": req.action, "result": result} # ========== URL 编解码 ========== @router.post("/crypto/url") @rate_limit(requests=100, period=60) -async def url_process(request: URLRequest): - if request.action == "encode": - result = urllib.parse.quote(request.text, safe='') - elif request.action == "decode": - result = urllib.parse.unquote(request.text) +async def url_process( + request: Request, # 添加 request 参数 + req: URLRequest +): + if req.action == "encode": + result = urllib.parse.quote(req.text, safe='') + elif req.action == "decode": + result = urllib.parse.unquote(req.text) else: raise HTTPException(status_code=400, detail="无效的 action") - return {"action": request.action, "result": result} + return {"action": req.action, "result": result} # ========== AES 加解密 ========== @router.post("/crypto/aes") @rate_limit(requests=50, period=60) -async def aes_process(request: AESRequest): +async def aes_process( + request: Request, # 添加 request 参数 + req: AESRequest +): try: - key = request.key.encode('utf-8') + key = req.key.encode('utf-8') mode_map = {"ECB": AES.MODE_ECB, "CBC": AES.MODE_CBC, "GCM": AES.MODE_GCM} - mode = mode_map.get(request.mode) + mode = mode_map.get(req.mode) if not mode: raise HTTPException(status_code=400, detail="不支持的 AES 模式") - # 密钥长度处理 if len(key) not in [16, 24, 32]: raise HTTPException(status_code=400, detail="密钥长度必须为 16/24/32 字节") - if request.action == "encrypt": - cipher = AES.new(key, mode, iv=request.iv.encode('utf-8') if request.iv else None) - if request.mode == "GCM": - ciphertext = cipher.encrypt(request.text.encode('utf-8')) + if req.action == "encrypt": + cipher = AES.new(key, mode, iv=req.iv.encode('utf-8') if req.iv else None) + if req.mode == "GCM": + ciphertext = cipher.encrypt(req.text.encode('utf-8')) result = b64.b64encode(ciphertext).decode('utf-8') else: - padded = pad(request.text.encode('utf-8'), AES.block_size) + padded = pad(req.text.encode('utf-8'), AES.block_size) ciphertext = cipher.encrypt(padded) result = b64.b64encode(ciphertext).decode('utf-8') - elif request.action == "decrypt": - ciphertext = b64.b64decode(request.text) - cipher = AES.new(key, mode, iv=request.iv.encode('utf-8') if request.iv else None) - if request.mode == "GCM": + elif req.action == "decrypt": + ciphertext = b64.b64decode(req.text) + cipher = AES.new(key, mode, iv=req.iv.encode('utf-8') if req.iv else None) + if req.mode == "GCM": plaintext = cipher.decrypt(ciphertext).decode('utf-8') result = plaintext else: @@ -159,7 +175,7 @@ async def aes_process(request: AESRequest): else: raise HTTPException(status_code=400, detail="无效的 action") - return {"mode": request.mode, "action": request.action, "result": result} + return {"mode": req.mode, "action": req.action, "result": result} except Exception as e: raise HTTPException(status_code=400, detail=f"AES 操作失败: {str(e)}") @@ -167,9 +183,12 @@ async def aes_process(request: AESRequest): # ========== JSON 校验 ========== @router.post("/json/validate", response_model=JSONValidateResponse) @rate_limit(requests=100, period=60) -async def validate_json(request: JSONValidateRequest): +async def validate_json( + request: Request, # 添加 request 参数 + req: JSONValidateRequest +): try: - parsed = json.loads(request.json_string) + parsed = json.loads(req.json_string) formatted = json.dumps(parsed, indent=2, ensure_ascii=False) return JSONValidateResponse(valid=True, formatted=formatted) except json.JSONDecodeError as e: diff --git a/backend/routers/v1/user.py b/backend/routers/v1/user.py index 8ed5583..9292995 100644 --- a/backend/routers/v1/user.py +++ b/backend/routers/v1/user.py @@ -7,31 +7,38 @@ Author: Canglan License: AGPL v3 """ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request from ...dependencies import CurrentUserDependency, DbDependency from ...models import User from ...schemas import UserResponse, UserUpdateRequest from ...utils.security import hash_password, verify_password from ...utils.logger import logger +from ...middleware.rate_limit import rate_limit router = APIRouter(prefix="/api/v1/user", tags=["user"]) @router.get("/profile", response_model=UserResponse) -async def get_profile(current_user: CurrentUserDependency): +@rate_limit(requests=50, period=60) +async def get_profile( + request: Request, # 添加 request 参数 + current_user: CurrentUserDependency +): """获取当前用户信息""" return current_user @router.put("/profile", response_model=UserResponse) +@rate_limit(requests=20, period=60) async def update_profile( - request: UserUpdateRequest, + request: Request, # 添加 request 参数 + req: UserUpdateRequest, current_user: CurrentUserDependency, db: DbDependency ): """更新用户信息""" - if request.username: - current_user.username = request.username - if request.avatar: - current_user.avatar = request.avatar + if req.username: + current_user.username = req.username + if req.avatar: + current_user.avatar = req.avatar db.commit() db.refresh(current_user) @@ -40,7 +47,9 @@ async def update_profile( return current_user @router.post("/change-password") +@rate_limit(requests=10, period=60) async def change_password( + request: Request, # 添加 request 参数 old_password: str, new_password: str, current_user: CurrentUserDependency,