seitime-frappe/frappe/utils/redis_wrapper.py
2025-01-06 18:57:57 +05:30

514 lines
14 KiB
Python

# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE
import pickle
import re
import threading
import time
import typing
from contextlib import suppress
import redis
from redis.commands.search import Search
import frappe
from frappe.utils import cstr
from frappe.utils.data import cint
# 5 is faster than default which is 4.
# Python uses old protocol for backward compatibility, we don't support anything <3.10.
DEFAULT_PICKLE_PROTOCOL = 5
class RedisearchWrapper(Search):
def sugadd(self, key, *suggestions, **kwargs):
return super().sugadd(self.client.make_key(key), *suggestions, **kwargs)
def suglen(self, key):
return super().suglen(self.client.make_key(key))
def sugdel(self, key, string):
return super().sugdel(self.client.make_key(key), string)
def sugget(self, key, *args, **kwargs):
return super().sugget(self.client.make_key(key), *args, **kwargs)
class RedisWrapper(redis.Redis):
"""Redis client that will automatically prefix conf.db_name"""
def connected(self):
try:
self.ping()
return True
except redis.exceptions.ConnectionError:
return False
def __call__(self):
"""WARNING: Added for backward compatibility to support frappe.cache().method(...)"""
return self
def make_key(self, key, user=None, shared=False):
if shared:
return key
if user:
if user is True:
user = frappe.session.user
key = f"user:{user}:{key}"
return f"{frappe.conf.db_name}|{key}".encode()
def set_value(self, key, val, user=None, expires_in_sec=None, shared=False):
"""Sets cache value.
:param key: Cache key
:param val: Value to be cached
:param user: Prepends key with User
:param expires_in_sec: Expire value of this key in X seconds
"""
key = self.make_key(key, user, shared)
frappe.local.cache[key] = val
with suppress(redis.exceptions.ConnectionError):
self.set(name=key, value=pickle.dumps(val, protocol=DEFAULT_PICKLE_PROTOCOL), ex=expires_in_sec)
def get_value(self, key, generator=None, user=None, expires=False, shared=False, *, use_local_cache=True):
"""Return cache value. If not found and generator function is
given, call the generator.
:param key: Cache key.
:param generator: Function to be called to generate a value if `None` is returned.
:param expires: If the key is supposed to be with an expiry, don't store it in frappe.local
"""
original_key = key
key = self.make_key(key, user, shared)
local_cache = frappe.local.cache
if key in local_cache and use_local_cache:
val = local_cache[key]
else:
val = None
try:
val = self.get(key)
except redis.exceptions.ConnectionError:
pass
if val is not None:
val = pickle.loads(val)
if not expires:
if val is None and generator:
val = generator()
self.set_value(original_key, val, user=user)
else:
local_cache[key] = val
return val
def get_all(self, key):
ret = {}
for k in self.get_keys(key):
ret[key] = self.get_value(k)
return ret
def get_keys(self, key):
"""Return keys starting with `key`."""
try:
key = self.make_key(key + "*")
return self.keys(key)
except redis.exceptions.ConnectionError:
regex = re.compile(cstr(key).replace("|", r"\|").replace("*", r"[\w]*"))
return [k for k in list(frappe.local.cache) if regex.match(cstr(k))]
def delete_keys(self, key):
"""Delete keys with wildcard `*`."""
self.delete_value(self.get_keys(key), make_keys=False)
def delete_key(self, *args, **kwargs):
self.delete_value(*args, **kwargs)
def delete_value(self, keys, user=None, make_keys=True, shared=False):
"""Delete value, list of values."""
if not keys:
return
if not isinstance(keys, list | tuple):
keys = (keys,)
if make_keys:
keys = [self.make_key(k, shared=shared, user=user) for k in keys]
local_cache = frappe.local.cache
for key in keys:
local_cache.pop(key, None)
try:
self.unlink(*keys)
except redis.exceptions.ConnectionError:
pass
def lpush(self, key, value):
return super().lpush(self.make_key(key), value)
def rpush(self, key, value):
return super().rpush(self.make_key(key), value)
def lpop(self, key):
return super().lpop(self.make_key(key))
def rpop(self, key):
return super().rpop(self.make_key(key))
def llen(self, key):
return super().llen(self.make_key(key))
def lrange(self, key, start, stop):
return super().lrange(self.make_key(key), start, stop)
def ltrim(self, key, start, stop):
return super().ltrim(self.make_key(key), start, stop)
def hset(
self,
name: str,
key: str,
value,
shared: bool = False,
*args,
**kwargs,
):
if key is None:
return
_name = self.make_key(name, shared=shared)
# set in local
frappe.local.cache.setdefault(_name, {})[key] = value
# set in redis
try:
super().hset(_name, key, pickle.dumps(value, protocol=DEFAULT_PICKLE_PROTOCOL), *args, **kwargs)
except redis.exceptions.ConnectionError:
pass
def hexists(self, name: str, key: str, shared: bool = False) -> bool:
if key is None:
return False
_name = self.make_key(name, shared=shared)
try:
return super().hexists(_name, key)
except redis.exceptions.ConnectionError:
return False
def exists(self, *names: str, user=None, shared=None) -> int:
names = [self.make_key(n, user=user, shared=shared) for n in names]
try:
return super().exists(*names)
except redis.exceptions.ConnectionError:
return False
def hgetall(self, name):
value = super().hgetall(self.make_key(name))
return {key: pickle.loads(value) for key, value in value.items()}
def hget(self, name, key, generator=None, shared=False):
_name = self.make_key(name, shared=shared)
local_cache = frappe.local.cache
if _name not in local_cache:
local_cache[_name] = {}
if not key:
return None
if key in local_cache[_name]:
return local_cache[_name][key]
value = None
try:
value = super().hget(_name, key)
except redis.exceptions.ConnectionError:
pass
if value is not None:
value = pickle.loads(value)
local_cache[_name][key] = value
elif generator:
value = generator()
self.hset(name, key, value, shared=shared)
return value
def hdel(
self,
name: str,
keys: str | list | tuple,
shared=False,
pipeline: redis.client.Pipeline | None = None,
):
"""
A wrapper around redis' HDEL command
:param name: The hash name
:param keys: the keys to delete
:param shared: shared frappe key or not
:param pipeline: A redis.client.Pipeline object, if this transaction is to be run in a pipeline
"""
_name = self.make_key(name, shared=shared)
name_in_local_cache = _name in frappe.local.cache
if not isinstance(keys, list | tuple):
if name_in_local_cache and keys in frappe.local.cache[_name]:
del frappe.local.cache[_name][keys]
if pipeline:
pipeline.hdel(_name, keys)
else:
try:
super().hdel(_name, keys)
except redis.exceptions.ConnectionError:
pass
return
local_pipeline = False
if pipeline is None:
pipeline = self.pipeline()
local_pipeline = True
for key in keys:
if name_in_local_cache:
if key in frappe.local.cache[_name]:
del frappe.local.cache[_name][key]
pipeline.hdel(_name, key)
if local_pipeline:
try:
pipeline.execute()
except redis.exceptions.ConnectionError:
pass
def hdel_names(self, names: list | tuple, key: str):
"""
A function to call HDEL on multiple hash names with a common key, run in a single pipeline
:param names: The hash names
:param key: The common key
"""
pipeline = self.pipeline()
for name in names:
self.hdel(name, key, pipeline=pipeline)
try:
pipeline.execute()
except redis.exceptions.ConnectionError:
pass
def hdel_keys(self, name_starts_with, key):
"""Delete hash names with wildcard `*` and key"""
pipeline = self.pipeline()
for name in self.get_keys(name_starts_with):
name = name.split("|", 1)[1]
self.hdel(name, key, pipeline=pipeline)
try:
pipeline.execute()
except redis.exceptions.ConnectionError:
pass
def hkeys(self, name):
try:
return super().hkeys(self.make_key(name))
except redis.exceptions.ConnectionError:
return []
def sadd(self, name, *values):
"""Add a member/members to a given set"""
super().sadd(self.make_key(name), *values)
def srem(self, name, *values):
"""Remove a specific member/list of members from the set."""
super().srem(self.make_key(name), *values)
def sismember(self, name, value):
"""Return True or False based on if a given value is present in the set."""
return super().sismember(self.make_key(name), value)
def spop(self, name):
"""Remove and returns a random member from the set."""
return super().spop(self.make_key(name))
def srandmember(self, name, count=None):
"""Return a random member from the set."""
return super().srandmember(self.make_key(name))
def smembers(self, name):
"""Return all members of the set."""
return super().smembers(self.make_key(name))
def ft(self, index_name="idx"):
return RedisearchWrapper(client=self, index_name=self.make_key(index_name))
def setup_cache() -> RedisWrapper:
if frappe.conf.redis_cache_sentinel_enabled:
sentinels = [tuple(node.split(":")) for node in frappe.conf.get("redis_cache_sentinels", [])]
sentinel = get_sentinel_connection(
sentinels=sentinels,
sentinel_username=frappe.conf.get("redis_cache_sentinel_username"),
sentinel_password=frappe.conf.get("redis_cache_sentinel_password"),
master_username=frappe.conf.get("redis_cache_master_username"),
master_password=frappe.conf.get("redis_cache_master_password"),
)
return sentinel.master_for(
frappe.conf.get("redis_cache_master_service"),
redis_class=RedisWrapper,
)
return RedisWrapper.from_url(frappe.conf.get("redis_cache"))
def get_sentinel_connection(
sentinels: list[tuple[str, int]],
sentinel_username=None,
sentinel_password=None,
master_username=None,
master_password=None,
):
from redis.sentinel import Sentinel
sentinel_kwargs = {}
if sentinel_username:
sentinel_kwargs["username"] = sentinel_username
if sentinel_password:
sentinel_kwargs["password"] = sentinel_password
return Sentinel(
sentinels=sentinels,
sentinel_kwargs=sentinel_kwargs,
username=master_username,
password=master_password,
)
class _TrackedConnection(redis.Connection):
def __init__(self, *args, **kwargs):
self.monitor_id = kwargs.pop("_monitor_id")
super().__init__(*args, **kwargs)
# Every redis connection needs to enable client tracking to get notified about invalidated
# keys.
self.register_connect_callback(self._enable_client_tracking)
def _enable_client_tracking(self, conn):
conn.send_command("CLIENT", "TRACKING", "ON", "redirect", self.monitor_id, "NOLOOP")
conn.read_response()
_ClientCacheValue = tuple[typing.Any, int]
class _ClientCache:
def __init__(self, maxsize: int = 1024, ttl=10 * 60, monitor: RedisWrapper | None = None) -> None:
self.monitor = frappe.cache
self.monitor_id = self.monitor.client_id()
self.maxsize = maxsize or 1024 # Expect 1024 * 4kb objects ~ 4MB
self.local_ttl = ttl
self.expiration = time.monotonic() + ttl
self._lock = threading.RLock()
self.redis: RedisWrapper = RedisWrapper.from_url(
frappe.conf.get("redis_cache"),
connection_class=_TrackedConnection,
_monitor_id=self.monitor_id,
)
protocol = self.redis.get_connection_kwargs().get("protocol")
if cint(protocol) == 3:
frappe.throw("RESP3 is not supported while connecting to Redis.") # nosemgrep
self.invalidator_thread = self.run_invalidator_thread()
self.local_cache: dict[bytes, _ClientCacheValue] = {}
self.cache_healthy = True
self._conn_retries = 0
def get_value(self, key):
key = self.redis.make_key(key)
try:
val = self.local_cache[key]
if time.monotonic() < val[1] and self.cache_healthy:
return val[0]
except KeyError:
pass # cache miss
val = self.redis.get_value(key, shared=True, use_local_cache=False)
# Note: We should not "cache" the cache-misses in client cache.
# This cache is long lived and "misses" are not tracked by redis so they'll never get
# invalidated.
if val is None:
return None
self.ensure_max_size()
with self._lock:
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
return val
def set_value(self, key, val):
key = self.redis.make_key(key)
self.ensure_max_size()
self.redis.set_value(key, val, shared=True)
with self._lock:
self.local_cache[key] = (val, time.monotonic() + self.local_ttl)
# XXX: We need to tell redis that we indeed read this key we just wrote
# This is an edge case:
# - Client A writes a key and reads it again from local cache
# - Client B overwrites this key, but since client A never "read" it from Redis, Redis
# doesn't send invalidation.
_ = self.redis.get_value(key, shared=True, use_local_cache=False)
def ensure_max_size(self):
if len(self.local_cache) >= self.maxsize:
with self._lock, suppress(RuntimeError):
self.local_cache.pop(next(iter(self.local_cache)), None)
def delete_value(self, key):
key = self.redis.make_key(key)
self.redis.delete_value(key, shared=True)
with self._lock:
self.local_cache.pop(key, None)
def run_invalidator_thread(self):
self._watcher = self.monitor.pubsub()
self._watcher.subscribe(**{"__redis__:invalidate": self._handle_invalidation})
return self._watcher.run_in_thread(
sleep_time=None,
daemon=True,
exception_handler=self._exception_handler,
)
def _handle_invalidation(self, message):
if message["data"] is None:
# Flushall
self.clear_cache()
with self._lock:
for key in message["data"]:
self.local_cache.pop(key, None)
def _exception_handler(self, exc, pubsub, pubsub_thread):
if isinstance(exc, (redis.exceptions.ConnectionError)):
self.clear_cache()
self._conn_retries += 1
if self._conn_retries > 10:
self.cache_healthy = False
raise
time.sleep(1)
else:
self.cache_healthy = False
raise
def clear_cache(self):
with self._lock:
self.local_cache.clear()