diff --git a/frappe/tests/test_caching.py b/frappe/tests/test_caching.py index d1de587d0d..4faade331c 100644 --- a/frappe/tests/test_caching.py +++ b/frappe/tests/test_caching.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import frappe from frappe.tests.test_api import FrappeAPITestCase from frappe.tests.utils import FrappeTestCase -from frappe.utils.caching import request_cache, site_cache +from frappe.utils.caching import redis_cache, request_cache, site_cache CACHE_TTL = 4 external_service = MagicMock(return_value=30) @@ -82,13 +82,84 @@ class TestSiteCache(FrappeAPITestCase): api_with_ttl = f"{module}.ping_with_ttl" api_without_ttl = f"{module}.ping" - start = time.monotonic() for _ in range(5): self.get(f"/api/method/{api_with_ttl}") self.get(f"/api/method/{api_without_ttl}") - end = time.monotonic() self.assertEqual(register_with_external_service.call_count, 2) - time.sleep(CACHE_TTL - (end - start)) + time.sleep(CACHE_TTL) self.get(f"/api/method/{api_with_ttl}") self.assertEqual(register_with_external_service.call_count, 3) + + +class TestRedisCache(FrappeAPITestCase): + def test_redis_cache(self): + function_call_count = 0 + + @redis_cache(ttl=CACHE_TTL) + def calculate_area(radius: float) -> float: + nonlocal function_call_count + function_call_count += 1 + return 3.14 * radius**2 + + self.assertEqual(calculate_area(10), 314) + self.assertEqual(function_call_count, 1) + self.assertEqual(calculate_area(10), 314) + self.assertEqual(function_call_count, 1) + + time.sleep(CACHE_TTL) + self.assertEqual(calculate_area(10), 314) + self.assertEqual(function_call_count, 2) + + calculate_area.clear_cache() + self.assertEqual(calculate_area(10), 314) + self.assertEqual(function_call_count, 3) + calculate_area.clear_cache() + + def test_redis_cache_without_params(self): + function_call_count = 0 + + @redis_cache + def calculate_area(radius: float) -> float: + nonlocal function_call_count + function_call_count += 1 + return 3.14 * radius**2 + + calculate_area.clear_cache() + self.assertEqual(calculate_area(10), 314) + self.assertEqual(function_call_count, 1) + + calculate_area.clear_cache() + self.assertEqual(calculate_area(10), 314) + self.assertEqual(function_call_count, 2) + + calculate_area.clear_cache() + + def test_redis_cache_diff_args(self): + function_call_count = 0 + + @redis_cache(ttl=CACHE_TTL) + def calculate_area(radius: float) -> float: + nonlocal function_call_count + function_call_count += 1 + return 3.14 * radius**2 + + self.assertEqual(calculate_area(10), 314) + self.assertEqual(function_call_count, 1) + self.assertEqual(calculate_area(100), 31400) + self.assertEqual(function_call_count, 2) + + self.assertEqual(calculate_area(5), 25 * 3.14) + self.assertEqual(function_call_count, 3) + + calculate_area(10) + # from cache now + self.assertEqual(function_call_count, 3) + + calculate_area(radius=10) + # args, kwargs are treated differently + self.assertEqual(function_call_count, 4) + + calculate_area(radius=10) + # kwargs should hit cache too + self.assertEqual(function_call_count, 4) diff --git a/frappe/utils/caching.py b/frappe/utils/caching.py index a2c9496098..007582f25f 100644 --- a/frappe/utils/caching.py +++ b/frappe/utils/caching.py @@ -128,3 +128,39 @@ def site_cache(ttl: int | None = None, maxsize: int | None = None) -> Callable: return time_cache_wrapper(ttl) return time_cache_wrapper + + +def redis_cache(ttl: int | None = 3600, user: str | bool | None = None) -> Callable: + """Decorator to cache method calls and its return values in Redis + + args: + ttl: time to expiry in seconds, defaults to 1 hour + user: `true` should cache be specific to session user. + """ + + def wrapper(func: Callable = None) -> Callable: + + func_key = f"{func.__module__}.{func.__qualname__}" + + def clear_cache(): + frappe.cache().delete_keys(func_key) + + func.clear_cache = clear_cache + func.ttl = ttl if not callable(ttl) else 3600 + + @wraps(func) + def redis_cache_wrapper(*args, **kwargs): + func_call_key = func_key + str(__generate_request_cache_key(args, kwargs)) + if frappe.cache().exists(func_call_key): + return frappe.cache().get_value(func_call_key, user=user) + else: + val = func(*args, **kwargs) + ttl = getattr(func, "ttl", 3600) + frappe.cache().set_value(func_call_key, val, expires_in_sec=ttl, user=user) + return val + + return redis_cache_wrapper + + if callable(ttl): + return wrapper(ttl) + return wrapper diff --git a/frappe/utils/redis_wrapper.py b/frappe/utils/redis_wrapper.py index ea91299cfc..3b335b2c1d 100644 --- a/frappe/utils/redis_wrapper.py +++ b/frappe/utils/redis_wrapper.py @@ -195,6 +195,10 @@ class RedisWrapper(redis.Redis): except redis.exceptions.ConnectionError: return False + def exists(self, *names: str, user=None, shared=None) -> int: + names = [self.make_key(n, user=user, shared=shared) for n in names] + return super().exists(*names) + def hgetall(self, name): value = super().hgetall(self.make_key(name)) return {key: pickle.loads(value) for key, value in value.items()}