refactor: Remember patched connections

This way if `frappe.db` changes we don't end up trying to unpatch the
wrong thing.
This commit is contained in:
Ankush Menat 2025-04-24 10:50:25 +05:30
parent 5d380e6f93
commit 26b1360c50

View file

@ -8,6 +8,7 @@ import json
import pstats
import re
import time
import typing
from collections import Counter
from collections.abc import Callable
from dataclasses import dataclass
@ -27,6 +28,10 @@ TRACEBACK_PATH_PATTERN = re.compile(".*/apps/")
RECORDER_AUTO_DISABLE = 5 * 60
if typing.TYPE_CHECKING:
from frappe.database.database import Database
@dataclass
class RecorderConfig:
record_requests: bool = True # Record web request
@ -192,7 +197,6 @@ class Recorder:
def __init__(self, force=False):
self.config = RecorderConfig.retrieve()
self.calls = []
self._patched_sql = False
self.profiler = None
self._recording = True
self.force = force
@ -200,6 +204,7 @@ class Recorder:
self.method = None
self.headers = None
self.form_dict = None
self.patched_databases = []
if (
self.config.record_requests
@ -229,8 +234,7 @@ class Recorder:
self.time = now_datetime()
if self.config.record_sql:
self._patch_sql()
self._patched_sql = True
self._patch_sql(frappe.db)
if self.config.profile:
self.profiler = cProfile.Profile()
@ -242,8 +246,7 @@ class Recorder:
def cleanup(self):
if self.profiler:
self.profiler.disable()
if self._patched_sql:
self._unpatch_sql()
self._unpatch_sql()
def process_profiler(self):
if self.config.profile or self.profiler:
@ -283,14 +286,14 @@ class Recorder:
if self.config.record_sql:
self._unpatch_sql()
@staticmethod
def _patch_sql():
def _patch_sql(self, db: "Database"):
frappe.db._sql = frappe.db.sql
frappe.db.sql = record_sql
self.patched_databases.append(db)
@staticmethod
def _unpatch_sql():
frappe.db.sql = frappe.db._sql
def _unpatch_sql(self):
for db in self.patched_databases:
db.sql = db._sql
def do_not_record(function):
@ -389,10 +392,11 @@ def record_queries(func: Callable):
@functools.wraps(func)
def wrapped(*args, **kwargs):
record(force=True)
frappe.local._recorder.path = f"Function call: {func.__module__}.{func.__qualname__}"
recorder = frappe.local._recorder
recorder.path = f"Function call: {func.__module__}.{func.__qualname__}"
ret = func(*args, **kwargs)
dump()
Recorder._unpatch_sql()
recorder._unpatch_sql()
post_process()
print("Recorded queries, open recorder to view them.")
return ret