fix: guard all writes with an RLock

This commit is contained in:
Ankush Menat 2025-01-06 18:14:11 +05:30
parent 53f085e0f4
commit 98b1df7dac

View file

@ -2,6 +2,7 @@
# License: MIT. See LICENSE
import pickle
import re
import threading
import time
import typing
from contextlib import suppress
@ -417,6 +418,7 @@ class _ClientCache:
self.maxsize = maxsize or 1024 # Expect 1024 * 4kb objects ~ 4MB
self.local_ttl = ttl
self.expiration = time.monotonic() + ttl
self._lock = threading.RLock()
self.redis: RedisWrapper = RedisWrapper.from_url(
frappe.conf.get("redis_cache"),
@ -437,8 +439,6 @@ class _ClientCache:
val = self.local_cache[key]
if time.monotonic() < val[1] and self.cache_healthy:
return val[0]
else:
self.local_cache.pop(key, None) # expired
except KeyError:
pass # cache miss
@ -451,7 +451,8 @@ class _ClientCache:
return None
self.ensure_max_size()
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
with self._lock:
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
return val
@ -459,7 +460,8 @@ class _ClientCache:
key = self.redis.make_key(key)
self.ensure_max_size()
self.redis.set_value(key, val, shared=True)
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
with self._lock:
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
# XXX: We need to tell redis that we indeed read this key we just wrote
# This is an edge case:
# - Client A writes a key and reads it again from local cache
@ -469,13 +471,14 @@ class _ClientCache:
def ensure_max_size(self):
if len(self.local_cache) >= self.maxsize:
with suppress(RuntimeError):
with self._lock, suppress(RuntimeError):
self.local_cache.pop(next(iter(self.local_cache)), None)
def delete_value(self, key):
key = self.redis.make_key(key)
self.redis.delete_value(key, shared=True)
self.local_cache.pop(key, None)
with self._lock:
self.local_cache.pop(key, None)
def run_invalidator_thread(self):
self._watcher = self.monitor.pubsub()
@ -490,12 +493,13 @@ class _ClientCache:
if message["data"] is None:
# Flushall
self.clear_cache()
for key in message["data"]:
self.local_cache.pop(key, None)
with self._lock:
for key in message["data"]:
self.local_cache.pop(key, None)
def _exception_handler(self, exc, pubsub, pubsub_thread):
if isinstance(exc, (redis.exceptions.ConnectionError)):
self.local_cache.clear()
self.clear_cache()
self._conn_retries += 1
if self._conn_retries > 10:
self.cache_healthy = False
@ -506,4 +510,5 @@ class _ClientCache:
raise
def clear_cache(self):
self.local_cache.clear()
with self._lock:
self.local_cache.clear()