From 119af71ae30e9dfc8e9d761332f2e38c97495a6c Mon Sep 17 00:00:00 2001 From: Ankush Menat Date: Mon, 6 Jan 2025 18:18:06 +0530 Subject: [PATCH] refactor: variable names, force RESP2 --- frappe/tests/test_client_cache.py | 2 +- frappe/utils/redis_wrapper.py | 57 +++++++++++++++---------------- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/frappe/tests/test_client_cache.py b/frappe/tests/test_client_cache.py index ade4ad2b8b..7646806f6a 100644 --- a/frappe/tests/test_client_cache.py +++ b/frappe/tests/test_client_cache.py @@ -55,4 +55,4 @@ class TestClientCache(IntegrationTestCase): c.set_value(frappe.generate_hash(), 42) c.set_value(frappe.generate_hash(), 42) - self.assertEqual(len(c.local_cache), 2) + self.assertEqual(len(c.cache), 2) diff --git a/frappe/utils/redis_wrapper.py b/frappe/utils/redis_wrapper.py index 34d3fe597d..56538771b9 100644 --- a/frappe/utils/redis_wrapper.py +++ b/frappe/utils/redis_wrapper.py @@ -397,14 +397,14 @@ def get_sentinel_connection( class _TrackedConnection(redis.Connection): def __init__(self, *args, **kwargs): - self.monitor_id = kwargs.pop("_monitor_id") + self._invalidator_id = kwargs.pop("_invalidator_id") super().__init__(*args, **kwargs) # Every redis connection needs to enable client tracking to get notified about invalidated # keys. self.register_connect_callback(self._enable_client_tracking) def _enable_client_tracking(self, conn): - conn.send_command("CLIENT", "TRACKING", "ON", "redirect", self.monitor_id, "NOLOOP") + conn.send_command("CLIENT", "TRACKING", "ON", "redirect", self._invalidator_id, "NOLOOP") conn.read_response() @@ -413,30 +413,29 @@ _ClientCacheValue = tuple[typing.Any, int] class _ClientCache: def __init__(self, maxsize: int = 1024, ttl=10 * 60, monitor: RedisWrapper | None = None) -> None: - self.monitor = frappe.cache - self.monitor_id = self.monitor.client_id() self.maxsize = maxsize or 1024 # Expect 1024 * 4kb objects ~ 4MB self.local_ttl = ttl - self.expiration = time.monotonic() + ttl - self._lock = threading.RLock() + # This guards writes to self.cache, reads are done without a lock. + self.lock = threading.RLock() + self.cache: dict[bytes, _ClientCacheValue] = {} + + self.invalidator = frappe.cache + self.invalidator_id = self.invalidator.client_id() self.redis: RedisWrapper = RedisWrapper.from_url( frappe.conf.get("redis_cache"), connection_class=_TrackedConnection, - _monitor_id=self.monitor_id, + _invalidator_id=self.invalidator_id, + protocol=2, ) - protocol = self.redis.get_connection_kwargs().get("protocol") - if cint(protocol) == 3: - frappe.throw("RESP3 is not supported while connecting to Redis.") # nosemgrep self.invalidator_thread = self.run_invalidator_thread() - self.local_cache: dict[bytes, _ClientCacheValue] = {} self.cache_healthy = True - self._conn_retries = 0 + self.connection_retries = 0 def get_value(self, key): key = self.redis.make_key(key) try: - val = self.local_cache[key] + val = self.cache[key] if time.monotonic() < val[1] and self.cache_healthy: return val[0] except KeyError: @@ -451,8 +450,8 @@ class _ClientCache: return None self.ensure_max_size() - with self._lock: - self.local_cache[key] = (val, time.monotonic() + self.local_ttl) + with self.lock: + self.cache[key] = (val, time.monotonic() + self.local_ttl) return val @@ -460,8 +459,8 @@ class _ClientCache: key = self.redis.make_key(key) self.ensure_max_size() self.redis.set_value(key, val, shared=True) - with self._lock: - self.local_cache[key] = (val, time.monotonic() + self.local_ttl) + with self.lock: + self.cache[key] = (val, time.monotonic() + self.local_ttl) # XXX: We need to tell redis that we indeed read this key we just wrote # This is an edge case: # - Client A writes a key and reads it again from local cache @@ -470,18 +469,18 @@ class _ClientCache: _ = self.redis.get_value(key, shared=True, use_local_cache=False) def ensure_max_size(self): - if len(self.local_cache) >= self.maxsize: - with self._lock, suppress(RuntimeError): - self.local_cache.pop(next(iter(self.local_cache)), None) + if len(self.cache) >= self.maxsize: + with self.lock, suppress(RuntimeError): + self.cache.pop(next(iter(self.cache)), None) def delete_value(self, key): key = self.redis.make_key(key) self.redis.delete_value(key, shared=True) - with self._lock: - self.local_cache.pop(key, None) + with self.lock: + self.cache.pop(key, None) def run_invalidator_thread(self): - self._watcher = self.monitor.pubsub() + self._watcher = self.invalidator.pubsub() self._watcher.subscribe(**{"__redis__:invalidate": self._handle_invalidation}) return self._watcher.run_in_thread( sleep_time=None, @@ -493,15 +492,15 @@ class _ClientCache: if message["data"] is None: # Flushall self.clear_cache() - with self._lock: + with self.lock: for key in message["data"]: - self.local_cache.pop(key, None) + self.cache.pop(key, None) def _exception_handler(self, exc, pubsub, pubsub_thread): if isinstance(exc, (redis.exceptions.ConnectionError)): self.clear_cache() - self._conn_retries += 1 - if self._conn_retries > 10: + self.connection_retries += 1 + if self.connection_retries > 10: self.cache_healthy = False raise time.sleep(1) @@ -510,5 +509,5 @@ class _ClientCache: raise def clear_cache(self): - with self._lock: - self.local_cache.clear() + with self.lock: + self.cache.clear()