From 961585f1d96215fc79fd7e6be6951243be98efc0 Mon Sep 17 00:00:00 2001 From: Ankush Menat Date: Mon, 28 Aug 2023 12:31:25 +0530 Subject: [PATCH] fix: misc dx improvemnts (#22188) * fix(dx): simplify adding callbacks frappe.db.after_commit(func) == frappe.db.after_commit.add(func) * fix: trace id missing DB gets initted before request --- frappe/database/database.py | 8 +++----- frappe/recorder.py | 4 +++- frappe/tests/test_db.py | 3 ++- frappe/tests/test_monitor.py | 16 ++++++++++++---- frappe/utils/__init__.py | 3 +++ 5 files changed, 23 insertions(+), 11 deletions(-) diff --git a/frappe/database/database.py b/frappe/database/database.py index ba5e6c891e..f790aa6c9b 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -97,10 +97,6 @@ class Database: self.before_rollback = CallbackManager() self.after_rollback = CallbackManager() - self._trace_comment = "" - if trace_id := get_trace_id(): - self._trace_comment = f" /* FRAPPE_TRACE_ID: {trace_id} */" - # self.db_type: str # self.last_query (lazy) attribute of last sql query executed @@ -213,7 +209,9 @@ class Database: values = (values,) query, values = self._transform_query(query, values) - query += self._trace_comment + + if trace_id := get_trace_id(): + query += f" /* FRAPPE_TRACE_ID: {trace_id} */" try: self._cursor.execute(query, values) diff --git a/frappe/recorder.py b/frappe/recorder.py index 2bc14e9f2f..68356c732e 100644 --- a/frappe/recorder.py +++ b/frappe/recorder.py @@ -71,7 +71,9 @@ def post_process(): for request in result: for call in request["calls"]: - formatted_query = sqlparse.format(call["query"].strip(), keyword_case="upper", reindent=True) + formatted_query = sqlparse.format( + call["query"].strip(), keyword_case="upper", reindent=True, strip_comments=True + ) call["query"] = formatted_query # Collect EXPLAIN for executed query diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index 76722fccc7..33927b2002 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -598,10 +598,11 @@ class TestDB(FrappeTestCase): frappe.db.before_rollback.add(lambda: f(5)) frappe.db.after_rollback.add(lambda: f(6)) frappe.db.after_rollback.add(lambda: f(7)) + frappe.db.after_rollback(lambda: f(8)) frappe.db.rollback() - self.assertEqual(order_of_execution, list(range(0, 8))) + self.assertEqual(order_of_execution, list(range(0, 9))) @run_only_if(db_type_is.MARIADB) diff --git a/frappe/tests/test_monitor.py b/frappe/tests/test_monitor.py index 74c8c07b9f..ef2854515e 100644 --- a/frappe/tests/test_monitor.py +++ b/frappe/tests/test_monitor.py @@ -3,7 +3,7 @@ import frappe import frappe.monitor -from frappe.monitor import MONITOR_REDIS_KEY +from frappe.monitor import MONITOR_REDIS_KEY, get_trace_id from frappe.tests.utils import FrappeTestCase from frappe.utils import set_request from frappe.utils.response import build_response @@ -14,6 +14,10 @@ class TestMonitor(FrappeTestCase): frappe.conf.monitor = 1 frappe.cache.delete_value(MONITOR_REDIS_KEY) + def tearDown(self): + frappe.conf.monitor = 0 + frappe.cache.delete_value(MONITOR_REDIS_KEY) + def test_enable_monitor(self): set_request(method="GET", path="/api/method/frappe.ping") response = build_response("json") @@ -77,6 +81,10 @@ class TestMonitor(FrappeTestCase): log = frappe.parse_json(logs[0]) self.assertEqual(log.transaction_type, "request") - def tearDown(self): - frappe.conf.monitor = 0 - frappe.cache.delete_value(MONITOR_REDIS_KEY) + def test_trace_ids(self): + set_request(method="GET", path="/api/method/frappe.ping") + response = build_response("json") + frappe.monitor.start() + frappe.db.sql("select 1") + self.assertIn(get_trace_id(), str(frappe.db.last_query)) + frappe.monitor.stop(response) diff --git a/frappe/utils/__init__.py b/frappe/utils/__init__.py index 00f2a8726c..7ed399e759 100644 --- a/frappe/utils/__init__.py +++ b/frappe/utils/__init__.py @@ -1131,6 +1131,9 @@ class CallbackManager: """Add a function to queue, functions are executed in order of addition.""" self._functions.append(func) + def __call__(self, func: Callable) -> None: + self.add(func) + def run(self): """Run all functions in queue""" while self._functions: