diff --git a/frappe/database/database.py b/frappe/database/database.py index 70444acee2..b502e556f8 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -7,6 +7,7 @@ import random import re import string import traceback +import warnings from collections.abc import Iterable, Sequence from contextlib import contextmanager, suppress from time import time @@ -47,6 +48,11 @@ MULTI_WORD_PATTERN = re.compile(r'([`"])(tab([A-Z]\w+)( [A-Z]\w+)+)\1') SQL_ITERATOR_BATCH_SIZE = 100 +TRANSACTION_DISABLED_MSG = """Commit/rollback are disabled during certain events. This command will +be ignored. Commit/Rollback from here WILL CAUSE very hard to debug problems with atomicity and +concurrent data update bugs.""" + + class Database: """ Open a database connection with the given parmeters, if use_default is True, use the @@ -95,8 +101,12 @@ class Database: self.before_rollback = CallbackManager() self.after_rollback = CallbackManager() - # self.db_type: str - # self.last_query (lazy) attribute of last sql query executed + # Setting this to true will disable full rollback and commit + # You can still use savepoint with partial rollback. + self._disable_transaction_control = 0 + + # self.db_type: str + # self.last_query (lazy) attribute of last sql query executed def setup_type_map(self): pass @@ -1028,6 +1038,10 @@ class Database: def commit(self): """Commit current transaction. Calls SQL `COMMIT`.""" + if self._disable_transaction_control: + warnings.warn(message=TRANSACTION_DISABLED_MSG, stacklevel=2) + return + self.before_rollback.reset() self.after_rollback.reset() @@ -1042,7 +1056,7 @@ class Database: """`ROLLBACK` current transaction. Optionally rollback to a known save_point.""" if save_point: self.sql(f"rollback to savepoint {save_point}") - else: + elif not self._disable_transaction_control: self.before_commit.reset() self.after_commit.reset() @@ -1052,6 +1066,8 @@ class Database: self.begin() self.after_rollback.run() + else: + warnings.warn(message=TRANSACTION_DISABLED_MSG, stacklevel=2) def savepoint(self, save_point): """Savepoints work as a nested transaction. diff --git a/frappe/model/document.py b/frappe/model/document.py index e3b96c6745..3a1b375ce6 100644 --- a/frappe/model/document.py +++ b/frappe/model/document.py @@ -1305,7 +1305,11 @@ class Document(BaseDocument): def runner(self, method, *args, **kwargs): add_to_return_value(self, fn(self, *args, **kwargs)) for f in hooks: - add_to_return_value(self, f(self, method, *args, **kwargs)) + try: + frappe.db._disable_transaction_control += 1 + add_to_return_value(self, f(self, method, *args, **kwargs)) + finally: + frappe.db._disable_transaction_control -= 1 return self.__dict__.pop("_return_value", None) diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index f5000e3f58..6beb96860f 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -15,7 +15,7 @@ from frappe.database.utils import FallBackDateTimeStr from frappe.query_builder import Field from frappe.query_builder.functions import Concat_ws from frappe.tests.test_query_builder import db_type_is, run_only_if -from frappe.tests.utils import FrappeTestCase, timeout +from frappe.tests.utils import FrappeTestCase, patch_hooks, timeout from frappe.utils import add_days, now, random_string, set_request from frappe.utils.testutils import clear_custom_fields @@ -459,6 +459,19 @@ class TestDB(FrappeTestCase): ) self.assertEqual(1, frappe.db.transaction_writes - writes) + def test_transactions_disabled_during_writes(self): + hook_name = f"{bad_hook.__module__}.{bad_hook.__name__}" + nested_hook_name = f"{bad_nested_hook.__module__}.{bad_nested_hook.__name__}" + + with patch_hooks( + {"doc_events": {"*": {"before_validate": hook_name, "on_update": nested_hook_name}}} + ): + note = frappe.new_doc("Note", title=frappe.generate_hash()) + note.insert() + self.assertGreater(frappe.db.transaction_writes, 0) # This would've reset for commit/rollback + + self.assertFalse(frappe.db._disable_transaction_control) + def test_pk_collision_ignoring(self): # note has `name` generated from title for _ in range(3): @@ -1007,6 +1020,17 @@ class TestConcurrency(FrappeTestCase): self.assertRaises(frappe.QueryTimeoutError, frappe.delete_doc, note.doctype, note.name) +def bad_hook(*args, **kwargs): + frappe.db.commit() + frappe.db.rollback() + + +def bad_nested_hook(doc, *args, **kwargs): + doc.run_method("before_validate") + frappe.db.commit() + frappe.db.rollback() + + class TestSqlIterator(FrappeTestCase): def test_db_sql_iterator(self): test_queries = [