refactor: generic callback manager

This commit is contained in:
Ankush Menat 2023-06-02 22:29:05 +05:30 committed by Ankush Menat
parent 680cf73cba
commit 6ce7444669
2 changed files with 47 additions and 43 deletions

View file

@ -8,10 +8,9 @@ import random
import re
import string
import traceback
from collections import deque
from contextlib import contextmanager, suppress
from time import time
from typing import Any, Callable, Iterable, Sequence
from typing import Any, Iterable, Sequence
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
from pypika.terms import Criterion, NullValue
@ -32,6 +31,7 @@ from frappe.database.utils import (
from frappe.exceptions import DoesNotExistError, ImplicitCommitError
from frappe.model.utils.link_count import flush_local_link_count
from frappe.query_builder.functions import Count
from frappe.utils import CallbackManager
from frappe.utils import cast as cast_fieldtype
from frappe.utils import cint, get_datetime, get_table_name, getdate, now, sbool
from frappe.utils.deprecations import deprecated, deprecation_warning
@ -109,10 +109,10 @@ class Database:
self.logger = frappe.logger("database")
self.logger.setLevel("WARNING")
self.before_commit = DBHooks()
self.after_commit = DBHooks()
self.before_rollback = DBHooks()
self.after_rollback = DBHooks()
self.before_commit = CallbackManager()
self.after_commit = CallbackManager()
self.before_rollback = CallbackManager()
self.after_rollback = CallbackManager()
# self.db_type: str
# self.last_query (lazy) attribute of last sql query executed
@ -1310,42 +1310,6 @@ class Database:
raise NotImplementedError
class DBHooks:
"""Hooks for database events.
Primarily used for doing things before/after commit/rollback.
hook_manager = DBHooks()
# Put a function call in queue
hook_manager.add(func)
# Run all pending functions in queue
hook_manager.run()
# Reset quue
hook_manager.reset()
"""
__slots__ = ("_functions",)
def __init__(self) -> None:
self._functions = deque()
def add(self, func: Callable) -> None:
"""Add a function to queue, functions are executed in order of addition."""
self._functions.append(func)
def run(self):
"""Run all functions in queue"""
while self._functions:
_func = self._functions.popleft()
_func()
def reset(self):
self._functions = deque()
@contextmanager
def savepoint(catch: type | tuple[type, ...] = Exception):
"""Wrapper for wrapping blocks of DB operations in a savepoint.

View file

@ -9,6 +9,7 @@ import os
import re
import sys
import traceback
from collections import deque
from collections.abc import (
Container,
Generator,
@ -20,7 +21,7 @@ from collections.abc import (
from email.header import decode_header, make_header
from email.utils import formataddr, parseaddr
from gzip import GzipFile
from typing import Any, Literal
from typing import Any, Callable, Literal
from urllib.parse import quote, urlparse
from redis.exceptions import ConnectionError
@ -1092,3 +1093,42 @@ def is_git_url(url: str) -> bool:
# modified to allow without the tailing .git from https://github.com/jonschlinkert/is-git-url.git
pattern = r"(?:git|ssh|https?|\w*@[-\w.]+):(\/\/)?(.*?)(\.git)?(\/?|\#[-\d\w._]+?)$"
return bool(re.match(pattern, url))
class CallbackManager:
"""Manage callbacks.
```
# Capture callacks
callbacks = CallbackManager()
# Put a function call in queue
callbacks.add(func)
# Run all pending functions in queue
callbacks.run()
# Reset queue
callbacks.reset()
```
Example usage: frappe.db.after_commit
"""
__slots__ = ("_functions",)
def __init__(self) -> None:
self._functions = deque()
def add(self, func: Callable) -> None:
"""Add a function to queue, functions are executed in order of addition."""
self._functions.append(func)
def run(self):
"""Run all functions in queue"""
while self._functions:
_func = self._functions.popleft()
_func()
def reset(self):
self._functions.clear()