refactor: generic callback manager
This commit is contained in:
parent
680cf73cba
commit
6ce7444669
2 changed files with 47 additions and 43 deletions
|
|
@ -8,10 +8,9 @@ import random
|
|||
import re
|
||||
import string
|
||||
import traceback
|
||||
from collections import deque
|
||||
from contextlib import contextmanager, suppress
|
||||
from time import time
|
||||
from typing import Any, Callable, Iterable, Sequence
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
|
||||
from pypika.terms import Criterion, NullValue
|
||||
|
|
@ -32,6 +31,7 @@ from frappe.database.utils import (
|
|||
from frappe.exceptions import DoesNotExistError, ImplicitCommitError
|
||||
from frappe.model.utils.link_count import flush_local_link_count
|
||||
from frappe.query_builder.functions import Count
|
||||
from frappe.utils import CallbackManager
|
||||
from frappe.utils import cast as cast_fieldtype
|
||||
from frappe.utils import cint, get_datetime, get_table_name, getdate, now, sbool
|
||||
from frappe.utils.deprecations import deprecated, deprecation_warning
|
||||
|
|
@ -109,10 +109,10 @@ class Database:
|
|||
self.logger = frappe.logger("database")
|
||||
self.logger.setLevel("WARNING")
|
||||
|
||||
self.before_commit = DBHooks()
|
||||
self.after_commit = DBHooks()
|
||||
self.before_rollback = DBHooks()
|
||||
self.after_rollback = DBHooks()
|
||||
self.before_commit = CallbackManager()
|
||||
self.after_commit = CallbackManager()
|
||||
self.before_rollback = CallbackManager()
|
||||
self.after_rollback = CallbackManager()
|
||||
|
||||
# self.db_type: str
|
||||
# self.last_query (lazy) attribute of last sql query executed
|
||||
|
|
@ -1310,42 +1310,6 @@ class Database:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class DBHooks:
|
||||
"""Hooks for database events.
|
||||
|
||||
Primarily used for doing things before/after commit/rollback.
|
||||
|
||||
hook_manager = DBHooks()
|
||||
|
||||
# Put a function call in queue
|
||||
hook_manager.add(func)
|
||||
|
||||
# Run all pending functions in queue
|
||||
hook_manager.run()
|
||||
|
||||
# Reset quue
|
||||
hook_manager.reset()
|
||||
"""
|
||||
|
||||
__slots__ = ("_functions",)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._functions = deque()
|
||||
|
||||
def add(self, func: Callable) -> None:
|
||||
"""Add a function to queue, functions are executed in order of addition."""
|
||||
self._functions.append(func)
|
||||
|
||||
def run(self):
|
||||
"""Run all functions in queue"""
|
||||
while self._functions:
|
||||
_func = self._functions.popleft()
|
||||
_func()
|
||||
|
||||
def reset(self):
|
||||
self._functions = deque()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def savepoint(catch: type | tuple[type, ...] = Exception):
|
||||
"""Wrapper for wrapping blocks of DB operations in a savepoint.
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import os
|
|||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from collections import deque
|
||||
from collections.abc import (
|
||||
Container,
|
||||
Generator,
|
||||
|
|
@ -20,7 +21,7 @@ from collections.abc import (
|
|||
from email.header import decode_header, make_header
|
||||
from email.utils import formataddr, parseaddr
|
||||
from gzip import GzipFile
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Callable, Literal
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
from redis.exceptions import ConnectionError
|
||||
|
|
@ -1092,3 +1093,42 @@ def is_git_url(url: str) -> bool:
|
|||
# modified to allow without the tailing .git from https://github.com/jonschlinkert/is-git-url.git
|
||||
pattern = r"(?:git|ssh|https?|\w*@[-\w.]+):(\/\/)?(.*?)(\.git)?(\/?|\#[-\d\w._]+?)$"
|
||||
return bool(re.match(pattern, url))
|
||||
|
||||
|
||||
class CallbackManager:
|
||||
"""Manage callbacks.
|
||||
|
||||
```
|
||||
# Capture callacks
|
||||
callbacks = CallbackManager()
|
||||
|
||||
# Put a function call in queue
|
||||
callbacks.add(func)
|
||||
|
||||
# Run all pending functions in queue
|
||||
callbacks.run()
|
||||
|
||||
# Reset queue
|
||||
callbacks.reset()
|
||||
```
|
||||
|
||||
Example usage: frappe.db.after_commit
|
||||
"""
|
||||
|
||||
__slots__ = ("_functions",)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._functions = deque()
|
||||
|
||||
def add(self, func: Callable) -> None:
|
||||
"""Add a function to queue, functions are executed in order of addition."""
|
||||
self._functions.append(func)
|
||||
|
||||
def run(self):
|
||||
"""Run all functions in queue"""
|
||||
while self._functions:
|
||||
_func = self._functions.popleft()
|
||||
_func()
|
||||
|
||||
def reset(self):
|
||||
self._functions.clear()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue