refactor: extract RedisSemaphore into redis_semaphore.py

This commit is contained in:
Saqib Ansari 2026-04-19 19:28:45 +05:30
parent 65965b9c44
commit 7f78cd25f9
2 changed files with 136 additions and 84 deletions

View file

@ -24,6 +24,7 @@ import frappe
from frappe.exceptions import ServiceUnavailableError
from frappe.utils import cint
from frappe.utils.caching import site_cache
from frappe.utils.redis_semaphore import RedisSemaphore
# Default wait timeout (seconds) before returning 503 to the caller.
_DEFAULT_WAIT_TIMEOUT = 10
@ -36,12 +37,7 @@ def _default_limit() -> int:
def gunicorn_max_concurrency() -> int:
"""Detect max concurrent requests from the running gunicorn master's cmdline.
Reads /proc/<ppid>/cmdline to extract --workers and --threads without
shelling out. Falls back to a CPU-based heuristic on non-Linux platforms
or when not running under gunicorn (dev server, CLI, tests).
"""
"""Detect max concurrent requests from the running gunicorn master's cmdline."""
import os
fallback = 4
@ -79,9 +75,12 @@ def concurrent_limit(limit: int | None = None, wait_timeout: int = _DEFAULT_WAIT
"""Decorator that limits simultaneous in-flight executions of the wrapped function.
:param limit: Maximum number of concurrent executions. Defaults to half of ``workers x threads``
as detected from the gunicorn master process (or a CPU-based heuristic as fallback).
as detected from the gunicorn master process.
:param wait_timeout: Seconds to wait for a free slot before returning 503.
Defaults to 10 s. Suppressed for background jobs.
Defaults to 10 s.
The limiter is skipped entirely for background jobs, CLI commands, and
tests that call functions directly (i.e. outside of an HTTP request).
"""
def decorator(fn: Callable) -> Callable:
@ -95,7 +94,8 @@ def concurrent_limit(limit: int | None = None, wait_timeout: int = _DEFAULT_WAIT
_limit = cint(limit) if limit is not None else _default_limit()
key = f"concurrency:{fn.__module__}.{fn.__qualname__}"
token = _acquire(key, _limit, wait_timeout)
sem = RedisSemaphore(key, _limit, wait_timeout)
token = sem.acquire()
if not token:
retry_after = max(1, int(wait_timeout))
if (headers := getattr(frappe.local, "response_headers", None)) is not None:
@ -107,82 +107,8 @@ def concurrent_limit(limit: int | None = None, wait_timeout: int = _DEFAULT_WAIT
try:
return fn(*args, **kwargs)
finally:
_release(key, token)
sem.release(token)
return wrapper
return decorator
# Safety TTL (seconds) for the capacity key — allows the pool to self-heal
# after a worker crash that leaked a token. The cap key expiring causes the
# next request to re-initialize the pool to full capacity. Must be longer
# than any realistic request, but short enough to recover from crashes.
_CAPACITY_KEY_TTL = 3600 # 1 hour
# Lua script that atomically initializes the token pool.
# Combines the SET NX check and the DEL + RPUSH population into a single
# atomic operation, closing the race window between the init-flag check
# and the list population that existed with the prior setnx + pipeline approach.
# KEYS[1] = capacity key, KEYS[2] = token list key, ARGV[1] = limit, ARGV[2] = TTL
_INIT_SCRIPT = """\
if redis.call('SET', KEYS[1], ARGV[1], 'NX', 'EX', ARGV[2]) then
redis.call('DEL', KEYS[2])
local n = tonumber(ARGV[1])
for i = 1, n do
redis.call('RPUSH', KEYS[2], tostring(i))
end
end
"""
def _ensure_tokens(key: str, limit: int) -> None:
"""Ensure the token pool is initialized atomically.
A Lua script performs ``SET NX`` on the capacity key and populates the
token list in a single atomic operation, closing the race window between
the init-flag check and the list population.
"""
try:
prefixed_cap_key = frappe.cache.make_key(f"{key}:capacity")
prefixed_key = frappe.cache.make_key(key)
frappe.cache.eval(_INIT_SCRIPT, 2, prefixed_cap_key, prefixed_key, str(limit), str(_CAPACITY_KEY_TTL))
except Exception:
frappe.log_error("Concurrency limiter: Failed to initialize tokens")
def _acquire(key: str, limit: int, wait_timeout: float) -> str | None:
"""Try to acquire a token from the pool.
For *wait_timeout* 0: uses LPOP (non-blocking).
For *wait_timeout* > 0: uses BLPOP (blocks until a token is available or
the timeout expires).
"""
try:
_ensure_tokens(key, limit)
def _decode(result):
return result.decode() if isinstance(result, bytes) else result
if wait_timeout <= 0:
result = frappe.cache.lpop(key)
return _decode(result) if result is not None else None
# Returns (key_bytes, value_bytes) or None on timeout.
if result := frappe.cache.blpop(key, timeout=int(wait_timeout)):
return _decode(result[1])
return None
except Exception:
frappe.log_error("Concurrency limiter: Redis unavailable, skipping limit")
return "fallback"
def _release(key: str, token: str) -> None:
"""Return the token to the pool."""
if token == "fallback":
return
try:
frappe.cache.lpush(key, token)
except Exception:
frappe.log_error(f"Concurrency limiter: Failed to release token {token}")

View file

@ -0,0 +1,126 @@
# Copyright (c) 2024, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE
"""Distributed counting semaphore backed by a Redis LIST."""
import frappe
class RedisSemaphore:
"""A distributed counting semaphore backed by a Redis LIST.
Allows up to *limit* concurrent holders across all processes sharing the
same Redis instance. The token pool is lazily initialized via an atomic
Lua script and self-heals after crashes thanks to a TTL on the capacity
key.
Usage as a context manager::
sem = RedisSemaphore("my-resource", limit=5, wait_timeout=10)
with sem:
... # at most 5 concurrent holders
Or acquire/release manually::
token = sem.acquire()
if token is None:
raise Exception("Too busy")
try:
...
finally:
sem.release(token)
"""
# Safety TTL (seconds) for the capacity key — allows the pool to self-heal
# after a worker crash that leaked a token.
CAPACITY_TTL = 3600 # 1 hour
# Lua script that atomically initializes the token pool.
# KEYS[1] = capacity key, KEYS[2] = token list key
# ARGV[1] = limit, ARGV[2] = TTL
_INIT_SCRIPT = """\
if redis.call('SET', KEYS[1], ARGV[1], 'NX', 'EX', ARGV[2]) then
redis.call('DEL', KEYS[2])
local n = tonumber(ARGV[1])
for i = 1, n do
redis.call('RPUSH', KEYS[2], tostring(i))
end
end
"""
def __init__(self, key: str, limit: int, wait_timeout: float = 0):
"""
:param key: A unique Redis key name for this semaphore (will be
prefixed by the cache layer).
:param limit: Maximum number of concurrent holders.
:param wait_timeout: Seconds to block waiting for a free slot.
0 means non-blocking (immediate return if unavailable).
"""
self.key = key
self.limit = limit
self.wait_timeout = wait_timeout
self._token: str | None = None
def acquire(self) -> str | None:
"""Try to acquire a token from the pool.
Returns a token string on success, ``None`` if no slot was
available within *wait_timeout*, or ``"fallback"`` if Redis is
unreachable (fail-open).
"""
try:
self._ensure_tokens()
if self.wait_timeout <= 0:
result = frappe.cache.lpop(self.key)
return self._decode(result) if result is not None else None
if result := frappe.cache.blpop(self.key, timeout=int(self.wait_timeout)):
return self._decode(result[1])
return None
except Exception:
frappe.log_error(f"RedisSemaphore({self.key}): Redis unavailable, skipping limit")
return "fallback"
def release(self, token: str) -> None:
"""Return *token* to the pool."""
if token == "fallback":
return
try:
frappe.cache.lpush(self.key, token)
except Exception:
frappe.log_error(f"RedisSemaphore({self.key}): Failed to release token {token}")
# -- context-manager protocol ------------------------------------------
def __enter__(self):
self._token = self.acquire()
return self._token
def __exit__(self, *exc_info):
if self._token is not None:
self.release(self._token)
self._token = None
# -- internals ---------------------------------------------------------
def _ensure_tokens(self) -> None:
"""Lazily initialize the token pool via an atomic Lua script."""
try:
prefixed_cap_key = frappe.cache.make_key(f"{self.key}:capacity")
prefixed_key = frappe.cache.make_key(self.key)
frappe.cache.eval(
self._INIT_SCRIPT,
2,
prefixed_cap_key,
prefixed_key,
str(self.limit),
str(self.CAPACITY_TTL),
)
except Exception:
frappe.log_error(f"RedisSemaphore({self.key}): Failed to initialize tokens")
@staticmethod
def _decode(result):
return result.decode() if isinstance(result, bytes) else result