修复限流器问题
This commit is contained in:
@@ -8,7 +8,7 @@ License: AGPL v3
|
||||
"""
|
||||
|
||||
import re
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from fastapi import APIRouter, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from ...dependencies import DbDependency
|
||||
from ...models import User
|
||||
@@ -31,29 +31,25 @@ def is_email(account: str) -> bool:
|
||||
return re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', account) is not None
|
||||
|
||||
@router.post("/send-code")
|
||||
@rate_limit(requests=1, period=60) # 同一账号每分钟1次
|
||||
async def send_verify_code(request: SendCodeRequest, db: DbDependency):
|
||||
@rate_limit(requests=1, period=60)
|
||||
async def send_verify_code(
|
||||
request: Request, # 添加 request 参数
|
||||
req: SendCodeRequest,
|
||||
db: DbDependency
|
||||
):
|
||||
"""发送验证码"""
|
||||
account = request.account
|
||||
account_type = request.type
|
||||
account = req.account
|
||||
account_type = req.type
|
||||
|
||||
# 验证账号格式
|
||||
if account_type == "phone" and not is_phone(account):
|
||||
raise HTTPException(status_code=400, detail="手机号格式错误")
|
||||
if account_type == "email" and not is_email(account):
|
||||
raise HTTPException(status_code=400, detail="邮箱格式错误")
|
||||
|
||||
# 检查是否已存在(注册时使用)
|
||||
# 这里只负责发送,不检查是否存在
|
||||
|
||||
# 生成验证码
|
||||
code = generate_verify_code()
|
||||
|
||||
# 存储到 Redis(5分钟有效)
|
||||
redis_key = f"verify:code:{account}"
|
||||
redis_client.set(redis_key, code, expire=300)
|
||||
|
||||
# 发送验证码
|
||||
success = False
|
||||
if account_type == "phone":
|
||||
success = send_sms(account, code)
|
||||
@@ -68,19 +64,21 @@ async def send_verify_code(request: SendCodeRequest, db: DbDependency):
|
||||
|
||||
@router.post("/register")
|
||||
@rate_limit(requests=5, period=60)
|
||||
async def register(request: RegisterRequest, db: DbDependency):
|
||||
"""注册(手机号/邮箱 + 验证码)"""
|
||||
account = request.account
|
||||
code = request.code
|
||||
password = request.password
|
||||
async def register(
|
||||
request: Request, # 添加 request 参数
|
||||
req: RegisterRequest,
|
||||
db: DbDependency
|
||||
):
|
||||
"""注册"""
|
||||
account = req.account
|
||||
code = req.code
|
||||
password = req.password
|
||||
|
||||
# 验证验证码
|
||||
redis_key = f"verify:code:{account}"
|
||||
saved_code = redis_client.get(redis_key)
|
||||
if not saved_code or saved_code != code:
|
||||
raise HTTPException(status_code=400, detail="验证码错误或已过期")
|
||||
|
||||
# 检查用户是否已存在
|
||||
if is_phone(account):
|
||||
existing = db.query(User).filter(User.phone == account).first()
|
||||
if existing:
|
||||
@@ -98,10 +96,7 @@ async def register(request: RegisterRequest, db: DbDependency):
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
# 删除已使用的验证码
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
# 生成 token
|
||||
token = create_access_token({"sub": str(user.id)})
|
||||
|
||||
logger.info(f"新用户注册: {account}")
|
||||
@@ -113,13 +108,16 @@ async def register(request: RegisterRequest, db: DbDependency):
|
||||
|
||||
@router.post("/login")
|
||||
@rate_limit(requests=10, period=60)
|
||||
async def login(request: LoginRequest, db: DbDependency):
|
||||
"""登录(支持验证码登录或密码登录)"""
|
||||
account = request.account
|
||||
password = request.password
|
||||
code = request.code
|
||||
async def login(
|
||||
request: Request, # 添加 request 参数
|
||||
req: LoginRequest,
|
||||
db: DbDependency
|
||||
):
|
||||
"""登录"""
|
||||
account = req.account
|
||||
password = req.password
|
||||
code = req.code
|
||||
|
||||
# 查找用户
|
||||
user = None
|
||||
if is_phone(account):
|
||||
user = db.query(User).filter(User.phone == account).first()
|
||||
@@ -131,22 +129,18 @@ async def login(request: LoginRequest, db: DbDependency):
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="用户不存在")
|
||||
|
||||
# 验证码登录
|
||||
if code:
|
||||
redis_key = f"verify:code:{account}"
|
||||
saved_code = redis_client.get(redis_key)
|
||||
if not saved_code or saved_code != code:
|
||||
raise HTTPException(status_code=400, detail="验证码错误或已过期")
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
# 密码登录
|
||||
elif password:
|
||||
if not verify_password(password, user.password_hash):
|
||||
raise HTTPException(status_code=400, detail="密码错误")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="请提供验证码或密码")
|
||||
|
||||
# 更新最后登录时间
|
||||
from datetime import datetime
|
||||
user.last_login = datetime.now()
|
||||
db.commit()
|
||||
@@ -161,7 +155,6 @@ async def login(request: LoginRequest, db: DbDependency):
|
||||
}
|
||||
|
||||
@router.post("/wechat/login")
|
||||
async def wechat_login():
|
||||
async def wechat_login(request: Request): # 添加 request 参数
|
||||
"""微信公众号登录(预留)"""
|
||||
# TODO: 实现公众号 OAuth2 登录
|
||||
raise HTTPException(status_code=501, detail="功能开发中")
|
||||
Reference in New Issue
Block a user