160 lines
5.1 KiB
Python
160 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
PerToolBox Server - 认证路由
|
|
Copyright (C) 2024 Sea Network Technology Studio
|
|
Author: Canglan <admin@sea-studio.top>
|
|
License: AGPL v3
|
|
"""
|
|
|
|
import re
|
|
from fastapi import APIRouter, HTTPException, status, Request
|
|
from sqlalchemy.orm import Session
|
|
from ...dependencies import DbDependency
|
|
from ...models import User
|
|
from ...schemas import (
|
|
SendCodeRequest, RegisterRequest, LoginRequest, UserResponse
|
|
)
|
|
from ...utils.redis_client import redis_client
|
|
from ...utils.security import generate_verify_code, hash_password, verify_password, create_access_token
|
|
from ...utils.sms import send_sms
|
|
from ...utils.email import send_email
|
|
from ...utils.logger import logger
|
|
from ...middleware.rate_limit import rate_limit
|
|
|
|
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
|
|
|
def is_phone(account: str) -> bool:
|
|
return re.match(r'^1[3-9]\d{9}$', account) is not None
|
|
|
|
def is_email(account: str) -> bool:
|
|
return re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', account) is not None
|
|
|
|
@router.post("/send-code")
|
|
@rate_limit(requests=1, period=60)
|
|
async def send_verify_code(
|
|
request: Request, # 添加 request 参数
|
|
req: SendCodeRequest,
|
|
db: DbDependency
|
|
):
|
|
"""发送验证码"""
|
|
account = req.account
|
|
account_type = req.type
|
|
|
|
if account_type == "phone" and not is_phone(account):
|
|
raise HTTPException(status_code=400, detail="手机号格式错误")
|
|
if account_type == "email" and not is_email(account):
|
|
raise HTTPException(status_code=400, detail="邮箱格式错误")
|
|
|
|
code = generate_verify_code()
|
|
redis_key = f"verify:code:{account}"
|
|
redis_client.set(redis_key, code, expire=300)
|
|
|
|
success = False
|
|
if account_type == "phone":
|
|
success = send_sms(account, code)
|
|
else:
|
|
success = await send_email(account, code)
|
|
|
|
if not success:
|
|
raise HTTPException(status_code=500, detail="验证码发送失败")
|
|
|
|
logger.info(f"验证码已发送: {account} -> {code}")
|
|
return {"success": True, "message": "验证码已发送"}
|
|
|
|
@router.post("/register")
|
|
@rate_limit(requests=5, period=60)
|
|
async def register(
|
|
request: Request, # 添加 request 参数
|
|
req: RegisterRequest,
|
|
db: DbDependency
|
|
):
|
|
"""注册"""
|
|
account = req.account
|
|
code = req.code
|
|
password = req.password
|
|
|
|
redis_key = f"verify:code:{account}"
|
|
saved_code = redis_client.get(redis_key)
|
|
if not saved_code or saved_code != code:
|
|
raise HTTPException(status_code=400, detail="验证码错误或已过期")
|
|
|
|
if is_phone(account):
|
|
existing = db.query(User).filter(User.phone == account).first()
|
|
if existing:
|
|
raise HTTPException(status_code=400, detail="手机号已注册")
|
|
user = User(phone=account, password_hash=hash_password(password))
|
|
elif is_email(account):
|
|
existing = db.query(User).filter(User.email == account).first()
|
|
if existing:
|
|
raise HTTPException(status_code=400, detail="邮箱已注册")
|
|
user = User(email=account, password_hash=hash_password(password))
|
|
else:
|
|
raise HTTPException(status_code=400, detail="账号格式错误")
|
|
|
|
db.add(user)
|
|
db.commit()
|
|
db.refresh(user)
|
|
|
|
redis_client.delete(redis_key)
|
|
token = create_access_token({"sub": str(user.id)})
|
|
|
|
logger.info(f"新用户注册: {account}")
|
|
return {
|
|
"success": True,
|
|
"token": token,
|
|
"user": UserResponse.model_validate(user)
|
|
}
|
|
|
|
@router.post("/login")
|
|
@rate_limit(requests=10, period=60)
|
|
async def login(
|
|
request: Request, # 添加 request 参数
|
|
req: LoginRequest,
|
|
db: DbDependency
|
|
):
|
|
"""登录"""
|
|
account = req.account
|
|
password = req.password
|
|
code = req.code
|
|
|
|
user = None
|
|
if is_phone(account):
|
|
user = db.query(User).filter(User.phone == account).first()
|
|
elif is_email(account):
|
|
user = db.query(User).filter(User.email == account).first()
|
|
else:
|
|
raise HTTPException(status_code=400, detail="账号格式错误")
|
|
|
|
if not user:
|
|
raise HTTPException(status_code=400, detail="用户不存在")
|
|
|
|
if code:
|
|
redis_key = f"verify:code:{account}"
|
|
saved_code = redis_client.get(redis_key)
|
|
if not saved_code or saved_code != code:
|
|
raise HTTPException(status_code=400, detail="验证码错误或已过期")
|
|
redis_client.delete(redis_key)
|
|
elif password:
|
|
if not verify_password(password, user.password_hash):
|
|
raise HTTPException(status_code=400, detail="密码错误")
|
|
else:
|
|
raise HTTPException(status_code=400, detail="请提供验证码或密码")
|
|
|
|
from datetime import datetime
|
|
user.last_login = datetime.now()
|
|
db.commit()
|
|
|
|
token = create_access_token({"sub": str(user.id)})
|
|
|
|
logger.info(f"用户登录: {account}")
|
|
return {
|
|
"success": True,
|
|
"token": token,
|
|
"user": UserResponse.model_validate(user)
|
|
}
|
|
|
|
@router.post("/wechat/login")
|
|
async def wechat_login(request: Request): # 添加 request 参数
|
|
"""微信公众号登录(预留)"""
|
|
raise HTTPException(status_code=501, detail="功能开发中") |