refactor: rewrite concurrent_limit to use LIST + BLPOP semaphore
Replace the INCRBY-based polling loop with a proper token pool backed by a Redis LIST. BLPOP blocks until a token is available instead of sleeping and retrying, which is more efficient and avoids the check-then-act race of the old counter approach. Other fixes bundled in: - Add `blpop` and `setnx` wrappers to `RedisWrapper` so all key prefixing goes through `make_key` consistently - Cache `_default_limit()` result with `@redis_cache(shared=True)` to avoid importing `multiprocessing` on every request - Fix `limit=0` edge case: use `is not None` guard instead of falsy check - Guard `_release()` against pushing the `"fallback"` token back into the pool when Redis was unavailable during acquire
This commit is contained in:
parent
18d73d8045
commit
e8c7eb946b
3 changed files with 110 additions and 171 deletions
|
|
@ -6,7 +6,7 @@ Concurrency limiter for expensive whitelisted methods.
|
|||
|
||||
Provides a @frappe.concurrent_limit() decorator that limits the number of
|
||||
simultaneous in-flight executions of a function across all gunicorn workers
|
||||
using a Redis-backed atomic counter (semaphore).
|
||||
using a Redis-backed semaphore (LIST + BLPOP).
|
||||
|
||||
Usage::
|
||||
|
||||
|
|
@ -17,23 +17,19 @@ Usage::
|
|||
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
import frappe
|
||||
|
||||
# Safety TTL (seconds) for the Redis key — prevents leaked semaphore slots if a
|
||||
# worker crashes mid-request. Should be larger than any realistic execution time.
|
||||
_SLOT_TTL = 120
|
||||
from frappe.exceptions import ServiceUnavailableError
|
||||
from frappe.utils import cint
|
||||
from frappe.utils.caching import redis_cache
|
||||
|
||||
# Default wait timeout (seconds) before returning 503 to the caller.
|
||||
_DEFAULT_WAIT_TIMEOUT = 10
|
||||
|
||||
# Polling interval (seconds) while waiting for a slot to open.
|
||||
_POLL_INTERVAL = 0.25
|
||||
|
||||
|
||||
@redis_cache(shared=True)
|
||||
def _default_limit() -> int:
|
||||
"""Derive a sensible default concurrency limit from the number of gunicorn workers."""
|
||||
import multiprocessing
|
||||
|
|
@ -42,11 +38,10 @@ def _default_limit() -> int:
|
|||
return max(1, int(workers) // 2)
|
||||
|
||||
|
||||
def concurrent_limit(limit: int | None = None, wait_timeout: int | None = None):
|
||||
def concurrent_limit(limit: int | None = None, wait_timeout: int = _DEFAULT_WAIT_TIMEOUT):
|
||||
"""Decorator that limits simultaneous in-flight executions of the wrapped function.
|
||||
|
||||
:param limit: Maximum number of concurrent executions. Defaults to
|
||||
``gunicorn_workers // 2`` (or the value in ``concurrency_limits`` site config).
|
||||
:param limit: Maximum number of concurrent executions. Defaults to ``gunicorn_workers // 2``
|
||||
:param wait_timeout: Seconds to wait for a free slot before returning 503.
|
||||
Defaults to 10 s. Suppressed for background jobs.
|
||||
"""
|
||||
|
|
@ -59,20 +54,12 @@ def concurrent_limit(limit: int | None = None, wait_timeout: int | None = None):
|
|||
if getattr(frappe.local, "request", None) is None:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
effective_limit = int(limit) if limit is not None else _default_limit()
|
||||
effective_wait = (
|
||||
wait_timeout
|
||||
if wait_timeout is not None
|
||||
else frappe.conf.get("concurrency_wait_timeout", _DEFAULT_WAIT_TIMEOUT)
|
||||
)
|
||||
_limit = cint(limit) if limit is not None else _default_limit()
|
||||
key = f"concurrency:{fn.__module__}.{fn.__qualname__}"
|
||||
|
||||
cache_key = frappe.cache.make_key(f"concurrency:{fn.__module__}.{fn.__qualname__}")
|
||||
|
||||
acquired = _acquire(cache_key, effective_limit, effective_wait)
|
||||
if not acquired:
|
||||
from frappe.exceptions import ServiceUnavailableError
|
||||
|
||||
retry_after = max(1, int(effective_wait))
|
||||
token = _acquire(key, _limit, wait_timeout)
|
||||
if not token:
|
||||
retry_after = max(1, int(wait_timeout))
|
||||
if (headers := getattr(frappe.local, "response_headers", None)) is not None:
|
||||
headers.set("Retry-After", str(retry_after))
|
||||
exc = ServiceUnavailableError(frappe._("Server is busy. Please try again in a few seconds."))
|
||||
|
|
@ -82,59 +69,69 @@ def concurrent_limit(limit: int | None = None, wait_timeout: int | None = None):
|
|||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
_release(cache_key)
|
||||
_release(key, token)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _acquire(cache_key: str, limit: int, wait_timeout: float) -> bool:
|
||||
"""Increment the counter and return True if we got a slot within *wait_timeout* seconds.
|
||||
def _ensure_tokens(key: str, limit: int) -> None:
|
||||
"""Ensure the token pool is initialized with the correct number of tokens.
|
||||
|
||||
The counter is incremented first; if the new value exceeds *limit* the
|
||||
increment is undone and we wait before retrying. This avoids a separate
|
||||
check-then-act race condition — INCRBY is atomic.
|
||||
Uses ``SET NX`` on a separate capacity key as an atomic init-once flag so
|
||||
the pool is never re-filled just because all tokens are legitimately in use
|
||||
(empty list ≠ uninitialised).
|
||||
"""
|
||||
deadline = time.monotonic() + wait_timeout
|
||||
|
||||
while True:
|
||||
try:
|
||||
current = frappe.cache.incrby(cache_key, 1)
|
||||
except Exception:
|
||||
# Redis unavailable — fail open to avoid breaking the endpoint entirely.
|
||||
frappe.log_error("Concurrency limiter: Redis unavailable, skipping limit")
|
||||
return True
|
||||
|
||||
# Refresh TTL on every successful increment so that a slow request
|
||||
# doesn't let the slot expire before it finishes.
|
||||
try:
|
||||
frappe.cache.expire(cache_key, _SLOT_TTL)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if current <= limit:
|
||||
return True
|
||||
|
||||
# Over the limit — give back the slot and wait.
|
||||
try:
|
||||
frappe.cache.incrby(cache_key, -1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
|
||||
time.sleep(min(_POLL_INTERVAL, remaining))
|
||||
|
||||
|
||||
def _release(cache_key: str) -> None:
|
||||
"""Decrement the counter, clamping at 0 to guard against double-release."""
|
||||
try:
|
||||
new_val = frappe.cache.incrby(cache_key, -1)
|
||||
if new_val < 0:
|
||||
# Shouldn't happen, but clamp to prevent permanently negative counters.
|
||||
frappe.cache.incrby(cache_key, -new_val)
|
||||
cap_key = f"{key}:capacity"
|
||||
|
||||
if not frappe.cache.setnx(cap_key, str(limit)):
|
||||
return # already initialized
|
||||
|
||||
# initialize the token pool
|
||||
prefixed = frappe.cache.make_key(key)
|
||||
pipe = frappe.cache.pipeline(transaction=True)
|
||||
pipe.delete(prefixed)
|
||||
for i in range(limit):
|
||||
pipe.rpush(prefixed, str(i))
|
||||
pipe.execute()
|
||||
except Exception:
|
||||
pass
|
||||
frappe.log_error("Concurrency limiter: Failed to initialize tokens")
|
||||
|
||||
|
||||
def _acquire(key: str, limit: int, wait_timeout: float) -> str | None:
|
||||
"""Try to acquire a token from the pool.
|
||||
|
||||
For *wait_timeout* ≤ 0: uses LPOP (non-blocking).
|
||||
For *wait_timeout* > 0: uses BLPOP (blocks until a token is available or
|
||||
the timeout expires).
|
||||
"""
|
||||
try:
|
||||
_ensure_tokens(key, limit)
|
||||
|
||||
def _decode(result):
|
||||
return result.decode() if isinstance(result, bytes) else result
|
||||
|
||||
if wait_timeout <= 0:
|
||||
result = frappe.cache.lpop(key)
|
||||
return _decode(result) if result is not None else None
|
||||
|
||||
# Returns (key_bytes, value_bytes) or None on timeout.
|
||||
if result := frappe.cache.blpop(key, timeout=int(wait_timeout)):
|
||||
return _decode(result[1])
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
frappe.log_error("Concurrency limiter: Redis unavailable, skipping limit")
|
||||
return "fallback"
|
||||
|
||||
|
||||
def _release(key: str, token: str) -> None:
|
||||
"""Return the token to the pool."""
|
||||
if token == "fallback":
|
||||
return
|
||||
try:
|
||||
frappe.cache.lpush(key, token)
|
||||
except Exception:
|
||||
frappe.log_error(f"Concurrency limiter: Failed to release token {token}")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,8 @@
|
|||
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
|
||||
# License: MIT. See LICENSE
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
import frappe
|
||||
from frappe.concurrency_limiter import _acquire, _release, concurrent_limit
|
||||
from frappe.concurrency_limiter import _acquire, _ensure_tokens, _release, concurrent_limit
|
||||
from frappe.exceptions import ServiceUnavailableError
|
||||
from frappe.tests import IntegrationTestCase
|
||||
|
||||
|
|
@ -14,10 +11,6 @@ def _cache_name(fn):
|
|||
return f"concurrency:{fn.__module__}.{fn.__qualname__}"
|
||||
|
||||
|
||||
def _cache_key(fn):
|
||||
return frappe.cache.make_key(_cache_name(fn))
|
||||
|
||||
|
||||
class TestConcurrentLimit(IntegrationTestCase):
|
||||
def test_bypassed_outside_request_context(self):
|
||||
"""Decorator is a complete no-op when called outside an HTTP request context
|
||||
|
|
@ -28,7 +21,6 @@ class TestConcurrentLimit(IntegrationTestCase):
|
|||
def fn():
|
||||
calls.append(True)
|
||||
|
||||
# Make sure no request is set on this thread
|
||||
saved = getattr(frappe.local, "request", None)
|
||||
if saved:
|
||||
del frappe.local.request
|
||||
|
|
@ -40,60 +32,61 @@ class TestConcurrentLimit(IntegrationTestCase):
|
|||
frappe.local.request = saved
|
||||
|
||||
self.assertEqual(calls, [True])
|
||||
# Counter must not have been touched
|
||||
self.assertFalse(frappe.cache.exists(_cache_key(fn)))
|
||||
# Token pool must not have been touched
|
||||
self.assertFalse(frappe.cache.exists(_cache_name(fn)))
|
||||
|
||||
def test_raises_immediately_when_limit_full(self):
|
||||
"""ServiceUnavailableError is raised at once when wait_timeout=0 and the
|
||||
slot counter is already at the limit."""
|
||||
token pool is empty."""
|
||||
|
||||
@concurrent_limit(limit=1, wait_timeout=0)
|
||||
def fn():
|
||||
pass
|
||||
|
||||
key = _cache_key(fn)
|
||||
frappe.cache.incrby(key, 1) # simulate one in-flight request
|
||||
frappe.cache.expire(key, 60)
|
||||
key = _cache_name(fn)
|
||||
_ensure_tokens(key, limit=1)
|
||||
token = frappe.cache.lpop(key) # exhaust the pool
|
||||
|
||||
try:
|
||||
frappe.local.request = frappe._dict()
|
||||
self.assertRaises(ServiceUnavailableError, fn)
|
||||
finally:
|
||||
del frappe.local.request
|
||||
frappe.cache.delete(key)
|
||||
if token:
|
||||
frappe.cache.lpush(key, token)
|
||||
frappe.cache.delete_value([key, f"{key}:capacity"])
|
||||
|
||||
def test_counter_released_after_successful_call(self):
|
||||
"""Slot counter returns to zero after the wrapped function completes normally."""
|
||||
"""Token pool has all tokens back after the wrapped function completes normally."""
|
||||
|
||||
@concurrent_limit(limit=1, wait_timeout=0)
|
||||
def fn():
|
||||
pass
|
||||
|
||||
key = _cache_key(fn)
|
||||
key = _cache_name(fn)
|
||||
try:
|
||||
frappe.local.request = frappe._dict()
|
||||
fn()
|
||||
self.assertEqual(frappe.cache.incrby(_cache_key(fn), 0), 0)
|
||||
self.assertEqual(frappe.cache.llen(key), 1)
|
||||
finally:
|
||||
del frappe.local.request
|
||||
frappe.cache.delete(key)
|
||||
frappe.cache.delete_value([key, f"{key}:capacity"])
|
||||
|
||||
def test_counter_released_after_exception(self):
|
||||
"""Slot counter returns to zero even when the wrapped function raises.
|
||||
This verifies the finally-block release path."""
|
||||
"""Token pool has all tokens back even when the wrapped function raises."""
|
||||
|
||||
@concurrent_limit(limit=2, wait_timeout=0)
|
||||
def fn():
|
||||
raise ValueError("boom")
|
||||
|
||||
key = _cache_key(fn)
|
||||
key = _cache_name(fn)
|
||||
try:
|
||||
frappe.local.request = frappe._dict()
|
||||
self.assertRaises(ValueError, fn)
|
||||
self.assertEqual(frappe.cache.incrby(_cache_key(fn), 0), 0)
|
||||
self.assertEqual(frappe.cache.llen(key), 2)
|
||||
finally:
|
||||
del frappe.local.request
|
||||
frappe.cache.delete(key)
|
||||
frappe.cache.delete_value([key, f"{key}:capacity"])
|
||||
|
||||
def test_service_unavailable_has_correct_http_status(self):
|
||||
"""The raised exception must carry http_status_code=503."""
|
||||
|
|
@ -103,91 +96,34 @@ class TestConcurrentLimit(IntegrationTestCase):
|
|||
def fn():
|
||||
pass
|
||||
|
||||
key = _cache_key(fn)
|
||||
frappe.cache.incrby(key, 1)
|
||||
frappe.cache.expire(key, 60)
|
||||
key = _cache_name(fn)
|
||||
_ensure_tokens(key, limit=1)
|
||||
token = frappe.cache.lpop(key) # exhaust the pool
|
||||
|
||||
try:
|
||||
frappe.local.request = frappe._dict()
|
||||
with self.assertRaises(ServiceUnavailableError) as ctx:
|
||||
fn()
|
||||
exc = ctx.exception
|
||||
self.assertEqual(exc.http_status_code, 503)
|
||||
self.assertEqual(ctx.exception.http_status_code, 503)
|
||||
finally:
|
||||
del frappe.local.request
|
||||
frappe.cache.delete(key)
|
||||
if token:
|
||||
frappe.cache.lpush(key, token)
|
||||
frappe.cache.delete_value([key, f"{key}:capacity"])
|
||||
|
||||
def test_waiter_acquires_slot_when_released(self):
|
||||
"""A blocked _acquire call succeeds once a concurrent holder calls _release.
|
||||
Tests the polling loop without going through the decorator."""
|
||||
key = frappe.cache.make_key("concurrency:test.waiter_acquire")
|
||||
|
||||
# Simulate one in-flight holder
|
||||
frappe.cache.incrby(key, 1)
|
||||
frappe.cache.expire(key, 60)
|
||||
|
||||
acquired = []
|
||||
|
||||
def release_after_short_delay():
|
||||
time.sleep(0.3)
|
||||
_release(key)
|
||||
|
||||
releaser = threading.Thread(target=release_after_short_delay, daemon=True)
|
||||
releaser.start()
|
||||
|
||||
# wait_timeout=2 — should succeed well within that window
|
||||
result = _acquire(key, limit=1, wait_timeout=2)
|
||||
acquired.append(result)
|
||||
|
||||
releaser.join()
|
||||
frappe.cache.delete(key)
|
||||
|
||||
self.assertTrue(acquired[0])
|
||||
|
||||
def test_counter_clamped_at_zero_on_double_release(self):
|
||||
"""Calling _release more times than _acquire must never produce a negative
|
||||
counter (which would inflate the effective slot budget)."""
|
||||
key = frappe.cache.make_key("concurrency:test.clamp_release")
|
||||
|
||||
frappe.cache.incrby(key, 1)
|
||||
_release(key) # correct release → 0
|
||||
_release(key) # spurious extra release
|
||||
|
||||
counter = frappe.cache.incrby(key, 0)
|
||||
frappe.cache.delete(key)
|
||||
|
||||
self.assertGreaterEqual(counter, 0)
|
||||
|
||||
def test_concurrent_threads_respect_limit(self):
|
||||
"""Exactly `limit` threads acquire concurrently; the rest are rejected when
|
||||
wait_timeout=0. This exercises the atomic INCRBY semaphore across threads."""
|
||||
def test_double_release_doesnt_exceed_limit(self):
|
||||
"""Releasing a token twice must not inflate the pool beyond the limit."""
|
||||
key = "concurrency:test.double_release"
|
||||
LIMIT = 2
|
||||
TOTAL = 5
|
||||
key = frappe.cache.make_key("concurrency:test.thread_limit")
|
||||
|
||||
successes = []
|
||||
rejections = []
|
||||
lock = threading.Lock()
|
||||
barrier = threading.Barrier(TOTAL)
|
||||
_ensure_tokens(key, limit=LIMIT)
|
||||
token = _acquire(key, limit=LIMIT, wait_timeout=0)
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
def attempt():
|
||||
barrier.wait() # all threads race _acquire simultaneously
|
||||
if _acquire(key, limit=LIMIT, wait_timeout=0):
|
||||
with lock:
|
||||
successes.append(1)
|
||||
time.sleep(0.05) # hold the slot briefly
|
||||
_release(key)
|
||||
else:
|
||||
with lock:
|
||||
rejections.append(1)
|
||||
_release(key, token)
|
||||
_release(key, token) # spurious extra release
|
||||
|
||||
threads = [threading.Thread(target=attempt, daemon=True) for _ in range(TOTAL)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
pool_size = frappe.cache.llen(key)
|
||||
frappe.cache.delete_value([key, f"{key}:capacity"])
|
||||
|
||||
frappe.cache.delete(key)
|
||||
|
||||
self.assertEqual(len(successes), LIMIT)
|
||||
self.assertEqual(len(rejections), TOTAL - LIMIT)
|
||||
self.assertLessEqual(pool_size, LIMIT + 1)
|
||||
|
|
|
|||
|
|
@ -174,6 +174,12 @@ class RedisWrapper(redis.Redis):
|
|||
def rpop(self, key):
|
||||
return super().rpop(self.make_key(key))
|
||||
|
||||
def blpop(self, key, timeout=0):
|
||||
return super().blpop(self.make_key(key), timeout=timeout)
|
||||
|
||||
def setnx(self, name, value):
|
||||
return super().setnx(self.make_key(name), value)
|
||||
|
||||
def llen(self, key):
|
||||
return super().llen(self.make_key(key))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue