From 98b1df7dac7806a40daf4dad2d3a93484e33b028 Mon Sep 17 00:00:00 2001 From: Ankush Menat Date: Mon, 6 Jan 2025 18:14:11 +0530 Subject: [PATCH] fix: guard all writes with an RLock --- frappe/utils/redis_wrapper.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/frappe/utils/redis_wrapper.py b/frappe/utils/redis_wrapper.py index bcfa4831bc..34d3fe597d 100644 --- a/frappe/utils/redis_wrapper.py +++ b/frappe/utils/redis_wrapper.py @@ -2,6 +2,7 @@ # License: MIT. See LICENSE import pickle import re +import threading import time import typing from contextlib import suppress @@ -417,6 +418,7 @@ class _ClientCache: self.maxsize = maxsize or 1024 # Expect 1024 * 4kb objects ~ 4MB self.local_ttl = ttl self.expiration = time.monotonic() + ttl + self._lock = threading.RLock() self.redis: RedisWrapper = RedisWrapper.from_url( frappe.conf.get("redis_cache"), @@ -437,8 +439,6 @@ class _ClientCache: val = self.local_cache[key] if time.monotonic() < val[1] and self.cache_healthy: return val[0] - else: - self.local_cache.pop(key, None) # expired except KeyError: pass # cache miss @@ -451,7 +451,8 @@ class _ClientCache: return None self.ensure_max_size() - self.local_cache[key] = (val, time.monotonic() + self.local_ttl) + with self._lock: + self.local_cache[key] = (val, time.monotonic() + self.local_ttl) return val @@ -459,7 +460,8 @@ class _ClientCache: key = self.redis.make_key(key) self.ensure_max_size() self.redis.set_value(key, val, shared=True) - self.local_cache[key] = (val, time.monotonic() + self.local_ttl) + with self._lock: + self.local_cache[key] = (val, time.monotonic() + self.local_ttl) # XXX: We need to tell redis that we indeed read this key we just wrote # This is an edge case: # - Client A writes a key and reads it again from local cache @@ -469,13 +471,14 @@ class _ClientCache: def ensure_max_size(self): if len(self.local_cache) >= self.maxsize: - with suppress(RuntimeError): + with self._lock, suppress(RuntimeError): self.local_cache.pop(next(iter(self.local_cache)), None) def delete_value(self, key): key = self.redis.make_key(key) self.redis.delete_value(key, shared=True) - self.local_cache.pop(key, None) + with self._lock: + self.local_cache.pop(key, None) def run_invalidator_thread(self): self._watcher = self.monitor.pubsub() @@ -490,12 +493,13 @@ class _ClientCache: if message["data"] is None: # Flushall self.clear_cache() - for key in message["data"]: - self.local_cache.pop(key, None) + with self._lock: + for key in message["data"]: + self.local_cache.pop(key, None) def _exception_handler(self, exc, pubsub, pubsub_thread): if isinstance(exc, (redis.exceptions.ConnectionError)): - self.local_cache.clear() + self.clear_cache() self._conn_retries += 1 if self._conn_retries > 10: self.cache_healthy = False @@ -506,4 +510,5 @@ class _ClientCache: raise def clear_cache(self): - self.local_cache.clear() + with self._lock: + self.local_cache.clear()