From ce360b6fcea2b10b0b2fa297746fc23595bb631f Mon Sep 17 00:00:00 2001 From: Ankush Menat Date: Sat, 5 Nov 2022 16:58:01 +0530 Subject: [PATCH] feat: Set default SQL statement timeouts --- frappe/database/database.py | 36 ++++++++++++++++++++++++++-- frappe/database/mariadb/database.py | 7 ++++++ frappe/database/postgres/database.py | 8 +++++++ frappe/tests/test_db.py | 21 ++++++++++++++++ frappe/utils/background_jobs.py | 1 - 5 files changed, 70 insertions(+), 3 deletions(-) diff --git a/frappe/database/database.py b/frappe/database/database.py index 3cb47e853a..64ef994a50 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -7,7 +7,7 @@ import random import re import string import traceback -from contextlib import contextmanager +from contextlib import contextmanager, suppress from time import time from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder @@ -29,7 +29,7 @@ from frappe.exceptions import DoesNotExistError, ImplicitCommitError from frappe.model.utils.link_count import flush_local_link_count from frappe.query_builder.functions import Count from frappe.utils import cast as cast_fieldtype -from frappe.utils import get_datetime, get_table_name, getdate, now, sbool +from frappe.utils import cint, get_datetime, get_table_name, getdate, now, sbool IFNULL_PATTERN = re.compile(r"ifnull\(", flags=re.IGNORECASE) INDEX_PATTERN = re.compile(r"\s*\([^)]+\)\s*") @@ -114,6 +114,17 @@ class Database: self._cursor = self._conn.cursor() frappe.local.rollback_observers = [] + try: + if execution_timeout := get_ideal_query_execution_timeout(): + self.set_execution_timeout(execution_timeout) + except Exception as e: + frappe.logger("database").warning(f"Couldn't set execution timeout {e}") + + def set_execution_timeout(self, seconds: int): + """Set session speicifc timeout on exeuction of statements. + If any statement takes more time it will be killed along with entire transaction.""" + raise NotImplementedError + def use(self, db_name): """`USE` db_name.""" self._conn.select_db(db_name) @@ -1340,3 +1351,24 @@ def savepoint(catch: type | tuple[type, ...] = Exception): frappe.db.rollback(save_point=savepoint) else: frappe.db.release_savepoint(savepoint) + + +def get_ideal_query_execution_timeout() -> int: + """Get execution timeout based on current timeout in contexts. + + HTTP requests: HTTP timeout or a default (300) + Background jobs: Job timeout + + Note: Timeout adds 1.5x as "safety factor" + """ + from rq import get_current_job + + # Zero means no timeout, which is the default value in db. + timeout = 0 + with suppress(Exception): + if hasattr(frappe.local, "request"): + timeout = frappe.conf.http_timeout or 300 + elif job := get_current_job(): + timeout = job.timeout + + return int(cint(timeout) * 1.5) diff --git a/frappe/database/mariadb/database.py b/frappe/database/mariadb/database.py index 1df9877eb1..322c355357 100644 --- a/frappe/database/mariadb/database.py +++ b/frappe/database/mariadb/database.py @@ -68,6 +68,10 @@ class MariaDBExceptionUtil: def is_syntax_error(e: pymysql.Error) -> bool: return e.args[0] == ER.PARSE_ERROR + @staticmethod + def is_statement_timeout(e: pymysql.Error) -> bool: + return e.args[0] == 1969 + @staticmethod def is_data_too_long(e: pymysql.Error) -> bool: return e.args[0] == ER.DATA_TOO_LONG @@ -102,6 +106,9 @@ class MariaDBConnectionUtil: def create_connection(self): return pymysql.connect(**self.get_connection_settings()) + def set_execution_timeout(self, seconds: int): + self.sql("set session max_statement_time = %s", int(seconds)) + def get_connection_settings(self) -> dict: conn_settings = { "host": self.host, diff --git a/frappe/database/postgres/database.py b/frappe/database/postgres/database.py index 3b3612c0e4..d082afceaf 100644 --- a/frappe/database/postgres/database.py +++ b/frappe/database/postgres/database.py @@ -99,6 +99,10 @@ class PostgresExceptionUtil: def is_duplicate_fieldname(e): return getattr(e, "pgcode", None) == DUPLICATE_COLUMN + @staticmethod + def is_statement_timeout(e): + return PostgresDatabase.is_timedout(e) or isinstance(e, frappe.QueryTimeoutError) + @staticmethod def is_data_too_long(e): return getattr(e, "pgcode", None) == STRING_DATA_RIGHT_TRUNCATION @@ -161,6 +165,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database): return conn + def set_execution_timeout(self, seconds: int): + # Postgres expects milliseconds as input + self.sql("set local statement_timeout = %s", int(seconds) * 1000) + def escape(self, s, percent=True): """Escape quotes and percent in given string.""" if isinstance(s, bytes): diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index 08fef66bd0..9a7d086252 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -36,6 +36,27 @@ class TestDB(FrappeTestCase): def test_get_database_size(self): self.assertIsInstance(frappe.db.get_database_size(), (float, int)) + def test_db_statement_execution_timeout(self): + frappe.db.set_execution_timeout(2) + # Setting 0 means no timeout. + self.addCleanup(frappe.db.set_execution_timeout, 0) + + try: + savepoint = "statement_timeout" + frappe.db.savepoint(savepoint) + frappe.db.multisql( + { + "mariadb": "select sleep(10)", + "postgres": "select pg_sleep(10)", + } + ) + except Exception as e: + self.assertTrue(frappe.db.is_statement_timeout(e), f"exepcted {e} to be timeout error") + frappe.db.rollback(save_point=savepoint) + else: + frappe.db.rollback(save_point=savepoint) + self.fail("Long running queries not timing out") + def test_get_value(self): self.assertEqual(frappe.db.get_value("User", {"name": ["=", "Administrator"]}), "Administrator") self.assertEqual(frappe.db.get_value("User", {"name": ["like", "Admin%"]}), "Administrator") diff --git a/frappe/utils/background_jobs.py b/frappe/utils/background_jobs.py index d416857588..12c2105df8 100755 --- a/frappe/utils/background_jobs.py +++ b/frappe/utils/background_jobs.py @@ -9,7 +9,6 @@ from uuid import uuid4 import redis from redis.exceptions import BusyLoadingError, ConnectionError from rq import Connection, Queue, Worker -from rq.command import send_stop_job_command from rq.logutils import setup_loghandlers from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed