Merge pull request #38576 from nextchamp-saqib/concurrent_limit
feat: add `@frappe.concurrent_limit()` decorator
This commit is contained in:
commit
a8c373a2f4
8 changed files with 422 additions and 9 deletions
|
|
@ -1595,6 +1595,7 @@ from frappe.utils.error import log_error
|
|||
from frappe.utils.formatters import format_value
|
||||
from frappe.utils.print_utils import get_print, attach_print
|
||||
from frappe.email import sendmail
|
||||
from frappe.concurrency_limiter import concurrent_limit
|
||||
|
||||
# for backwards compatibility
|
||||
format = format_value
|
||||
|
|
|
|||
125
frappe/concurrency_limiter.py
Normal file
125
frappe/concurrency_limiter.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
|
||||
# License: MIT. See LICENSE
|
||||
|
||||
"""
|
||||
Concurrency limiter for expensive whitelisted methods.
|
||||
|
||||
Provides a @frappe.concurrent_limit() decorator that limits the number of
|
||||
simultaneous in-flight executions of a function across all gunicorn workers
|
||||
using a Redis-backed semaphore (LIST + BLPOP).
|
||||
|
||||
Usage::
|
||||
|
||||
@frappe.whitelist(allow_guest=True)
|
||||
@frappe.concurrent_limit(limit=3)
|
||||
def download_pdf(...):
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
import frappe
|
||||
from frappe.exceptions import ServiceUnavailableError
|
||||
from frappe.utils import cint
|
||||
from frappe.utils.caching import redis_cache
|
||||
from frappe.utils.redis_semaphore import RedisSemaphore
|
||||
|
||||
# Default wait timeout (seconds) before returning 503 to the caller.
|
||||
_DEFAULT_WAIT_TIMEOUT = 10
|
||||
|
||||
|
||||
@redis_cache(shared=True)
|
||||
def _default_limit() -> int:
|
||||
"""Derive a sensible default concurrency limit from gunicorn's max concurrency."""
|
||||
return max(1, gunicorn_max_concurrency() // 2)
|
||||
|
||||
|
||||
def gunicorn_max_concurrency() -> int:
|
||||
"""Detect max concurrent requests from the running gunicorn master's cmdline."""
|
||||
import os
|
||||
|
||||
fallback = 4
|
||||
|
||||
try:
|
||||
ppid = os.getppid()
|
||||
with open(f"/proc/{ppid}/cmdline", "rb") as f:
|
||||
args = f.read().rstrip(b"\0").decode().split("\0")
|
||||
|
||||
if not any("gunicorn" in a for a in args):
|
||||
return fallback
|
||||
|
||||
workers = _extract_cli_int(args, "-w", "--workers") or fallback
|
||||
threads = _extract_cli_int(args, "--threads") or 1
|
||||
return workers * threads
|
||||
except OSError:
|
||||
return fallback
|
||||
|
||||
|
||||
def _extract_cli_int(args: list[str], *flags: str) -> int | None:
|
||||
"""Return the integer value for a CLI flag from a split argument list.
|
||||
|
||||
Handles both ``--flag value`` and ``--flag=value`` forms.
|
||||
"""
|
||||
for i, arg in enumerate(args):
|
||||
for flag in flags:
|
||||
if arg == flag and i + 1 < len(args):
|
||||
return int(args[i + 1])
|
||||
if arg.startswith(f"{flag}="):
|
||||
return int(arg.split("=", 1)[1])
|
||||
return None
|
||||
|
||||
|
||||
def concurrent_limit(limit: int | None = None, wait_timeout: int = _DEFAULT_WAIT_TIMEOUT):
|
||||
"""Decorator that limits simultaneous in-flight executions of the wrapped function.
|
||||
|
||||
:param limit: Maximum number of concurrent executions. Defaults to half of ``workers x threads``
|
||||
as detected from the gunicorn master process.
|
||||
:param wait_timeout: Seconds to wait for a free slot before returning 503.
|
||||
Defaults to 10 s.
|
||||
|
||||
The limiter is skipped entirely for background jobs, CLI commands, and
|
||||
tests that call functions directly (i.e. outside of an HTTP request).
|
||||
"""
|
||||
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Skip concurrency limiting outside of HTTP requests (background jobs,
|
||||
# CLI commands, tests that call functions directly, etc.).
|
||||
if getattr(frappe.local, "request", None) is None:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
_limit = cint(limit) if limit is not None else _default_limit()
|
||||
key = f"concurrency:{fn.__module__}.{fn.__qualname__}"
|
||||
|
||||
sem = RedisSemaphore(key, _limit, wait_timeout, shared=True)
|
||||
token = sem.acquire()
|
||||
if not token:
|
||||
retry_after = max(1, int(wait_timeout))
|
||||
if (headers := getattr(frappe.local, "response_headers", None)) is not None:
|
||||
headers.set("Retry-After", str(retry_after))
|
||||
exc = ServiceUnavailableError(frappe._("Server is busy. Please try again in a few seconds."))
|
||||
exc.retry_after = retry_after
|
||||
raise exc
|
||||
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
sem.release(token)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@frappe.whitelist()
|
||||
def get_stats() -> dict:
|
||||
frappe.only_for("System Manager")
|
||||
cached_limit = _default_limit()
|
||||
gunicorn_limit = gunicorn_max_concurrency()
|
||||
return {
|
||||
"cached_limit": cached_limit,
|
||||
"gunicorn_limit": gunicorn_limit,
|
||||
}
|
||||
|
|
@ -86,6 +86,10 @@ class TooManyRequestsError(Exception):
|
|||
http_status_code = 429
|
||||
|
||||
|
||||
class ServiceUnavailableError(Exception):
|
||||
http_status_code = 503
|
||||
|
||||
|
||||
class ImproperDBConfigurationError(Exception):
|
||||
"""
|
||||
Used when frappe detects that database or tables are not properly
|
||||
|
|
|
|||
|
|
@ -530,6 +530,7 @@ persistent_cache_keys = [
|
|||
"monitor-transactions",
|
||||
"rate-limit-counter-*",
|
||||
"rl:*",
|
||||
"concurrency:*",
|
||||
]
|
||||
|
||||
user_invitation = {
|
||||
|
|
|
|||
163
frappe/tests/test_concurrency_limiter.py
Normal file
163
frappe/tests/test_concurrency_limiter.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
|
||||
# License: MIT. See LICENSE
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import frappe
|
||||
from frappe.concurrency_limiter import concurrent_limit
|
||||
from frappe.exceptions import ServiceUnavailableError
|
||||
from frappe.tests import IntegrationTestCase
|
||||
|
||||
|
||||
def _key(fn):
|
||||
"""Reconstruct the Redis key that concurrent_limit uses for a decorated function."""
|
||||
return f"concurrency:{fn.__module__}.{fn.__qualname__}"
|
||||
|
||||
|
||||
def _cleanup(fn):
|
||||
key = _key(fn)
|
||||
frappe.cache.delete_value([key, f"{key}:capacity"], shared=True)
|
||||
|
||||
|
||||
class TestConcurrentLimit(IntegrationTestCase):
|
||||
def test_bypassed_outside_request_context(self):
|
||||
"""Decorator is a no-op outside HTTP request context (background jobs, CLI, tests).
|
||||
Even limit=0 must not reject."""
|
||||
calls = []
|
||||
|
||||
@concurrent_limit(limit=0)
|
||||
def fn():
|
||||
calls.append(True)
|
||||
|
||||
saved = getattr(frappe.local, "request", None)
|
||||
if saved:
|
||||
del frappe.local.request
|
||||
|
||||
try:
|
||||
fn() # must not raise despite limit=0
|
||||
finally:
|
||||
if saved:
|
||||
frappe.local.request = saved
|
||||
|
||||
self.assertEqual(calls, [True])
|
||||
|
||||
def test_pool_exhaustion_raises_503_with_retry_after_header(self):
|
||||
"""When all slots are occupied, the next request raises ServiceUnavailableError
|
||||
(HTTP 503) immediately with wait_timeout=0. The Retry-After response header must be set."""
|
||||
in_fn = threading.Event()
|
||||
proceed = threading.Event()
|
||||
|
||||
@concurrent_limit(limit=1, wait_timeout=0)
|
||||
def fn():
|
||||
in_fn.set()
|
||||
proceed.wait()
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
|
||||
def hold_slot():
|
||||
frappe.local.request = frappe._dict()
|
||||
fn()
|
||||
|
||||
t = threading.Thread(target=ctx.run, args=(hold_slot,))
|
||||
t.start()
|
||||
self.assertTrue(in_fn.wait(timeout=5), "Thread did not acquire the slot in time")
|
||||
|
||||
mock_headers = MagicMock()
|
||||
saved_headers = getattr(frappe.local, "response_headers", None)
|
||||
try:
|
||||
frappe.local.request = frappe._dict()
|
||||
frappe.local.response_headers = mock_headers
|
||||
with self.assertRaises(ServiceUnavailableError) as exc_ctx:
|
||||
fn()
|
||||
self.assertEqual(exc_ctx.exception.http_status_code, 503)
|
||||
mock_headers.set.assert_called_once_with("Retry-After", "1") # max(1, wait_timeout=0)
|
||||
finally:
|
||||
proceed.set()
|
||||
t.join(timeout=5)
|
||||
del frappe.local.request
|
||||
frappe.local.response_headers = saved_headers
|
||||
_cleanup(fn)
|
||||
|
||||
def test_token_released_on_success(self):
|
||||
"""A token is returned to the pool after a successful call,
|
||||
so subsequent calls can acquire it without hitting a 503."""
|
||||
|
||||
@concurrent_limit(limit=1, wait_timeout=0)
|
||||
def fn():
|
||||
pass
|
||||
|
||||
try:
|
||||
frappe.local.request = frappe._dict()
|
||||
fn()
|
||||
fn() # should not raise ServiceUnavailableError since the token was released after the first call
|
||||
finally:
|
||||
del frappe.local.request
|
||||
_cleanup(fn)
|
||||
|
||||
def test_token_released_on_exception(self):
|
||||
"""A token is returned to the pool even when the wrapped function raises,
|
||||
so subsequent calls can proceed with their own application error, not a 503."""
|
||||
|
||||
@concurrent_limit(limit=1, wait_timeout=0)
|
||||
def fn():
|
||||
raise ValueError("boom")
|
||||
|
||||
try:
|
||||
frappe.local.request = frappe._dict()
|
||||
with self.assertRaises(ValueError):
|
||||
fn()
|
||||
# Second call must raise ValueError (application error), not
|
||||
# ServiceUnavailableError — which would indicate the token was leaked.
|
||||
with self.assertRaises(ValueError):
|
||||
fn()
|
||||
finally:
|
||||
del frappe.local.request
|
||||
_cleanup(fn)
|
||||
|
||||
def test_self_heals_after_capacity_key_expiry(self):
|
||||
"""After the capacity key expires (simulating crashed workers + TTL),
|
||||
the pool re-initializes to full capacity so new requests succeed."""
|
||||
|
||||
@concurrent_limit(limit=1, wait_timeout=0)
|
||||
def fn():
|
||||
pass
|
||||
|
||||
key = _key(fn)
|
||||
try:
|
||||
frappe.local.request = frappe._dict()
|
||||
fn() # initializes the pool via the decorator
|
||||
|
||||
# Simulate all tokens being leaked (workers crashed mid-request)
|
||||
# by draining the pool without returning tokens.
|
||||
while frappe.cache.lpop(key, shared=True):
|
||||
pass
|
||||
|
||||
# Simulate capacity key TTL expiry.
|
||||
frappe.cache.delete_value(f"{key}:capacity", shared=True)
|
||||
|
||||
# Self-heal: next request must re-initialize the pool and succeed.
|
||||
fn() # must not raise ServiceUnavailableError
|
||||
finally:
|
||||
del frappe.local.request
|
||||
_cleanup(fn)
|
||||
|
||||
def test_fails_open_when_redis_unavailable(self):
|
||||
"""When Redis is unavailable during acquire, the request proceeds normally
|
||||
(fail-open) rather than raising ServiceUnavailableError."""
|
||||
calls = []
|
||||
|
||||
@concurrent_limit(limit=1, wait_timeout=0)
|
||||
def fn():
|
||||
calls.append(True)
|
||||
|
||||
try:
|
||||
frappe.local.request = frappe._dict()
|
||||
with patch.object(frappe.cache, "lpop", side_effect=Exception("Redis down")):
|
||||
fn() # must not raise
|
||||
finally:
|
||||
del frappe.local.request
|
||||
_cleanup(fn)
|
||||
|
||||
self.assertEqual(calls, [True])
|
||||
|
|
@ -180,7 +180,7 @@ def redis_cache(ttl: int | None = 3600, user: str | bool | None = None, shared:
|
|||
func_key = f"{func.__module__}.{func.__qualname__}"
|
||||
|
||||
def clear_cache():
|
||||
frappe.cache.delete_keys(func_key)
|
||||
frappe.cache.delete_keys(func_key, user=user, shared=shared)
|
||||
|
||||
func.clear_cache = clear_cache
|
||||
func.ttl = ttl if not callable(ttl) else 3600
|
||||
|
|
|
|||
116
frappe/utils/redis_semaphore.py
Normal file
116
frappe/utils/redis_semaphore.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
|
||||
# License: MIT. See LICENSE
|
||||
|
||||
"""Distributed counting semaphore backed by a Redis LIST."""
|
||||
|
||||
import frappe
|
||||
|
||||
|
||||
class RedisSemaphore:
|
||||
"""A distributed counting semaphore backed by a Redis LIST.
|
||||
|
||||
Allows up to *limit* concurrent holders across all processes sharing the
|
||||
same Redis instance. The token pool is lazily initialized and self-heals
|
||||
after crashes thanks to a TTL on the capacity key.
|
||||
|
||||
Usage as a context manager::
|
||||
|
||||
sem = RedisSemaphore("my-resource", limit=5, wait_timeout=10)
|
||||
with sem:
|
||||
... # at most 5 concurrent holders
|
||||
|
||||
Or acquire/release manually::
|
||||
|
||||
token = sem.acquire()
|
||||
if token is None:
|
||||
raise Exception("Too busy")
|
||||
try:
|
||||
...
|
||||
finally:
|
||||
sem.release(token)
|
||||
"""
|
||||
|
||||
# Safety TTL (seconds) for the capacity key — allows the pool to self-heal
|
||||
# after a worker crash that leaked a token.
|
||||
CAPACITY_TTL = 3600 # 1 hour
|
||||
|
||||
def __init__(self, key: str, limit: int, wait_timeout: float = 0, shared: bool = False):
|
||||
"""
|
||||
:param key: A unique Redis key name for this semaphore (will be
|
||||
prefixed by the cache layer).
|
||||
:param limit: Maximum number of concurrent holders.
|
||||
:param wait_timeout: Seconds to block waiting for a free slot.
|
||||
0 means non-blocking (immediate return if unavailable).
|
||||
:param shared: If True, the semaphore key is bench-wide (not
|
||||
prefixed with the site's db_name). Defaults to site-scoped.
|
||||
"""
|
||||
self.key = key
|
||||
self.limit = limit
|
||||
self.wait_timeout = wait_timeout
|
||||
self.shared = shared
|
||||
self._token: str | None = None
|
||||
|
||||
def acquire(self) -> str | None:
|
||||
"""Try to acquire a token from the pool.
|
||||
|
||||
Returns a token string on success, ``None`` if no slot was
|
||||
available within *wait_timeout*, or ``"fallback"`` if Redis is
|
||||
unreachable (fail-open).
|
||||
"""
|
||||
try:
|
||||
self._ensure_tokens()
|
||||
|
||||
if self.wait_timeout <= 0:
|
||||
result = frappe.cache.lpop(self.key, shared=self.shared)
|
||||
return self._decode(result) if result is not None else None
|
||||
|
||||
if result := frappe.cache.blpop(self.key, timeout=int(self.wait_timeout), shared=self.shared):
|
||||
return self._decode(result[1])
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
frappe.log_error(f"RedisSemaphore({self.key}): Redis unavailable, skipping limit")
|
||||
return "fallback"
|
||||
|
||||
def release(self, token: str) -> None:
|
||||
"""Return *token* to the pool."""
|
||||
if token == "fallback":
|
||||
return
|
||||
try:
|
||||
frappe.cache.lpush(self.key, token, shared=self.shared)
|
||||
except Exception:
|
||||
frappe.log_error(f"RedisSemaphore({self.key}): Failed to release token {token}")
|
||||
|
||||
# -- context-manager protocol ------------------------------------------
|
||||
|
||||
def __enter__(self):
|
||||
self._token = self.acquire()
|
||||
return self._token
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
if self._token is not None:
|
||||
self.release(self._token)
|
||||
self._token = None
|
||||
|
||||
# -- internals ---------------------------------------------------------
|
||||
|
||||
def _ensure_tokens(self) -> None:
|
||||
"""Lazily initialize the token pool."""
|
||||
try:
|
||||
if frappe.cache.exists(f"{self.key}:capacity", shared=self.shared):
|
||||
return
|
||||
frappe.cache.set_value(
|
||||
f"{self.key}:capacity",
|
||||
self.limit,
|
||||
expires_in_sec=self.CAPACITY_TTL,
|
||||
shared=self.shared,
|
||||
)
|
||||
frappe.cache.delete_value(self.key, shared=self.shared)
|
||||
for i in range(1, self.limit + 1):
|
||||
frappe.cache.lpush(self.key, str(i), shared=self.shared)
|
||||
except Exception:
|
||||
frappe.log_error(f"RedisSemaphore({self.key}): Failed to initialize tokens")
|
||||
|
||||
@staticmethod
|
||||
def _decode(result):
|
||||
return result.decode() if isinstance(result, bytes) else result
|
||||
|
|
@ -125,19 +125,19 @@ class RedisWrapper(redis.Redis):
|
|||
|
||||
return ret
|
||||
|
||||
def get_keys(self, key):
|
||||
def get_keys(self, key, user=None, shared=False):
|
||||
"""Return keys starting with `key`."""
|
||||
try:
|
||||
key = self.make_key(key + "*")
|
||||
key = self.make_key(key + "*", user=user, shared=shared)
|
||||
return self.keys(key)
|
||||
|
||||
except redis.exceptions.ConnectionError:
|
||||
regex = re.compile(cstr(key).replace("|", r"\|").replace("*", r"[\w]*"))
|
||||
return [k for k in list(frappe.local.cache) if regex.match(cstr(k))]
|
||||
|
||||
def delete_keys(self, key):
|
||||
def delete_keys(self, key, user=None, shared=False):
|
||||
"""Delete keys with wildcard `*`."""
|
||||
self.delete_value(self.get_keys(key), make_keys=False)
|
||||
self.delete_value(self.get_keys(key, user=user, shared=shared), make_keys=False)
|
||||
|
||||
def delete_key(self, *args, **kwargs):
|
||||
self.delete_value(*args, **kwargs)
|
||||
|
|
@ -162,18 +162,21 @@ class RedisWrapper(redis.Redis):
|
|||
except redis.exceptions.ConnectionError:
|
||||
pass
|
||||
|
||||
def lpush(self, key, value):
|
||||
return super().lpush(self.make_key(key), value)
|
||||
def lpush(self, key, value, user=None, shared=False):
|
||||
return super().lpush(self.make_key(key, user=user, shared=shared), value)
|
||||
|
||||
def rpush(self, key, value):
|
||||
return super().rpush(self.make_key(key), value)
|
||||
|
||||
def lpop(self, key):
|
||||
return super().lpop(self.make_key(key))
|
||||
def lpop(self, key, user=None, shared=False):
|
||||
return super().lpop(self.make_key(key, user=user, shared=shared))
|
||||
|
||||
def rpop(self, key):
|
||||
return super().rpop(self.make_key(key))
|
||||
|
||||
def blpop(self, key, timeout=0, user=None, shared=False):
|
||||
return super().blpop(self.make_key(key, user=user, shared=shared), timeout=timeout)
|
||||
|
||||
def llen(self, key):
|
||||
return super().llen(self.make_key(key))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue