refactor: variable names, force RESP2
This commit is contained in:
parent
98b1df7dac
commit
119af71ae3
2 changed files with 29 additions and 30 deletions
|
|
@ -55,4 +55,4 @@ class TestClientCache(IntegrationTestCase):
|
|||
c.set_value(frappe.generate_hash(), 42)
|
||||
c.set_value(frappe.generate_hash(), 42)
|
||||
|
||||
self.assertEqual(len(c.local_cache), 2)
|
||||
self.assertEqual(len(c.cache), 2)
|
||||
|
|
|
|||
|
|
@ -397,14 +397,14 @@ def get_sentinel_connection(
|
|||
|
||||
class _TrackedConnection(redis.Connection):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.monitor_id = kwargs.pop("_monitor_id")
|
||||
self._invalidator_id = kwargs.pop("_invalidator_id")
|
||||
super().__init__(*args, **kwargs)
|
||||
# Every redis connection needs to enable client tracking to get notified about invalidated
|
||||
# keys.
|
||||
self.register_connect_callback(self._enable_client_tracking)
|
||||
|
||||
def _enable_client_tracking(self, conn):
|
||||
conn.send_command("CLIENT", "TRACKING", "ON", "redirect", self.monitor_id, "NOLOOP")
|
||||
conn.send_command("CLIENT", "TRACKING", "ON", "redirect", self._invalidator_id, "NOLOOP")
|
||||
conn.read_response()
|
||||
|
||||
|
||||
|
|
@ -413,30 +413,29 @@ _ClientCacheValue = tuple[typing.Any, int]
|
|||
|
||||
class _ClientCache:
|
||||
def __init__(self, maxsize: int = 1024, ttl=10 * 60, monitor: RedisWrapper | None = None) -> None:
|
||||
self.monitor = frappe.cache
|
||||
self.monitor_id = self.monitor.client_id()
|
||||
self.maxsize = maxsize or 1024 # Expect 1024 * 4kb objects ~ 4MB
|
||||
self.local_ttl = ttl
|
||||
self.expiration = time.monotonic() + ttl
|
||||
self._lock = threading.RLock()
|
||||
# This guards writes to self.cache, reads are done without a lock.
|
||||
self.lock = threading.RLock()
|
||||
self.cache: dict[bytes, _ClientCacheValue] = {}
|
||||
|
||||
self.invalidator = frappe.cache
|
||||
self.invalidator_id = self.invalidator.client_id()
|
||||
|
||||
self.redis: RedisWrapper = RedisWrapper.from_url(
|
||||
frappe.conf.get("redis_cache"),
|
||||
connection_class=_TrackedConnection,
|
||||
_monitor_id=self.monitor_id,
|
||||
_invalidator_id=self.invalidator_id,
|
||||
protocol=2,
|
||||
)
|
||||
protocol = self.redis.get_connection_kwargs().get("protocol")
|
||||
if cint(protocol) == 3:
|
||||
frappe.throw("RESP3 is not supported while connecting to Redis.") # nosemgrep
|
||||
self.invalidator_thread = self.run_invalidator_thread()
|
||||
self.local_cache: dict[bytes, _ClientCacheValue] = {}
|
||||
self.cache_healthy = True
|
||||
self._conn_retries = 0
|
||||
self.connection_retries = 0
|
||||
|
||||
def get_value(self, key):
|
||||
key = self.redis.make_key(key)
|
||||
try:
|
||||
val = self.local_cache[key]
|
||||
val = self.cache[key]
|
||||
if time.monotonic() < val[1] and self.cache_healthy:
|
||||
return val[0]
|
||||
except KeyError:
|
||||
|
|
@ -451,8 +450,8 @@ class _ClientCache:
|
|||
return None
|
||||
|
||||
self.ensure_max_size()
|
||||
with self._lock:
|
||||
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
|
||||
with self.lock:
|
||||
self.cache[key] = (val, time.monotonic() + self.local_ttl)
|
||||
|
||||
return val
|
||||
|
||||
|
|
@ -460,8 +459,8 @@ class _ClientCache:
|
|||
key = self.redis.make_key(key)
|
||||
self.ensure_max_size()
|
||||
self.redis.set_value(key, val, shared=True)
|
||||
with self._lock:
|
||||
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
|
||||
with self.lock:
|
||||
self.cache[key] = (val, time.monotonic() + self.local_ttl)
|
||||
# XXX: We need to tell redis that we indeed read this key we just wrote
|
||||
# This is an edge case:
|
||||
# - Client A writes a key and reads it again from local cache
|
||||
|
|
@ -470,18 +469,18 @@ class _ClientCache:
|
|||
_ = self.redis.get_value(key, shared=True, use_local_cache=False)
|
||||
|
||||
def ensure_max_size(self):
|
||||
if len(self.local_cache) >= self.maxsize:
|
||||
with self._lock, suppress(RuntimeError):
|
||||
self.local_cache.pop(next(iter(self.local_cache)), None)
|
||||
if len(self.cache) >= self.maxsize:
|
||||
with self.lock, suppress(RuntimeError):
|
||||
self.cache.pop(next(iter(self.cache)), None)
|
||||
|
||||
def delete_value(self, key):
|
||||
key = self.redis.make_key(key)
|
||||
self.redis.delete_value(key, shared=True)
|
||||
with self._lock:
|
||||
self.local_cache.pop(key, None)
|
||||
with self.lock:
|
||||
self.cache.pop(key, None)
|
||||
|
||||
def run_invalidator_thread(self):
|
||||
self._watcher = self.monitor.pubsub()
|
||||
self._watcher = self.invalidator.pubsub()
|
||||
self._watcher.subscribe(**{"__redis__:invalidate": self._handle_invalidation})
|
||||
return self._watcher.run_in_thread(
|
||||
sleep_time=None,
|
||||
|
|
@ -493,15 +492,15 @@ class _ClientCache:
|
|||
if message["data"] is None:
|
||||
# Flushall
|
||||
self.clear_cache()
|
||||
with self._lock:
|
||||
with self.lock:
|
||||
for key in message["data"]:
|
||||
self.local_cache.pop(key, None)
|
||||
self.cache.pop(key, None)
|
||||
|
||||
def _exception_handler(self, exc, pubsub, pubsub_thread):
|
||||
if isinstance(exc, (redis.exceptions.ConnectionError)):
|
||||
self.clear_cache()
|
||||
self._conn_retries += 1
|
||||
if self._conn_retries > 10:
|
||||
self.connection_retries += 1
|
||||
if self.connection_retries > 10:
|
||||
self.cache_healthy = False
|
||||
raise
|
||||
time.sleep(1)
|
||||
|
|
@ -510,5 +509,5 @@ class _ClientCache:
|
|||
raise
|
||||
|
||||
def clear_cache(self):
|
||||
with self._lock:
|
||||
self.local_cache.clear()
|
||||
with self.lock:
|
||||
self.cache.clear()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue