feat: savepoint contextmanager

syntactic sugar around frappe.db.savepoint and rollback
This commit is contained in:
Ankush Menat 2022-01-03 21:38:34 +05:30
parent a817a4ac5e
commit e08b41964c
3 changed files with 61 additions and 1 deletions

View file

@ -4,6 +4,8 @@
# Database Module
# --------------------
from frappe.database.database import savepoint
def setup_database(force, source_sql=None, verbose=None, no_mariadb_socket=False):
import frappe
if frappe.conf.db_type == 'postgres':

View file

@ -6,11 +6,14 @@
import re
import time
from typing import Dict, List, Union
import string
import random
from typing import Dict, List, Union, Tuple, Optional
import frappe
import datetime
import frappe.defaults
import frappe.model.meta
from contextlib import contextmanager
from frappe import _
from time import time
@ -811,6 +814,9 @@ class Database(object):
Avoid using savepoints when writing to filesystem."""
self.sql(f"savepoint {save_point}")
def release_savepoint(self, save_point):
self.sql(f"release savepoint {save_point}")
def rollback(self, *, save_point=None):
"""`ROLLBACK` current transaction. Optionally rollback to a known save_point."""
if save_point:
@ -1097,3 +1103,28 @@ def enqueue_jobs_after_commit():
q.enqueue_call(execute_job, timeout=job.get("timeout"),
kwargs=job.get("queue_args"))
frappe.flags.enqueue_after_commit = []
@contextmanager
def savepoint(catch: Union[type, Tuple[type, ...]] = Exception):
""" Wrapper for wrapping blocks of DB operations in a savepoint.
as contextmanager:
for doc in docs:
with savepoint(catch=DuplicateError):
doc.insert()
as decorator (wraps FULL function call):
@savepoint(catch=DuplicateError)
def process_doc(doc):
doc.insert()
"""
try:
savepoint = ''.join(random.sample(string.ascii_lowercase, 10))
frappe.db.savepoint(savepoint)
yield # control back to calling function
except catch:
frappe.db.rollback(save_point=savepoint)
else:
frappe.db.release_savepoint(savepoint)

View file

@ -12,6 +12,7 @@ from frappe.custom.doctype.custom_field.custom_field import create_custom_field
from frappe.utils import random_string
from frappe.utils.testutils import clear_custom_fields
from frappe.query_builder import Field
from frappe.database import savepoint
from .test_query_builder import run_only_if, db_type_is
from frappe.query_builder.functions import Concat_ws
@ -267,6 +268,32 @@ class TestDB(unittest.TestCase):
for d in created_docs:
self.assertTrue(frappe.db.exists("ToDo", d))
def test_savepoints_wrapper(self):
frappe.db.rollback()
class SpecificExc(Exception):
pass
created_docs = []
failed_docs = []
for _ in range(5):
with savepoint(catch=SpecificExc):
doc_kept = frappe.get_doc(doctype="ToDo", description="nope").save()
created_docs.append(doc_kept.name)
with savepoint(catch=SpecificExc):
doc_gone = frappe.get_doc(doctype="ToDo", description="nope").save()
failed_docs.append(doc_gone.name)
raise SpecificExc
frappe.db.commit()
for d in failed_docs:
self.assertFalse(frappe.db.exists("ToDo", d))
for d in created_docs:
self.assertTrue(frappe.db.exists("ToDo", d))
@run_only_if(db_type_is.MARIADB)
class TestDDLCommandsMaria(unittest.TestCase):