feat: savepoint contextmanager
syntactic sugar around frappe.db.savepoint and rollback
This commit is contained in:
parent
a817a4ac5e
commit
e08b41964c
3 changed files with 61 additions and 1 deletions
|
|
@ -4,6 +4,8 @@
|
|||
# Database Module
|
||||
# --------------------
|
||||
|
||||
from frappe.database.database import savepoint
|
||||
|
||||
def setup_database(force, source_sql=None, verbose=None, no_mariadb_socket=False):
|
||||
import frappe
|
||||
if frappe.conf.db_type == 'postgres':
|
||||
|
|
|
|||
|
|
@ -6,11 +6,14 @@
|
|||
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, List, Union
|
||||
import string
|
||||
import random
|
||||
from typing import Dict, List, Union, Tuple, Optional
|
||||
import frappe
|
||||
import datetime
|
||||
import frappe.defaults
|
||||
import frappe.model.meta
|
||||
from contextlib import contextmanager
|
||||
|
||||
from frappe import _
|
||||
from time import time
|
||||
|
|
@ -811,6 +814,9 @@ class Database(object):
|
|||
Avoid using savepoints when writing to filesystem."""
|
||||
self.sql(f"savepoint {save_point}")
|
||||
|
||||
def release_savepoint(self, save_point):
|
||||
self.sql(f"release savepoint {save_point}")
|
||||
|
||||
def rollback(self, *, save_point=None):
|
||||
"""`ROLLBACK` current transaction. Optionally rollback to a known save_point."""
|
||||
if save_point:
|
||||
|
|
@ -1097,3 +1103,28 @@ def enqueue_jobs_after_commit():
|
|||
q.enqueue_call(execute_job, timeout=job.get("timeout"),
|
||||
kwargs=job.get("queue_args"))
|
||||
frappe.flags.enqueue_after_commit = []
|
||||
|
||||
@contextmanager
|
||||
def savepoint(catch: Union[type, Tuple[type, ...]] = Exception):
|
||||
""" Wrapper for wrapping blocks of DB operations in a savepoint.
|
||||
|
||||
as contextmanager:
|
||||
|
||||
for doc in docs:
|
||||
with savepoint(catch=DuplicateError):
|
||||
doc.insert()
|
||||
|
||||
as decorator (wraps FULL function call):
|
||||
|
||||
@savepoint(catch=DuplicateError)
|
||||
def process_doc(doc):
|
||||
doc.insert()
|
||||
"""
|
||||
try:
|
||||
savepoint = ''.join(random.sample(string.ascii_lowercase, 10))
|
||||
frappe.db.savepoint(savepoint)
|
||||
yield # control back to calling function
|
||||
except catch:
|
||||
frappe.db.rollback(save_point=savepoint)
|
||||
else:
|
||||
frappe.db.release_savepoint(savepoint)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from frappe.custom.doctype.custom_field.custom_field import create_custom_field
|
|||
from frappe.utils import random_string
|
||||
from frappe.utils.testutils import clear_custom_fields
|
||||
from frappe.query_builder import Field
|
||||
from frappe.database import savepoint
|
||||
|
||||
from .test_query_builder import run_only_if, db_type_is
|
||||
from frappe.query_builder.functions import Concat_ws
|
||||
|
|
@ -267,6 +268,32 @@ class TestDB(unittest.TestCase):
|
|||
for d in created_docs:
|
||||
self.assertTrue(frappe.db.exists("ToDo", d))
|
||||
|
||||
def test_savepoints_wrapper(self):
|
||||
frappe.db.rollback()
|
||||
|
||||
class SpecificExc(Exception):
|
||||
pass
|
||||
|
||||
created_docs = []
|
||||
failed_docs = []
|
||||
|
||||
for _ in range(5):
|
||||
with savepoint(catch=SpecificExc):
|
||||
doc_kept = frappe.get_doc(doctype="ToDo", description="nope").save()
|
||||
created_docs.append(doc_kept.name)
|
||||
|
||||
with savepoint(catch=SpecificExc):
|
||||
doc_gone = frappe.get_doc(doctype="ToDo", description="nope").save()
|
||||
failed_docs.append(doc_gone.name)
|
||||
raise SpecificExc
|
||||
|
||||
frappe.db.commit()
|
||||
|
||||
for d in failed_docs:
|
||||
self.assertFalse(frappe.db.exists("ToDo", d))
|
||||
for d in created_docs:
|
||||
self.assertTrue(frappe.db.exists("ToDo", d))
|
||||
|
||||
|
||||
@run_only_if(db_type_is.MARIADB)
|
||||
class TestDDLCommandsMaria(unittest.TestCase):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue