diff --git a/frappe/database/__init__.py b/frappe/database/__init__.py index b0e3183d4f..7b26ac31b3 100644 --- a/frappe/database/__init__.py +++ b/frappe/database/__init__.py @@ -4,6 +4,8 @@ # Database Module # -------------------- +from frappe.database.database import savepoint + def setup_database(force, source_sql=None, verbose=None, no_mariadb_socket=False): import frappe if frappe.conf.db_type == 'postgres': diff --git a/frappe/database/database.py b/frappe/database/database.py index 7c147cd1d0..2cb3098f79 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -6,11 +6,14 @@ import re import time -from typing import Dict, List, Union +import string +import random +from typing import Dict, List, Union, Tuple, Optional import frappe import datetime import frappe.defaults import frappe.model.meta +from contextlib import contextmanager from frappe import _ from time import time @@ -811,6 +814,9 @@ class Database(object): Avoid using savepoints when writing to filesystem.""" self.sql(f"savepoint {save_point}") + def release_savepoint(self, save_point): + self.sql(f"release savepoint {save_point}") + def rollback(self, *, save_point=None): """`ROLLBACK` current transaction. Optionally rollback to a known save_point.""" if save_point: @@ -1097,3 +1103,28 @@ def enqueue_jobs_after_commit(): q.enqueue_call(execute_job, timeout=job.get("timeout"), kwargs=job.get("queue_args")) frappe.flags.enqueue_after_commit = [] + +@contextmanager +def savepoint(catch: Union[type, Tuple[type, ...]] = Exception): + """ Wrapper for wrapping blocks of DB operations in a savepoint. + + as contextmanager: + + for doc in docs: + with savepoint(catch=DuplicateError): + doc.insert() + + as decorator (wraps FULL function call): + + @savepoint(catch=DuplicateError) + def process_doc(doc): + doc.insert() + """ + try: + savepoint = ''.join(random.sample(string.ascii_lowercase, 10)) + frappe.db.savepoint(savepoint) + yield # control back to calling function + except catch: + frappe.db.rollback(save_point=savepoint) + else: + frappe.db.release_savepoint(savepoint) diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index dec55b4714..cdef4354ed 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -12,6 +12,7 @@ from frappe.custom.doctype.custom_field.custom_field import create_custom_field from frappe.utils import random_string from frappe.utils.testutils import clear_custom_fields from frappe.query_builder import Field +from frappe.database import savepoint from .test_query_builder import run_only_if, db_type_is from frappe.query_builder.functions import Concat_ws @@ -267,6 +268,32 @@ class TestDB(unittest.TestCase): for d in created_docs: self.assertTrue(frappe.db.exists("ToDo", d)) + def test_savepoints_wrapper(self): + frappe.db.rollback() + + class SpecificExc(Exception): + pass + + created_docs = [] + failed_docs = [] + + for _ in range(5): + with savepoint(catch=SpecificExc): + doc_kept = frappe.get_doc(doctype="ToDo", description="nope").save() + created_docs.append(doc_kept.name) + + with savepoint(catch=SpecificExc): + doc_gone = frappe.get_doc(doctype="ToDo", description="nope").save() + failed_docs.append(doc_gone.name) + raise SpecificExc + + frappe.db.commit() + + for d in failed_docs: + self.assertFalse(frappe.db.exists("ToDo", d)) + for d in created_docs: + self.assertTrue(frappe.db.exists("ToDo", d)) + @run_only_if(db_type_is.MARIADB) class TestDDLCommandsMaria(unittest.TestCase):