refactor: variable names, force RESP2

This commit is contained in:
Ankush Menat 2025-01-06 18:18:06 +05:30
parent 98b1df7dac
commit 119af71ae3
2 changed files with 29 additions and 30 deletions

View file

@ -55,4 +55,4 @@ class TestClientCache(IntegrationTestCase):
c.set_value(frappe.generate_hash(), 42)
c.set_value(frappe.generate_hash(), 42)
self.assertEqual(len(c.local_cache), 2)
self.assertEqual(len(c.cache), 2)

View file

@ -397,14 +397,14 @@ def get_sentinel_connection(
class _TrackedConnection(redis.Connection):
def __init__(self, *args, **kwargs):
self.monitor_id = kwargs.pop("_monitor_id")
self._invalidator_id = kwargs.pop("_invalidator_id")
super().__init__(*args, **kwargs)
# Every redis connection needs to enable client tracking to get notified about invalidated
# keys.
self.register_connect_callback(self._enable_client_tracking)
def _enable_client_tracking(self, conn):
conn.send_command("CLIENT", "TRACKING", "ON", "redirect", self.monitor_id, "NOLOOP")
conn.send_command("CLIENT", "TRACKING", "ON", "redirect", self._invalidator_id, "NOLOOP")
conn.read_response()
@ -413,30 +413,29 @@ _ClientCacheValue = tuple[typing.Any, int]
class _ClientCache:
def __init__(self, maxsize: int = 1024, ttl=10 * 60, monitor: RedisWrapper | None = None) -> None:
self.monitor = frappe.cache
self.monitor_id = self.monitor.client_id()
self.maxsize = maxsize or 1024 # Expect 1024 * 4kb objects ~ 4MB
self.local_ttl = ttl
self.expiration = time.monotonic() + ttl
self._lock = threading.RLock()
# This guards writes to self.cache, reads are done without a lock.
self.lock = threading.RLock()
self.cache: dict[bytes, _ClientCacheValue] = {}
self.invalidator = frappe.cache
self.invalidator_id = self.invalidator.client_id()
self.redis: RedisWrapper = RedisWrapper.from_url(
frappe.conf.get("redis_cache"),
connection_class=_TrackedConnection,
_monitor_id=self.monitor_id,
_invalidator_id=self.invalidator_id,
protocol=2,
)
protocol = self.redis.get_connection_kwargs().get("protocol")
if cint(protocol) == 3:
frappe.throw("RESP3 is not supported while connecting to Redis.") # nosemgrep
self.invalidator_thread = self.run_invalidator_thread()
self.local_cache: dict[bytes, _ClientCacheValue] = {}
self.cache_healthy = True
self._conn_retries = 0
self.connection_retries = 0
def get_value(self, key):
key = self.redis.make_key(key)
try:
val = self.local_cache[key]
val = self.cache[key]
if time.monotonic() < val[1] and self.cache_healthy:
return val[0]
except KeyError:
@ -451,8 +450,8 @@ class _ClientCache:
return None
self.ensure_max_size()
with self._lock:
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
with self.lock:
self.cache[key] = (val, time.monotonic() + self.local_ttl)
return val
@ -460,8 +459,8 @@ class _ClientCache:
key = self.redis.make_key(key)
self.ensure_max_size()
self.redis.set_value(key, val, shared=True)
with self._lock:
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
with self.lock:
self.cache[key] = (val, time.monotonic() + self.local_ttl)
# XXX: We need to tell redis that we indeed read this key we just wrote
# This is an edge case:
# - Client A writes a key and reads it again from local cache
@ -470,18 +469,18 @@ class _ClientCache:
_ = self.redis.get_value(key, shared=True, use_local_cache=False)
def ensure_max_size(self):
if len(self.local_cache) >= self.maxsize:
with self._lock, suppress(RuntimeError):
self.local_cache.pop(next(iter(self.local_cache)), None)
if len(self.cache) >= self.maxsize:
with self.lock, suppress(RuntimeError):
self.cache.pop(next(iter(self.cache)), None)
def delete_value(self, key):
key = self.redis.make_key(key)
self.redis.delete_value(key, shared=True)
with self._lock:
self.local_cache.pop(key, None)
with self.lock:
self.cache.pop(key, None)
def run_invalidator_thread(self):
self._watcher = self.monitor.pubsub()
self._watcher = self.invalidator.pubsub()
self._watcher.subscribe(**{"__redis__:invalidate": self._handle_invalidation})
return self._watcher.run_in_thread(
sleep_time=None,
@ -493,15 +492,15 @@ class _ClientCache:
if message["data"] is None:
# Flushall
self.clear_cache()
with self._lock:
with self.lock:
for key in message["data"]:
self.local_cache.pop(key, None)
self.cache.pop(key, None)
def _exception_handler(self, exc, pubsub, pubsub_thread):
if isinstance(exc, (redis.exceptions.ConnectionError)):
self.clear_cache()
self._conn_retries += 1
if self._conn_retries > 10:
self.connection_retries += 1
if self.connection_retries > 10:
self.cache_healthy = False
raise
time.sleep(1)
@ -510,5 +509,5 @@ class _ClientCache:
raise
def clear_cache(self):
with self._lock:
self.local_cache.clear()
with self.lock:
self.cache.clear()