feat: Set default SQL statement timeouts
This commit is contained in:
parent
b2860e6f9e
commit
ce360b6fce
5 changed files with 70 additions and 3 deletions
|
|
@ -7,7 +7,7 @@ import random
|
|||
import re
|
||||
import string
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, suppress
|
||||
from time import time
|
||||
|
||||
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
|
||||
|
|
@ -29,7 +29,7 @@ from frappe.exceptions import DoesNotExistError, ImplicitCommitError
|
|||
from frappe.model.utils.link_count import flush_local_link_count
|
||||
from frappe.query_builder.functions import Count
|
||||
from frappe.utils import cast as cast_fieldtype
|
||||
from frappe.utils import get_datetime, get_table_name, getdate, now, sbool
|
||||
from frappe.utils import cint, get_datetime, get_table_name, getdate, now, sbool
|
||||
|
||||
IFNULL_PATTERN = re.compile(r"ifnull\(", flags=re.IGNORECASE)
|
||||
INDEX_PATTERN = re.compile(r"\s*\([^)]+\)\s*")
|
||||
|
|
@ -114,6 +114,17 @@ class Database:
|
|||
self._cursor = self._conn.cursor()
|
||||
frappe.local.rollback_observers = []
|
||||
|
||||
try:
|
||||
if execution_timeout := get_ideal_query_execution_timeout():
|
||||
self.set_execution_timeout(execution_timeout)
|
||||
except Exception as e:
|
||||
frappe.logger("database").warning(f"Couldn't set execution timeout {e}")
|
||||
|
||||
def set_execution_timeout(self, seconds: int):
|
||||
"""Set session speicifc timeout on exeuction of statements.
|
||||
If any statement takes more time it will be killed along with entire transaction."""
|
||||
raise NotImplementedError
|
||||
|
||||
def use(self, db_name):
|
||||
"""`USE` db_name."""
|
||||
self._conn.select_db(db_name)
|
||||
|
|
@ -1340,3 +1351,24 @@ def savepoint(catch: type | tuple[type, ...] = Exception):
|
|||
frappe.db.rollback(save_point=savepoint)
|
||||
else:
|
||||
frappe.db.release_savepoint(savepoint)
|
||||
|
||||
|
||||
def get_ideal_query_execution_timeout() -> int:
|
||||
"""Get execution timeout based on current timeout in contexts.
|
||||
|
||||
HTTP requests: HTTP timeout or a default (300)
|
||||
Background jobs: Job timeout
|
||||
|
||||
Note: Timeout adds 1.5x as "safety factor"
|
||||
"""
|
||||
from rq import get_current_job
|
||||
|
||||
# Zero means no timeout, which is the default value in db.
|
||||
timeout = 0
|
||||
with suppress(Exception):
|
||||
if hasattr(frappe.local, "request"):
|
||||
timeout = frappe.conf.http_timeout or 300
|
||||
elif job := get_current_job():
|
||||
timeout = job.timeout
|
||||
|
||||
return int(cint(timeout) * 1.5)
|
||||
|
|
|
|||
|
|
@ -68,6 +68,10 @@ class MariaDBExceptionUtil:
|
|||
def is_syntax_error(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.PARSE_ERROR
|
||||
|
||||
@staticmethod
|
||||
def is_statement_timeout(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == 1969
|
||||
|
||||
@staticmethod
|
||||
def is_data_too_long(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.DATA_TOO_LONG
|
||||
|
|
@ -102,6 +106,9 @@ class MariaDBConnectionUtil:
|
|||
def create_connection(self):
|
||||
return pymysql.connect(**self.get_connection_settings())
|
||||
|
||||
def set_execution_timeout(self, seconds: int):
|
||||
self.sql("set session max_statement_time = %s", int(seconds))
|
||||
|
||||
def get_connection_settings(self) -> dict:
|
||||
conn_settings = {
|
||||
"host": self.host,
|
||||
|
|
|
|||
|
|
@ -99,6 +99,10 @@ class PostgresExceptionUtil:
|
|||
def is_duplicate_fieldname(e):
|
||||
return getattr(e, "pgcode", None) == DUPLICATE_COLUMN
|
||||
|
||||
@staticmethod
|
||||
def is_statement_timeout(e):
|
||||
return PostgresDatabase.is_timedout(e) or isinstance(e, frappe.QueryTimeoutError)
|
||||
|
||||
@staticmethod
|
||||
def is_data_too_long(e):
|
||||
return getattr(e, "pgcode", None) == STRING_DATA_RIGHT_TRUNCATION
|
||||
|
|
@ -161,6 +165,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
|
||||
return conn
|
||||
|
||||
def set_execution_timeout(self, seconds: int):
|
||||
# Postgres expects milliseconds as input
|
||||
self.sql("set local statement_timeout = %s", int(seconds) * 1000)
|
||||
|
||||
def escape(self, s, percent=True):
|
||||
"""Escape quotes and percent in given string."""
|
||||
if isinstance(s, bytes):
|
||||
|
|
|
|||
|
|
@ -36,6 +36,27 @@ class TestDB(FrappeTestCase):
|
|||
def test_get_database_size(self):
|
||||
self.assertIsInstance(frappe.db.get_database_size(), (float, int))
|
||||
|
||||
def test_db_statement_execution_timeout(self):
|
||||
frappe.db.set_execution_timeout(2)
|
||||
# Setting 0 means no timeout.
|
||||
self.addCleanup(frappe.db.set_execution_timeout, 0)
|
||||
|
||||
try:
|
||||
savepoint = "statement_timeout"
|
||||
frappe.db.savepoint(savepoint)
|
||||
frappe.db.multisql(
|
||||
{
|
||||
"mariadb": "select sleep(10)",
|
||||
"postgres": "select pg_sleep(10)",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
self.assertTrue(frappe.db.is_statement_timeout(e), f"exepcted {e} to be timeout error")
|
||||
frappe.db.rollback(save_point=savepoint)
|
||||
else:
|
||||
frappe.db.rollback(save_point=savepoint)
|
||||
self.fail("Long running queries not timing out")
|
||||
|
||||
def test_get_value(self):
|
||||
self.assertEqual(frappe.db.get_value("User", {"name": ["=", "Administrator"]}), "Administrator")
|
||||
self.assertEqual(frappe.db.get_value("User", {"name": ["like", "Admin%"]}), "Administrator")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from uuid import uuid4
|
|||
import redis
|
||||
from redis.exceptions import BusyLoadingError, ConnectionError
|
||||
from rq import Connection, Queue, Worker
|
||||
from rq.command import send_stop_job_command
|
||||
from rq.logutils import setup_loghandlers
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue