diff --git a/frappe/__init__.py b/frappe/__init__.py index 2a9a508921..8c9f8126c9 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -1595,6 +1595,7 @@ from frappe.utils.error import log_error from frappe.utils.formatters import format_value from frappe.utils.print_utils import get_print, attach_print from frappe.email import sendmail +from frappe.concurrency_limiter import concurrent_limit # for backwards compatibility format = format_value diff --git a/frappe/concurrency_limiter.py b/frappe/concurrency_limiter.py new file mode 100644 index 0000000000..ba056faec1 --- /dev/null +++ b/frappe/concurrency_limiter.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors +# License: MIT. See LICENSE + +""" +Concurrency limiter for expensive whitelisted methods. + +Provides a @frappe.concurrent_limit() decorator that limits the number of +simultaneous in-flight executions of a function across all gunicorn workers +using a Redis-backed semaphore (LIST + BLPOP). + +Usage:: + + @frappe.whitelist(allow_guest=True) + @frappe.concurrent_limit(limit=3) + def download_pdf(...): + ... + +""" + +from collections.abc import Callable +from functools import wraps + +import frappe +from frappe.exceptions import ServiceUnavailableError +from frappe.utils import cint +from frappe.utils.caching import redis_cache +from frappe.utils.redis_semaphore import RedisSemaphore + +# Default wait timeout (seconds) before returning 503 to the caller. +_DEFAULT_WAIT_TIMEOUT = 10 + + +@redis_cache(shared=True) +def _default_limit() -> int: + """Derive a sensible default concurrency limit from gunicorn's max concurrency.""" + return max(1, gunicorn_max_concurrency() // 2) + + +def gunicorn_max_concurrency() -> int: + """Detect max concurrent requests from the running gunicorn master's cmdline.""" + import os + + fallback = 4 + + try: + ppid = os.getppid() + with open(f"/proc/{ppid}/cmdline", "rb") as f: + args = f.read().rstrip(b"\0").decode().split("\0") + + if not any("gunicorn" in a for a in args): + return fallback + + workers = _extract_cli_int(args, "-w", "--workers") or fallback + threads = _extract_cli_int(args, "--threads") or 1 + return workers * threads + except OSError: + return fallback + + +def _extract_cli_int(args: list[str], *flags: str) -> int | None: + """Return the integer value for a CLI flag from a split argument list. + + Handles both ``--flag value`` and ``--flag=value`` forms. + """ + for i, arg in enumerate(args): + for flag in flags: + if arg == flag and i + 1 < len(args): + return int(args[i + 1]) + if arg.startswith(f"{flag}="): + return int(arg.split("=", 1)[1]) + return None + + +def concurrent_limit(limit: int | None = None, wait_timeout: int = _DEFAULT_WAIT_TIMEOUT): + """Decorator that limits simultaneous in-flight executions of the wrapped function. + + :param limit: Maximum number of concurrent executions. Defaults to half of ``workers x threads`` + as detected from the gunicorn master process. + :param wait_timeout: Seconds to wait for a free slot before returning 503. + Defaults to 10 s. + + The limiter is skipped entirely for background jobs, CLI commands, and + tests that call functions directly (i.e. outside of an HTTP request). + """ + + def decorator(fn: Callable) -> Callable: + @wraps(fn) + def wrapper(*args, **kwargs): + # Skip concurrency limiting outside of HTTP requests (background jobs, + # CLI commands, tests that call functions directly, etc.). + if getattr(frappe.local, "request", None) is None: + return fn(*args, **kwargs) + + _limit = cint(limit) if limit is not None else _default_limit() + key = f"concurrency:{fn.__module__}.{fn.__qualname__}" + + sem = RedisSemaphore(key, _limit, wait_timeout, shared=True) + token = sem.acquire() + if not token: + retry_after = max(1, int(wait_timeout)) + if (headers := getattr(frappe.local, "response_headers", None)) is not None: + headers.set("Retry-After", str(retry_after)) + exc = ServiceUnavailableError(frappe._("Server is busy. Please try again in a few seconds.")) + exc.retry_after = retry_after + raise exc + + try: + return fn(*args, **kwargs) + finally: + sem.release(token) + + return wrapper + + return decorator + + +@frappe.whitelist() +def get_stats() -> dict: + frappe.only_for("System Manager") + cached_limit = _default_limit() + gunicorn_limit = gunicorn_max_concurrency() + return { + "cached_limit": cached_limit, + "gunicorn_limit": gunicorn_limit, + } diff --git a/frappe/exceptions.py b/frappe/exceptions.py index 210408422a..4b73db8b9c 100644 --- a/frappe/exceptions.py +++ b/frappe/exceptions.py @@ -86,6 +86,10 @@ class TooManyRequestsError(Exception): http_status_code = 429 +class ServiceUnavailableError(Exception): + http_status_code = 503 + + class ImproperDBConfigurationError(Exception): """ Used when frappe detects that database or tables are not properly diff --git a/frappe/hooks.py b/frappe/hooks.py index e1c46826f6..ff9f455178 100644 --- a/frappe/hooks.py +++ b/frappe/hooks.py @@ -530,6 +530,7 @@ persistent_cache_keys = [ "monitor-transactions", "rate-limit-counter-*", "rl:*", + "concurrency:*", ] user_invitation = { diff --git a/frappe/tests/test_concurrency_limiter.py b/frappe/tests/test_concurrency_limiter.py new file mode 100644 index 0000000000..b203e8822e --- /dev/null +++ b/frappe/tests/test_concurrency_limiter.py @@ -0,0 +1,163 @@ +# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors +# License: MIT. See LICENSE + +import contextvars +import threading +from unittest.mock import MagicMock, patch + +import frappe +from frappe.concurrency_limiter import concurrent_limit +from frappe.exceptions import ServiceUnavailableError +from frappe.tests import IntegrationTestCase + + +def _key(fn): + """Reconstruct the Redis key that concurrent_limit uses for a decorated function.""" + return f"concurrency:{fn.__module__}.{fn.__qualname__}" + + +def _cleanup(fn): + key = _key(fn) + frappe.cache.delete_value([key, f"{key}:capacity"], shared=True) + + +class TestConcurrentLimit(IntegrationTestCase): + def test_bypassed_outside_request_context(self): + """Decorator is a no-op outside HTTP request context (background jobs, CLI, tests). + Even limit=0 must not reject.""" + calls = [] + + @concurrent_limit(limit=0) + def fn(): + calls.append(True) + + saved = getattr(frappe.local, "request", None) + if saved: + del frappe.local.request + + try: + fn() # must not raise despite limit=0 + finally: + if saved: + frappe.local.request = saved + + self.assertEqual(calls, [True]) + + def test_pool_exhaustion_raises_503_with_retry_after_header(self): + """When all slots are occupied, the next request raises ServiceUnavailableError + (HTTP 503) immediately with wait_timeout=0. The Retry-After response header must be set.""" + in_fn = threading.Event() + proceed = threading.Event() + + @concurrent_limit(limit=1, wait_timeout=0) + def fn(): + in_fn.set() + proceed.wait() + + ctx = contextvars.copy_context() + + def hold_slot(): + frappe.local.request = frappe._dict() + fn() + + t = threading.Thread(target=ctx.run, args=(hold_slot,)) + t.start() + self.assertTrue(in_fn.wait(timeout=5), "Thread did not acquire the slot in time") + + mock_headers = MagicMock() + saved_headers = getattr(frappe.local, "response_headers", None) + try: + frappe.local.request = frappe._dict() + frappe.local.response_headers = mock_headers + with self.assertRaises(ServiceUnavailableError) as exc_ctx: + fn() + self.assertEqual(exc_ctx.exception.http_status_code, 503) + mock_headers.set.assert_called_once_with("Retry-After", "1") # max(1, wait_timeout=0) + finally: + proceed.set() + t.join(timeout=5) + del frappe.local.request + frappe.local.response_headers = saved_headers + _cleanup(fn) + + def test_token_released_on_success(self): + """A token is returned to the pool after a successful call, + so subsequent calls can acquire it without hitting a 503.""" + + @concurrent_limit(limit=1, wait_timeout=0) + def fn(): + pass + + try: + frappe.local.request = frappe._dict() + fn() + fn() # should not raise ServiceUnavailableError since the token was released after the first call + finally: + del frappe.local.request + _cleanup(fn) + + def test_token_released_on_exception(self): + """A token is returned to the pool even when the wrapped function raises, + so subsequent calls can proceed with their own application error, not a 503.""" + + @concurrent_limit(limit=1, wait_timeout=0) + def fn(): + raise ValueError("boom") + + try: + frappe.local.request = frappe._dict() + with self.assertRaises(ValueError): + fn() + # Second call must raise ValueError (application error), not + # ServiceUnavailableError — which would indicate the token was leaked. + with self.assertRaises(ValueError): + fn() + finally: + del frappe.local.request + _cleanup(fn) + + def test_self_heals_after_capacity_key_expiry(self): + """After the capacity key expires (simulating crashed workers + TTL), + the pool re-initializes to full capacity so new requests succeed.""" + + @concurrent_limit(limit=1, wait_timeout=0) + def fn(): + pass + + key = _key(fn) + try: + frappe.local.request = frappe._dict() + fn() # initializes the pool via the decorator + + # Simulate all tokens being leaked (workers crashed mid-request) + # by draining the pool without returning tokens. + while frappe.cache.lpop(key, shared=True): + pass + + # Simulate capacity key TTL expiry. + frappe.cache.delete_value(f"{key}:capacity", shared=True) + + # Self-heal: next request must re-initialize the pool and succeed. + fn() # must not raise ServiceUnavailableError + finally: + del frappe.local.request + _cleanup(fn) + + def test_fails_open_when_redis_unavailable(self): + """When Redis is unavailable during acquire, the request proceeds normally + (fail-open) rather than raising ServiceUnavailableError.""" + calls = [] + + @concurrent_limit(limit=1, wait_timeout=0) + def fn(): + calls.append(True) + + try: + frappe.local.request = frappe._dict() + with patch.object(frappe.cache, "lpop", side_effect=Exception("Redis down")): + fn() # must not raise + finally: + del frappe.local.request + _cleanup(fn) + + self.assertEqual(calls, [True]) diff --git a/frappe/utils/caching.py b/frappe/utils/caching.py index f2f1f9aecd..f1e0c781ed 100644 --- a/frappe/utils/caching.py +++ b/frappe/utils/caching.py @@ -180,7 +180,7 @@ def redis_cache(ttl: int | None = 3600, user: str | bool | None = None, shared: func_key = f"{func.__module__}.{func.__qualname__}" def clear_cache(): - frappe.cache.delete_keys(func_key) + frappe.cache.delete_keys(func_key, user=user, shared=shared) func.clear_cache = clear_cache func.ttl = ttl if not callable(ttl) else 3600 diff --git a/frappe/utils/redis_semaphore.py b/frappe/utils/redis_semaphore.py new file mode 100644 index 0000000000..7228e6ba36 --- /dev/null +++ b/frappe/utils/redis_semaphore.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors +# License: MIT. See LICENSE + +"""Distributed counting semaphore backed by a Redis LIST.""" + +import frappe + + +class RedisSemaphore: + """A distributed counting semaphore backed by a Redis LIST. + + Allows up to *limit* concurrent holders across all processes sharing the + same Redis instance. The token pool is lazily initialized and self-heals + after crashes thanks to a TTL on the capacity key. + + Usage as a context manager:: + + sem = RedisSemaphore("my-resource", limit=5, wait_timeout=10) + with sem: + ... # at most 5 concurrent holders + + Or acquire/release manually:: + + token = sem.acquire() + if token is None: + raise Exception("Too busy") + try: + ... + finally: + sem.release(token) + """ + + # Safety TTL (seconds) for the capacity key — allows the pool to self-heal + # after a worker crash that leaked a token. + CAPACITY_TTL = 3600 # 1 hour + + def __init__(self, key: str, limit: int, wait_timeout: float = 0, shared: bool = False): + """ + :param key: A unique Redis key name for this semaphore (will be + prefixed by the cache layer). + :param limit: Maximum number of concurrent holders. + :param wait_timeout: Seconds to block waiting for a free slot. + 0 means non-blocking (immediate return if unavailable). + :param shared: If True, the semaphore key is bench-wide (not + prefixed with the site's db_name). Defaults to site-scoped. + """ + self.key = key + self.limit = limit + self.wait_timeout = wait_timeout + self.shared = shared + self._token: str | None = None + + def acquire(self) -> str | None: + """Try to acquire a token from the pool. + + Returns a token string on success, ``None`` if no slot was + available within *wait_timeout*, or ``"fallback"`` if Redis is + unreachable (fail-open). + """ + try: + self._ensure_tokens() + + if self.wait_timeout <= 0: + result = frappe.cache.lpop(self.key, shared=self.shared) + return self._decode(result) if result is not None else None + + if result := frappe.cache.blpop(self.key, timeout=int(self.wait_timeout), shared=self.shared): + return self._decode(result[1]) + return None + + except Exception: + frappe.log_error(f"RedisSemaphore({self.key}): Redis unavailable, skipping limit") + return "fallback" + + def release(self, token: str) -> None: + """Return *token* to the pool.""" + if token == "fallback": + return + try: + frappe.cache.lpush(self.key, token, shared=self.shared) + except Exception: + frappe.log_error(f"RedisSemaphore({self.key}): Failed to release token {token}") + + # -- context-manager protocol ------------------------------------------ + + def __enter__(self): + self._token = self.acquire() + return self._token + + def __exit__(self, *exc_info): + if self._token is not None: + self.release(self._token) + self._token = None + + # -- internals --------------------------------------------------------- + + def _ensure_tokens(self) -> None: + """Lazily initialize the token pool.""" + try: + if frappe.cache.exists(f"{self.key}:capacity", shared=self.shared): + return + frappe.cache.set_value( + f"{self.key}:capacity", + self.limit, + expires_in_sec=self.CAPACITY_TTL, + shared=self.shared, + ) + frappe.cache.delete_value(self.key, shared=self.shared) + for i in range(1, self.limit + 1): + frappe.cache.lpush(self.key, str(i), shared=self.shared) + except Exception: + frappe.log_error(f"RedisSemaphore({self.key}): Failed to initialize tokens") + + @staticmethod + def _decode(result): + return result.decode() if isinstance(result, bytes) else result diff --git a/frappe/utils/redis_wrapper.py b/frappe/utils/redis_wrapper.py index 844852b4d5..1f02ebbbfd 100644 --- a/frappe/utils/redis_wrapper.py +++ b/frappe/utils/redis_wrapper.py @@ -125,19 +125,19 @@ class RedisWrapper(redis.Redis): return ret - def get_keys(self, key): + def get_keys(self, key, user=None, shared=False): """Return keys starting with `key`.""" try: - key = self.make_key(key + "*") + key = self.make_key(key + "*", user=user, shared=shared) return self.keys(key) except redis.exceptions.ConnectionError: regex = re.compile(cstr(key).replace("|", r"\|").replace("*", r"[\w]*")) return [k for k in list(frappe.local.cache) if regex.match(cstr(k))] - def delete_keys(self, key): + def delete_keys(self, key, user=None, shared=False): """Delete keys with wildcard `*`.""" - self.delete_value(self.get_keys(key), make_keys=False) + self.delete_value(self.get_keys(key, user=user, shared=shared), make_keys=False) def delete_key(self, *args, **kwargs): self.delete_value(*args, **kwargs) @@ -162,18 +162,21 @@ class RedisWrapper(redis.Redis): except redis.exceptions.ConnectionError: pass - def lpush(self, key, value): - return super().lpush(self.make_key(key), value) + def lpush(self, key, value, user=None, shared=False): + return super().lpush(self.make_key(key, user=user, shared=shared), value) def rpush(self, key, value): return super().rpush(self.make_key(key), value) - def lpop(self, key): - return super().lpop(self.make_key(key)) + def lpop(self, key, user=None, shared=False): + return super().lpop(self.make_key(key, user=user, shared=shared)) def rpop(self, key): return super().rpop(self.make_key(key)) + def blpop(self, key, timeout=0, user=None, shared=False): + return super().blpop(self.make_key(key, user=user, shared=shared), timeout=timeout) + def llen(self, key): return super().llen(self.make_key(key))