feat: implement concurrency limiting decorator

This commit is contained in:
Saqib Ansari 2026-04-10 22:22:23 +05:30
parent 115d3cead0
commit 2f30dac5d8
5 changed files with 347 additions and 0 deletions

View file

@ -34,6 +34,7 @@ import orjson
from werkzeug.datastructures import Headers
import frappe
from frappe.concurrency_limiter import concurrent_limit
from frappe.query_builder.utils import (
get_query,
get_query_builder,

View file

@ -394,6 +394,12 @@ def handle_exception(e):
elif http_status_code == 429:
response = frappe.rate_limiter.respond()
elif http_status_code == 503:
retry_after = getattr(e, "retry_after", 10)
response = frappe.utils.response.report_error(503)
if response:
response.headers["Retry-After"] = str(retry_after)
else:
response = ErrorPage(
http_status_code=http_status_code, title=_("Server Error"), message=_("Uncaught Exception")

View file

@ -0,0 +1,138 @@
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE
"""
Concurrency limiter for expensive whitelisted methods.
Provides a @frappe.concurrent_limit() decorator that limits the number of
simultaneous in-flight executions of a function across all gunicorn workers
using a Redis-backed atomic counter (semaphore).
Usage::
@frappe.whitelist(allow_guest=True)
@frappe.concurrent_limit(limit=3)
def download_pdf(...):
...
"""
import time
from collections.abc import Callable
from functools import wraps
import frappe
# Safety TTL (seconds) for the Redis key — prevents leaked semaphore slots if a
# worker crashes mid-request. Should be larger than any realistic execution time.
_SLOT_TTL = 120
# Default wait timeout (seconds) before returning 503 to the caller.
_DEFAULT_WAIT_TIMEOUT = 10
# Polling interval (seconds) while waiting for a slot to open.
_POLL_INTERVAL = 0.25
def _default_limit() -> int:
"""Derive a sensible default concurrency limit from the number of gunicorn workers."""
import multiprocessing
workers = frappe.conf.get("gunicorn_workers") or (multiprocessing.cpu_count() * 2 + 1)
return max(1, int(workers) // 2)
def concurrent_limit(limit: int | None = None, wait_timeout: int | None = None):
"""Decorator that limits simultaneous in-flight executions of the wrapped function.
:param limit: Maximum number of concurrent executions. Defaults to
``gunicorn_workers // 2`` (or the value in ``concurrency_limits`` site config).
:param wait_timeout: Seconds to wait for a free slot before returning 503.
Defaults to 10 s. Suppressed for background jobs.
"""
def decorator(fn: Callable) -> Callable:
@wraps(fn)
def wrapper(*args, **kwargs):
# Skip concurrency limiting outside of HTTP requests (background jobs,
# CLI commands, tests that call functions directly, etc.).
if not getattr(frappe.local, "request", None):
return fn(*args, **kwargs)
effective_limit = int(limit) if limit is not None else _default_limit()
effective_wait = (
wait_timeout
if wait_timeout is not None
else frappe.conf.get("concurrency_wait_timeout", _DEFAULT_WAIT_TIMEOUT)
)
cache_key = frappe.cache.make_key(f"concurrency:{fn.__module__}.{fn.__qualname__}")
acquired = _acquire(cache_key, effective_limit, effective_wait)
if not acquired:
from frappe.exceptions import ServiceUnavailableError
retry_after = max(1, int(effective_wait))
exc = ServiceUnavailableError(frappe._("Server is busy. Please try again in a few seconds."))
exc.retry_after = retry_after
raise exc
try:
return fn(*args, **kwargs)
finally:
_release(cache_key)
return wrapper
return decorator
def _acquire(cache_key: str, limit: int, wait_timeout: float) -> bool:
"""Increment the counter and return True if we got a slot within *wait_timeout* seconds.
The counter is incremented first; if the new value exceeds *limit* the
increment is undone and we wait before retrying. This avoids a separate
check-then-act race condition INCRBY is atomic.
"""
deadline = time.monotonic() + wait_timeout
while True:
try:
current = frappe.cache.incrby(cache_key, 1)
except Exception:
# Redis unavailable — fail open to avoid breaking the endpoint entirely.
frappe.log_error("Concurrency limiter: Redis unavailable, skipping limit")
return True
# Refresh TTL on every successful increment so that a slow request
# doesn't let the slot expire before it finishes.
try:
frappe.cache.expire(cache_key, _SLOT_TTL)
except Exception:
pass
if current <= limit:
return True
# Over the limit — give back the slot and wait.
try:
frappe.cache.incrby(cache_key, -1)
except Exception:
pass
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
time.sleep(min(_POLL_INTERVAL, remaining))
def _release(cache_key: str) -> None:
"""Decrement the counter, clamping at 0 to guard against double-release."""
try:
new_val = frappe.cache.incrby(cache_key, -1)
if new_val < 0:
# Shouldn't happen, but clamp to prevent permanently negative counters.
frappe.cache.incrby(cache_key, -new_val)
except Exception:
pass

View file

@ -86,6 +86,17 @@ class TooManyRequestsError(Exception):
http_status_code = 429
class ServiceUnavailableError(Exception):
"""Raised when a concurrency limit is exceeded for an endpoint.
Set :attr:`retry_after` (seconds) before raising so that the response
includes a ``Retry-After`` header.
"""
http_status_code = 503
retry_after: int = 10
class ImproperDBConfigurationError(Exception):
"""
Used when frappe detects that database or tables are not properly

View file

@ -0,0 +1,191 @@
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE
import threading
import time
import frappe
from frappe.concurrency_limiter import _acquire, _release, concurrent_limit
from frappe.exceptions import ServiceUnavailableError
from frappe.tests import IntegrationTestCase
def _cache_key(fn):
return frappe.cache.make_key(f"concurrency:{fn.__module__}.{fn.__qualname__}")
class TestConcurrentLimit(IntegrationTestCase):
def test_bypassed_outside_request_context(self):
"""Decorator is a complete no-op when called outside an HTTP request context
(background jobs, CLI, direct test calls). Even limit=0 must not reject."""
calls = []
@concurrent_limit(limit=0)
def fn():
calls.append(True)
# Make sure no request is set on this thread
saved = getattr(frappe.local, "request", None)
if saved:
del frappe.local.request
try:
fn() # must not raise despite limit=0
finally:
if saved:
frappe.local.request = saved
self.assertEqual(calls, [True])
# Counter must not have been touched
self.assertIsNone(frappe.cache.get(_cache_key(fn)))
def test_raises_immediately_when_limit_full(self):
"""ServiceUnavailableError is raised at once when wait_timeout=0 and the
slot counter is already at the limit."""
@concurrent_limit(limit=1, wait_timeout=0)
def fn():
pass
key = _cache_key(fn)
frappe.cache.incrby(key, 1) # simulate one in-flight request
frappe.cache.expire(key, 60)
try:
frappe.local.request = frappe._dict()
self.assertRaises(ServiceUnavailableError, fn)
finally:
del frappe.local.request
frappe.cache.delete(key)
def test_counter_released_after_successful_call(self):
"""Slot counter returns to zero after the wrapped function completes normally."""
@concurrent_limit(limit=1, wait_timeout=0)
def fn():
pass
key = _cache_key(fn)
try:
frappe.local.request = frappe._dict()
fn()
self.assertEqual(int(frappe.cache.get(key) or 0), 0)
finally:
del frappe.local.request
frappe.cache.delete(key)
def test_counter_released_after_exception(self):
"""Slot counter returns to zero even when the wrapped function raises.
This verifies the finally-block release path."""
@concurrent_limit(limit=2, wait_timeout=0)
def fn():
raise ValueError("boom")
key = _cache_key(fn)
try:
frappe.local.request = frappe._dict()
self.assertRaises(ValueError, fn)
self.assertEqual(int(frappe.cache.get(key) or 0), 0)
finally:
del frappe.local.request
frappe.cache.delete(key)
def test_service_unavailable_has_correct_http_status_and_retry_after(self):
"""The raised exception must carry http_status_code=503 and retry_after
equal to the configured wait_timeout."""
TIMEOUT = 1
@concurrent_limit(limit=1, wait_timeout=TIMEOUT)
def fn():
pass
key = _cache_key(fn)
frappe.cache.incrby(key, 1)
frappe.cache.expire(key, 60)
try:
frappe.local.request = frappe._dict()
with self.assertRaises(ServiceUnavailableError) as ctx:
fn()
exc = ctx.exception
self.assertEqual(exc.http_status_code, 503)
self.assertEqual(exc.retry_after, TIMEOUT)
finally:
del frappe.local.request
frappe.cache.delete(key)
def test_waiter_acquires_slot_when_released(self):
"""A blocked _acquire call succeeds once a concurrent holder calls _release.
Tests the polling loop without going through the decorator."""
key = frappe.cache.make_key("concurrency:test.waiter_acquire")
# Simulate one in-flight holder
frappe.cache.incrby(key, 1)
frappe.cache.expire(key, 60)
acquired = []
def release_after_short_delay():
time.sleep(0.3)
_release(key)
releaser = threading.Thread(target=release_after_short_delay, daemon=True)
releaser.start()
# wait_timeout=2 — should succeed well within that window
result = _acquire(key, limit=1, wait_timeout=2)
acquired.append(result)
releaser.join()
frappe.cache.delete(key)
self.assertTrue(acquired[0])
def test_counter_clamped_at_zero_on_double_release(self):
"""Calling _release more times than _acquire must never produce a negative
counter (which would inflate the effective slot budget)."""
key = frappe.cache.make_key("concurrency:test.clamp_release")
frappe.cache.incrby(key, 1)
_release(key) # correct release → 0
_release(key) # spurious extra release
counter = int(frappe.cache.get(key) or 0)
frappe.cache.delete(key)
self.assertGreaterEqual(counter, 0)
def test_concurrent_threads_respect_limit(self):
"""Exactly `limit` threads acquire concurrently; the rest are rejected when
wait_timeout=0. This exercises the atomic INCRBY semaphore across threads."""
LIMIT = 2
TOTAL = 5
key = frappe.cache.make_key("concurrency:test.thread_limit")
successes = []
rejections = []
lock = threading.Lock()
barrier = threading.Barrier(TOTAL)
def attempt():
barrier.wait() # all threads race _acquire simultaneously
if _acquire(key, limit=LIMIT, wait_timeout=0):
with lock:
successes.append(1)
time.sleep(0.05) # hold the slot briefly
_release(key)
else:
with lock:
rejections.append(1)
threads = [threading.Thread(target=attempt, daemon=True) for _ in range(TOTAL)]
for t in threads:
t.start()
for t in threads:
t.join()
frappe.cache.delete(key)
self.assertEqual(len(successes), LIMIT)
self.assertEqual(len(rejections), TOTAL - LIMIT)