Merge pull request #38576 from nextchamp-saqib/concurrent_limit

feat: add `@frappe.concurrent_limit()` decorator
This commit is contained in:
Saqib Ansari 2026-04-23 22:55:59 +05:30 committed by GitHub
commit a8c373a2f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 422 additions and 9 deletions

View file

@ -1595,6 +1595,7 @@ from frappe.utils.error import log_error
from frappe.utils.formatters import format_value
from frappe.utils.print_utils import get_print, attach_print
from frappe.email import sendmail
from frappe.concurrency_limiter import concurrent_limit
# for backwards compatibility
format = format_value

View file

@ -0,0 +1,125 @@
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE
"""
Concurrency limiter for expensive whitelisted methods.
Provides a @frappe.concurrent_limit() decorator that limits the number of
simultaneous in-flight executions of a function across all gunicorn workers
using a Redis-backed semaphore (LIST + BLPOP).
Usage::
@frappe.whitelist(allow_guest=True)
@frappe.concurrent_limit(limit=3)
def download_pdf(...):
...
"""
from collections.abc import Callable
from functools import wraps
import frappe
from frappe.exceptions import ServiceUnavailableError
from frappe.utils import cint
from frappe.utils.caching import redis_cache
from frappe.utils.redis_semaphore import RedisSemaphore
# Default wait timeout (seconds) before returning 503 to the caller.
_DEFAULT_WAIT_TIMEOUT = 10
@redis_cache(shared=True)
def _default_limit() -> int:
"""Derive a sensible default concurrency limit from gunicorn's max concurrency."""
return max(1, gunicorn_max_concurrency() // 2)
def gunicorn_max_concurrency() -> int:
"""Detect max concurrent requests from the running gunicorn master's cmdline."""
import os
fallback = 4
try:
ppid = os.getppid()
with open(f"/proc/{ppid}/cmdline", "rb") as f:
args = f.read().rstrip(b"\0").decode().split("\0")
if not any("gunicorn" in a for a in args):
return fallback
workers = _extract_cli_int(args, "-w", "--workers") or fallback
threads = _extract_cli_int(args, "--threads") or 1
return workers * threads
except OSError:
return fallback
def _extract_cli_int(args: list[str], *flags: str) -> int | None:
"""Return the integer value for a CLI flag from a split argument list.
Handles both ``--flag value`` and ``--flag=value`` forms.
"""
for i, arg in enumerate(args):
for flag in flags:
if arg == flag and i + 1 < len(args):
return int(args[i + 1])
if arg.startswith(f"{flag}="):
return int(arg.split("=", 1)[1])
return None
def concurrent_limit(limit: int | None = None, wait_timeout: int = _DEFAULT_WAIT_TIMEOUT):
"""Decorator that limits simultaneous in-flight executions of the wrapped function.
:param limit: Maximum number of concurrent executions. Defaults to half of ``workers x threads``
as detected from the gunicorn master process.
:param wait_timeout: Seconds to wait for a free slot before returning 503.
Defaults to 10 s.
The limiter is skipped entirely for background jobs, CLI commands, and
tests that call functions directly (i.e. outside of an HTTP request).
"""
def decorator(fn: Callable) -> Callable:
@wraps(fn)
def wrapper(*args, **kwargs):
# Skip concurrency limiting outside of HTTP requests (background jobs,
# CLI commands, tests that call functions directly, etc.).
if getattr(frappe.local, "request", None) is None:
return fn(*args, **kwargs)
_limit = cint(limit) if limit is not None else _default_limit()
key = f"concurrency:{fn.__module__}.{fn.__qualname__}"
sem = RedisSemaphore(key, _limit, wait_timeout, shared=True)
token = sem.acquire()
if not token:
retry_after = max(1, int(wait_timeout))
if (headers := getattr(frappe.local, "response_headers", None)) is not None:
headers.set("Retry-After", str(retry_after))
exc = ServiceUnavailableError(frappe._("Server is busy. Please try again in a few seconds."))
exc.retry_after = retry_after
raise exc
try:
return fn(*args, **kwargs)
finally:
sem.release(token)
return wrapper
return decorator
@frappe.whitelist()
def get_stats() -> dict:
frappe.only_for("System Manager")
cached_limit = _default_limit()
gunicorn_limit = gunicorn_max_concurrency()
return {
"cached_limit": cached_limit,
"gunicorn_limit": gunicorn_limit,
}

View file

@ -86,6 +86,10 @@ class TooManyRequestsError(Exception):
http_status_code = 429
class ServiceUnavailableError(Exception):
http_status_code = 503
class ImproperDBConfigurationError(Exception):
"""
Used when frappe detects that database or tables are not properly

View file

@ -530,6 +530,7 @@ persistent_cache_keys = [
"monitor-transactions",
"rate-limit-counter-*",
"rl:*",
"concurrency:*",
]
user_invitation = {

View file

@ -0,0 +1,163 @@
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE
import contextvars
import threading
from unittest.mock import MagicMock, patch
import frappe
from frappe.concurrency_limiter import concurrent_limit
from frappe.exceptions import ServiceUnavailableError
from frappe.tests import IntegrationTestCase
def _key(fn):
"""Reconstruct the Redis key that concurrent_limit uses for a decorated function."""
return f"concurrency:{fn.__module__}.{fn.__qualname__}"
def _cleanup(fn):
key = _key(fn)
frappe.cache.delete_value([key, f"{key}:capacity"], shared=True)
class TestConcurrentLimit(IntegrationTestCase):
def test_bypassed_outside_request_context(self):
"""Decorator is a no-op outside HTTP request context (background jobs, CLI, tests).
Even limit=0 must not reject."""
calls = []
@concurrent_limit(limit=0)
def fn():
calls.append(True)
saved = getattr(frappe.local, "request", None)
if saved:
del frappe.local.request
try:
fn() # must not raise despite limit=0
finally:
if saved:
frappe.local.request = saved
self.assertEqual(calls, [True])
def test_pool_exhaustion_raises_503_with_retry_after_header(self):
"""When all slots are occupied, the next request raises ServiceUnavailableError
(HTTP 503) immediately with wait_timeout=0. The Retry-After response header must be set."""
in_fn = threading.Event()
proceed = threading.Event()
@concurrent_limit(limit=1, wait_timeout=0)
def fn():
in_fn.set()
proceed.wait()
ctx = contextvars.copy_context()
def hold_slot():
frappe.local.request = frappe._dict()
fn()
t = threading.Thread(target=ctx.run, args=(hold_slot,))
t.start()
self.assertTrue(in_fn.wait(timeout=5), "Thread did not acquire the slot in time")
mock_headers = MagicMock()
saved_headers = getattr(frappe.local, "response_headers", None)
try:
frappe.local.request = frappe._dict()
frappe.local.response_headers = mock_headers
with self.assertRaises(ServiceUnavailableError) as exc_ctx:
fn()
self.assertEqual(exc_ctx.exception.http_status_code, 503)
mock_headers.set.assert_called_once_with("Retry-After", "1") # max(1, wait_timeout=0)
finally:
proceed.set()
t.join(timeout=5)
del frappe.local.request
frappe.local.response_headers = saved_headers
_cleanup(fn)
def test_token_released_on_success(self):
"""A token is returned to the pool after a successful call,
so subsequent calls can acquire it without hitting a 503."""
@concurrent_limit(limit=1, wait_timeout=0)
def fn():
pass
try:
frappe.local.request = frappe._dict()
fn()
fn() # should not raise ServiceUnavailableError since the token was released after the first call
finally:
del frappe.local.request
_cleanup(fn)
def test_token_released_on_exception(self):
"""A token is returned to the pool even when the wrapped function raises,
so subsequent calls can proceed with their own application error, not a 503."""
@concurrent_limit(limit=1, wait_timeout=0)
def fn():
raise ValueError("boom")
try:
frappe.local.request = frappe._dict()
with self.assertRaises(ValueError):
fn()
# Second call must raise ValueError (application error), not
# ServiceUnavailableError — which would indicate the token was leaked.
with self.assertRaises(ValueError):
fn()
finally:
del frappe.local.request
_cleanup(fn)
def test_self_heals_after_capacity_key_expiry(self):
"""After the capacity key expires (simulating crashed workers + TTL),
the pool re-initializes to full capacity so new requests succeed."""
@concurrent_limit(limit=1, wait_timeout=0)
def fn():
pass
key = _key(fn)
try:
frappe.local.request = frappe._dict()
fn() # initializes the pool via the decorator
# Simulate all tokens being leaked (workers crashed mid-request)
# by draining the pool without returning tokens.
while frappe.cache.lpop(key, shared=True):
pass
# Simulate capacity key TTL expiry.
frappe.cache.delete_value(f"{key}:capacity", shared=True)
# Self-heal: next request must re-initialize the pool and succeed.
fn() # must not raise ServiceUnavailableError
finally:
del frappe.local.request
_cleanup(fn)
def test_fails_open_when_redis_unavailable(self):
"""When Redis is unavailable during acquire, the request proceeds normally
(fail-open) rather than raising ServiceUnavailableError."""
calls = []
@concurrent_limit(limit=1, wait_timeout=0)
def fn():
calls.append(True)
try:
frappe.local.request = frappe._dict()
with patch.object(frappe.cache, "lpop", side_effect=Exception("Redis down")):
fn() # must not raise
finally:
del frappe.local.request
_cleanup(fn)
self.assertEqual(calls, [True])

View file

@ -180,7 +180,7 @@ def redis_cache(ttl: int | None = 3600, user: str | bool | None = None, shared:
func_key = f"{func.__module__}.{func.__qualname__}"
def clear_cache():
frappe.cache.delete_keys(func_key)
frappe.cache.delete_keys(func_key, user=user, shared=shared)
func.clear_cache = clear_cache
func.ttl = ttl if not callable(ttl) else 3600

View file

@ -0,0 +1,116 @@
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE
"""Distributed counting semaphore backed by a Redis LIST."""
import frappe
class RedisSemaphore:
"""A distributed counting semaphore backed by a Redis LIST.
Allows up to *limit* concurrent holders across all processes sharing the
same Redis instance. The token pool is lazily initialized and self-heals
after crashes thanks to a TTL on the capacity key.
Usage as a context manager::
sem = RedisSemaphore("my-resource", limit=5, wait_timeout=10)
with sem:
... # at most 5 concurrent holders
Or acquire/release manually::
token = sem.acquire()
if token is None:
raise Exception("Too busy")
try:
...
finally:
sem.release(token)
"""
# Safety TTL (seconds) for the capacity key — allows the pool to self-heal
# after a worker crash that leaked a token.
CAPACITY_TTL = 3600 # 1 hour
def __init__(self, key: str, limit: int, wait_timeout: float = 0, shared: bool = False):
"""
:param key: A unique Redis key name for this semaphore (will be
prefixed by the cache layer).
:param limit: Maximum number of concurrent holders.
:param wait_timeout: Seconds to block waiting for a free slot.
0 means non-blocking (immediate return if unavailable).
:param shared: If True, the semaphore key is bench-wide (not
prefixed with the site's db_name). Defaults to site-scoped.
"""
self.key = key
self.limit = limit
self.wait_timeout = wait_timeout
self.shared = shared
self._token: str | None = None
def acquire(self) -> str | None:
"""Try to acquire a token from the pool.
Returns a token string on success, ``None`` if no slot was
available within *wait_timeout*, or ``"fallback"`` if Redis is
unreachable (fail-open).
"""
try:
self._ensure_tokens()
if self.wait_timeout <= 0:
result = frappe.cache.lpop(self.key, shared=self.shared)
return self._decode(result) if result is not None else None
if result := frappe.cache.blpop(self.key, timeout=int(self.wait_timeout), shared=self.shared):
return self._decode(result[1])
return None
except Exception:
frappe.log_error(f"RedisSemaphore({self.key}): Redis unavailable, skipping limit")
return "fallback"
def release(self, token: str) -> None:
"""Return *token* to the pool."""
if token == "fallback":
return
try:
frappe.cache.lpush(self.key, token, shared=self.shared)
except Exception:
frappe.log_error(f"RedisSemaphore({self.key}): Failed to release token {token}")
# -- context-manager protocol ------------------------------------------
def __enter__(self):
self._token = self.acquire()
return self._token
def __exit__(self, *exc_info):
if self._token is not None:
self.release(self._token)
self._token = None
# -- internals ---------------------------------------------------------
def _ensure_tokens(self) -> None:
"""Lazily initialize the token pool."""
try:
if frappe.cache.exists(f"{self.key}:capacity", shared=self.shared):
return
frappe.cache.set_value(
f"{self.key}:capacity",
self.limit,
expires_in_sec=self.CAPACITY_TTL,
shared=self.shared,
)
frappe.cache.delete_value(self.key, shared=self.shared)
for i in range(1, self.limit + 1):
frappe.cache.lpush(self.key, str(i), shared=self.shared)
except Exception:
frappe.log_error(f"RedisSemaphore({self.key}): Failed to initialize tokens")
@staticmethod
def _decode(result):
return result.decode() if isinstance(result, bytes) else result

View file

@ -125,19 +125,19 @@ class RedisWrapper(redis.Redis):
return ret
def get_keys(self, key):
def get_keys(self, key, user=None, shared=False):
"""Return keys starting with `key`."""
try:
key = self.make_key(key + "*")
key = self.make_key(key + "*", user=user, shared=shared)
return self.keys(key)
except redis.exceptions.ConnectionError:
regex = re.compile(cstr(key).replace("|", r"\|").replace("*", r"[\w]*"))
return [k for k in list(frappe.local.cache) if regex.match(cstr(k))]
def delete_keys(self, key):
def delete_keys(self, key, user=None, shared=False):
"""Delete keys with wildcard `*`."""
self.delete_value(self.get_keys(key), make_keys=False)
self.delete_value(self.get_keys(key, user=user, shared=shared), make_keys=False)
def delete_key(self, *args, **kwargs):
self.delete_value(*args, **kwargs)
@ -162,18 +162,21 @@ class RedisWrapper(redis.Redis):
except redis.exceptions.ConnectionError:
pass
def lpush(self, key, value):
return super().lpush(self.make_key(key), value)
def lpush(self, key, value, user=None, shared=False):
return super().lpush(self.make_key(key, user=user, shared=shared), value)
def rpush(self, key, value):
return super().rpush(self.make_key(key), value)
def lpop(self, key):
return super().lpop(self.make_key(key))
def lpop(self, key, user=None, shared=False):
return super().lpop(self.make_key(key, user=user, shared=shared))
def rpop(self, key):
return super().rpop(self.make_key(key))
def blpop(self, key, timeout=0, user=None, shared=False):
return super().blpop(self.make_key(key, user=user, shared=shared), timeout=timeout)
def llen(self, key):
return super().llen(self.make_key(key))