51 lines
1.9 KiB
Python
51 lines
1.9 KiB
Python
import time
|
|
from collections import defaultdict, deque
|
|
|
|
from fastapi import Request, Response
|
|
from fastapi.responses import ORJSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
|
|
from app.core.config import Settings
|
|
from app.core.errors import error_payload
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
def __init__(self, app: object, settings: Settings) -> None:
|
|
super().__init__(app)
|
|
self.limit = settings.rate_limit_requests
|
|
self.window_seconds = settings.rate_limit_window_seconds
|
|
self.requests: defaultdict[str, deque[float]] = defaultdict(deque)
|
|
|
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
|
client = request.client.host if request.client else "unknown"
|
|
now = time.monotonic()
|
|
bucket = self.requests[client]
|
|
|
|
while bucket and now - bucket[0] > self.window_seconds:
|
|
bucket.popleft()
|
|
|
|
if len(bucket) >= self.limit:
|
|
return ORJSONResponse(
|
|
status_code=429,
|
|
content=error_payload("rate_limited", "Too many requests"),
|
|
headers={"Retry-After": str(self.window_seconds)},
|
|
)
|
|
|
|
bucket.append(now)
|
|
return await call_next(request)
|
|
|
|
|
|
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
|
|
def __init__(self, app: object, settings: Settings) -> None:
|
|
super().__init__(app)
|
|
self.max_request_bytes = settings.max_request_bytes
|
|
|
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
|
content_length = request.headers.get("content-length")
|
|
if content_length and int(content_length) > self.max_request_bytes:
|
|
return ORJSONResponse(
|
|
status_code=413,
|
|
content=error_payload("payload_too_large", "Request body is too large"),
|
|
)
|
|
return await call_next(request)
|