fix: support shared RedisSemaphores for concurrency limits

This commit is contained in:
Saqib Ansari 2026-04-20 13:15:41 +05:30
parent 7f78cd25f9
commit 0064eb80b4
4 changed files with 19 additions and 16 deletions

View file

@ -94,7 +94,7 @@ def concurrent_limit(limit: int | None = None, wait_timeout: int = _DEFAULT_WAIT
_limit = cint(limit) if limit is not None else _default_limit()
key = f"concurrency:{fn.__module__}.{fn.__qualname__}"
sem = RedisSemaphore(key, _limit, wait_timeout)
sem = RedisSemaphore(key, _limit, wait_timeout, shared=True)
token = sem.acquire()
if not token:
retry_after = max(1, int(wait_timeout))

View file

@ -18,7 +18,7 @@ def _key(fn):
def _cleanup(fn):
key = _key(fn)
frappe.cache.delete_value([key, f"{key}:capacity"])
frappe.cache.delete_value([key, f"{key}:capacity"], shared=True)
class TestConcurrentLimit(IntegrationTestCase):
@ -131,11 +131,11 @@ class TestConcurrentLimit(IntegrationTestCase):
# Simulate all tokens being leaked (workers crashed mid-request)
# by draining the pool without returning tokens.
while frappe.cache.lpop(key):
while frappe.cache.lpop(key, shared=True):
pass
# Simulate capacity key TTL expiry.
frappe.cache.delete_value(f"{key}:capacity")
frappe.cache.delete_value(f"{key}:capacity", shared=True)
# Self-heal: next request must re-initialize the pool and succeed.
fn() # must not raise ServiceUnavailableError

View file

@ -48,17 +48,20 @@ if redis.call('SET', KEYS[1], ARGV[1], 'NX', 'EX', ARGV[2]) then
end
"""
def __init__(self, key: str, limit: int, wait_timeout: float = 0):
def __init__(self, key: str, limit: int, wait_timeout: float = 0, shared: bool = False):
"""
:param key: A unique Redis key name for this semaphore (will be
prefixed by the cache layer).
:param limit: Maximum number of concurrent holders.
:param wait_timeout: Seconds to block waiting for a free slot.
0 means non-blocking (immediate return if unavailable).
:param shared: If True, the semaphore key is bench-wide (not
prefixed with the site's db_name). Defaults to site-scoped.
"""
self.key = key
self.limit = limit
self.wait_timeout = wait_timeout
self.shared = shared
self._token: str | None = None
def acquire(self) -> str | None:
@ -72,10 +75,10 @@ end
self._ensure_tokens()
if self.wait_timeout <= 0:
result = frappe.cache.lpop(self.key)
result = frappe.cache.lpop(self.key, shared=self.shared)
return self._decode(result) if result is not None else None
if result := frappe.cache.blpop(self.key, timeout=int(self.wait_timeout)):
if result := frappe.cache.blpop(self.key, timeout=int(self.wait_timeout), shared=self.shared):
return self._decode(result[1])
return None
@ -88,7 +91,7 @@ end
if token == "fallback":
return
try:
frappe.cache.lpush(self.key, token)
frappe.cache.lpush(self.key, token, shared=self.shared)
except Exception:
frappe.log_error(f"RedisSemaphore({self.key}): Failed to release token {token}")
@ -108,8 +111,8 @@ end
def _ensure_tokens(self) -> None:
"""Lazily initialize the token pool via an atomic Lua script."""
try:
prefixed_cap_key = frappe.cache.make_key(f"{self.key}:capacity")
prefixed_key = frappe.cache.make_key(self.key)
prefixed_cap_key = frappe.cache.make_key(f"{self.key}:capacity", shared=self.shared)
prefixed_key = frappe.cache.make_key(self.key, shared=self.shared)
frappe.cache.eval(
self._INIT_SCRIPT,
2,

View file

@ -162,20 +162,20 @@ class RedisWrapper(redis.Redis):
except redis.exceptions.ConnectionError:
pass
def lpush(self, key, value):
return super().lpush(self.make_key(key), value)
def lpush(self, key, value, user=None, shared=False):
return super().lpush(self.make_key(key, user=user, shared=shared), value)
def rpush(self, key, value):
return super().rpush(self.make_key(key), value)
def lpop(self, key):
return super().lpop(self.make_key(key))
def lpop(self, key, user=None, shared=False):
return super().lpop(self.make_key(key, user=user, shared=shared))
def rpop(self, key):
return super().rpop(self.make_key(key))
def blpop(self, key, timeout=0):
return super().blpop(self.make_key(key), timeout=timeout)
def blpop(self, key, timeout=0, user=None, shared=False):
return super().blpop(self.make_key(key, user=user, shared=shared), timeout=timeout)
def setnx(self, name, value):
return super().setnx(self.make_key(name), value)