From 6ce7444669c7462e26875bb6772838dab0df3cb5 Mon Sep 17 00:00:00 2001 From: Ankush Menat Date: Fri, 2 Jun 2023 22:29:05 +0530 Subject: [PATCH] refactor: generic callback manager --- frappe/database/database.py | 48 +++++-------------------------------- frappe/utils/__init__.py | 42 +++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 43 deletions(-) diff --git a/frappe/database/database.py b/frappe/database/database.py index e11e2fccbd..bd99f790f4 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -8,10 +8,9 @@ import random import re import string import traceback -from collections import deque from contextlib import contextmanager, suppress from time import time -from typing import Any, Callable, Iterable, Sequence +from typing import Any, Iterable, Sequence from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder from pypika.terms import Criterion, NullValue @@ -32,6 +31,7 @@ from frappe.database.utils import ( from frappe.exceptions import DoesNotExistError, ImplicitCommitError from frappe.model.utils.link_count import flush_local_link_count from frappe.query_builder.functions import Count +from frappe.utils import CallbackManager from frappe.utils import cast as cast_fieldtype from frappe.utils import cint, get_datetime, get_table_name, getdate, now, sbool from frappe.utils.deprecations import deprecated, deprecation_warning @@ -109,10 +109,10 @@ class Database: self.logger = frappe.logger("database") self.logger.setLevel("WARNING") - self.before_commit = DBHooks() - self.after_commit = DBHooks() - self.before_rollback = DBHooks() - self.after_rollback = DBHooks() + self.before_commit = CallbackManager() + self.after_commit = CallbackManager() + self.before_rollback = CallbackManager() + self.after_rollback = CallbackManager() # self.db_type: str # self.last_query (lazy) attribute of last sql query executed @@ -1310,42 +1310,6 @@ class Database: raise NotImplementedError -class DBHooks: - """Hooks for database events. - - Primarily used for doing things before/after commit/rollback. - - hook_manager = DBHooks() - - # Put a function call in queue - hook_manager.add(func) - - # Run all pending functions in queue - hook_manager.run() - - # Reset quue - hook_manager.reset() - """ - - __slots__ = ("_functions",) - - def __init__(self) -> None: - self._functions = deque() - - def add(self, func: Callable) -> None: - """Add a function to queue, functions are executed in order of addition.""" - self._functions.append(func) - - def run(self): - """Run all functions in queue""" - while self._functions: - _func = self._functions.popleft() - _func() - - def reset(self): - self._functions = deque() - - @contextmanager def savepoint(catch: type | tuple[type, ...] = Exception): """Wrapper for wrapping blocks of DB operations in a savepoint. diff --git a/frappe/utils/__init__.py b/frappe/utils/__init__.py index ef32ff5653..b7dc565555 100644 --- a/frappe/utils/__init__.py +++ b/frappe/utils/__init__.py @@ -9,6 +9,7 @@ import os import re import sys import traceback +from collections import deque from collections.abc import ( Container, Generator, @@ -20,7 +21,7 @@ from collections.abc import ( from email.header import decode_header, make_header from email.utils import formataddr, parseaddr from gzip import GzipFile -from typing import Any, Literal +from typing import Any, Callable, Literal from urllib.parse import quote, urlparse from redis.exceptions import ConnectionError @@ -1092,3 +1093,42 @@ def is_git_url(url: str) -> bool: # modified to allow without the tailing .git from https://github.com/jonschlinkert/is-git-url.git pattern = r"(?:git|ssh|https?|\w*@[-\w.]+):(\/\/)?(.*?)(\.git)?(\/?|\#[-\d\w._]+?)$" return bool(re.match(pattern, url)) + + +class CallbackManager: + """Manage callbacks. + + ``` + # Capture callacks + callbacks = CallbackManager() + + # Put a function call in queue + callbacks.add(func) + + # Run all pending functions in queue + callbacks.run() + + # Reset queue + callbacks.reset() + ``` + + Example usage: frappe.db.after_commit + """ + + __slots__ = ("_functions",) + + def __init__(self) -> None: + self._functions = deque() + + def add(self, func: Callable) -> None: + """Add a function to queue, functions are executed in order of addition.""" + self._functions.append(func) + + def run(self): + """Run all functions in queue""" + while self._functions: + _func = self._functions.popleft() + _func() + + def reset(self): + self._functions.clear()