From 006ebcbedeb41fb21babf0a6c11a709a0a3525eb Mon Sep 17 00:00:00 2001 From: Gavin D'souza Date: Thu, 21 Jul 2022 11:17:02 +0530 Subject: [PATCH] refactor: Use pymysql over mariadb client This is supposed to be a temporary switch to make the parent PR easier to digest. MariaDB client has some issues with release, and system dependencies. This commit may be reverted to enable mariadb client again. --- frappe/__init__.py | 1 - frappe/commands/site.py | 2 - frappe/database/mariadb/database.py | 253 ++++++---------------------- frappe/utils/bench_helper.py | 2 - pyproject.toml | 1 - 5 files changed, 55 insertions(+), 204 deletions(-) diff --git a/frappe/__init__.py b/frappe/__init__.py index 3d3288c16d..20fddb0267 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -48,7 +48,6 @@ __title__ = "Frappe Framework" controllers = {} local = Local() STANDARD_USERS = ("Guest", "Administrator") -DISABLE_DATABASE_CONNECTION_POOLING = None _dev_server = int(sbool(os.environ.get("DEV_SERVER", False))) _qb_patched = {} diff --git a/frappe/commands/site.py b/frappe/commands/site.py index 62451b6013..e3c7de32a3 100644 --- a/frappe/commands/site.py +++ b/frappe/commands/site.py @@ -69,8 +69,6 @@ def new_site( "Create a new site" from frappe.installer import _new_site - frappe.DISABLE_DATABASE_CONNECTION_POOLING = True - frappe.init(site=site, new_site=True) _new_site( diff --git a/frappe/database/mariadb/database.py b/frappe/database/mariadb/database.py index 71317b5884..e89168194e 100644 --- a/frappe/database/mariadb/database.py +++ b/frappe/database/mariadb/database.py @@ -1,114 +1,87 @@ import re -from collections import defaultdict -from decimal import Decimal -from typing import TYPE_CHECKING -import mariadb -from mariadb.constants import ERR, FIELD_TYPE -from pymysql.converters import escape_sequence, escape_string +import pymysql +from pymysql.constants import ER, FIELD_TYPE +from pymysql.converters import conversions, escape_string import frappe -from frappe.database.database import Database, QueryValues +from frappe.database.database import Database from frappe.database.mariadb.schema import MariaDBTable -from frappe.utils import UnicodeWithAttrs, get_datetime, get_table_name +from frappe.utils import UnicodeWithAttrs, cstr, get_datetime, get_table_name -if TYPE_CHECKING: - from mariadb import ConnectionPool - -_FIND_ITER_PATTERN = re.compile("%s") _PARAM_COMP = re.compile(r"%\([\w]*\)s") -_SITE_POOLS = defaultdict(frappe._dict) -_MAX_POOL_SIZE = 64 -_POOL_SIZE = 1 - -# _POOL_SIZE is selected "arbitrarily" to avoid overloading the server and being mindful of multitenancy -# init size of connection pool will be _POOL_SIZE for each site. Replica setups will have separate pool. -# This means each site with a replica setup can have 2 active pools of size _POOL_SIZE each. Each pool may -# expand up to _MAX_POOL_SIZE as per requirement. This cannot be a function of @@global.max_connections, -# no. of sites since there may be multiple processes holding connections; and this defines the size for each -# of those processes/workers. Check MariaDBConnectionUtil for connection & pool management. - - -def is_connection_pooling_enabled() -> bool: - """Set `frappe.DISABLE_CONNECTION_POOLING` to enable/disable connection pooling for all on current - process. This will override config key `disable_database_connection_pooling`. Set key - `disable_database_connection_pooling` in site config for persistent settings across workers.""" - - if frappe.DISABLE_DATABASE_CONNECTION_POOLING is not None: - return not frappe.DISABLE_DATABASE_CONNECTION_POOLING - return not frappe.local.conf.disable_database_connection_pooling class MariaDBExceptionUtil: - ProgrammingError = mariadb.ProgrammingError - TableMissingError = mariadb.ProgrammingError - OperationalError = mariadb.OperationalError - InternalError = mariadb.InternalError - SQLError = mariadb.ProgrammingError - DataError = mariadb.DataError + ProgrammingError = pymysql.ProgrammingError + TableMissingError = pymysql.ProgrammingError + OperationalError = pymysql.OperationalError + InternalError = pymysql.InternalError + SQLError = pymysql.ProgrammingError + DataError = pymysql.DataError # match ER_SEQUENCE_RUN_OUT - https://mariadb.com/kb/en/mariadb-error-codes/ - SequenceGeneratorLimitExceeded = mariadb.OperationalError + SequenceGeneratorLimitExceeded = pymysql.OperationalError SequenceGeneratorLimitExceeded.errno = 4084 @staticmethod - def is_deadlocked(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_LOCK_DEADLOCK + def is_deadlocked(e: pymysql.Error) -> bool: + return e.args[0] == ER.LOCK_DEADLOCK @staticmethod - def is_timedout(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_LOCK_WAIT_TIMEOUT + def is_timedout(e: pymysql.Error) -> bool: + return e.args[0] == ER.LOCK_WAIT_TIMEOUT @staticmethod - def is_table_missing(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_NO_SUCH_TABLE + def is_table_missing(e: pymysql.Error) -> bool: + return e.args[0] == ER.NO_SUCH_TABLE @staticmethod - def is_missing_table(e: mariadb.Error) -> bool: + def is_missing_table(e: pymysql.Error) -> bool: return MariaDBDatabase.is_table_missing(e) @staticmethod - def is_missing_column(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_BAD_FIELD_ERROR + def is_missing_column(e: pymysql.Error) -> bool: + return e.args[0] == ER.BAD_FIELD_ERROR @staticmethod - def is_duplicate_fieldname(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_DUP_FIELDNAME + def is_duplicate_fieldname(e: pymysql.Error) -> bool: + return e.args[0] == ER.DUP_FIELDNAME @staticmethod - def is_duplicate_entry(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_DUP_ENTRY + def is_duplicate_entry(e: pymysql.Error) -> bool: + return e.args[0] == ER.DUP_ENTRY @staticmethod - def is_access_denied(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_ACCESS_DENIED_ERROR + def is_access_denied(e: pymysql.Error) -> bool: + return e.args[0] == ER.ACCESS_DENIED_ERROR @staticmethod - def cant_drop_field_or_key(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_CANT_DROP_FIELD_OR_KEY + def cant_drop_field_or_key(e: pymysql.Error) -> bool: + return e.args[0] == ER.CANT_DROP_FIELD_OR_KEY @staticmethod - def is_syntax_error(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_PARSE_ERROR + def is_syntax_error(e: pymysql.Error) -> bool: + return e.args[0] == ER.PARSE_ERROR @staticmethod - def is_data_too_long(e: mariadb.Error) -> bool: - return getattr(e, "errno", None) == ERR.ER_DATA_TOO_LONG + def is_data_too_long(e: pymysql.Error) -> bool: + return e.args[0] == ER.DATA_TOO_LONG @staticmethod - def is_primary_key_violation(e: mariadb.Error) -> bool: + def is_primary_key_violation(e: pymysql.Error) -> bool: return ( MariaDBDatabase.is_duplicate_entry(e) - and "PRIMARY" in e.errmsg - and isinstance(e, mariadb.IntegrityError) + and "PRIMARY" in cstr(e.args[1]) + and isinstance(e, pymysql.IntegrityError) ) @staticmethod - def is_unique_key_violation(e: mariadb.Error) -> bool: + def is_unique_key_violation(e: pymysql.Error) -> bool: return ( MariaDBDatabase.is_duplicate_entry(e) - and "Duplicate" in e.errmsg - and isinstance(e, mariadb.IntegrityError) + and "Duplicate" in cstr(e.args[1]) + and isinstance(e, pymysql.IntegrityError) ) @@ -118,90 +91,21 @@ class MariaDBConnectionUtil: conn.auto_reconnect = True return conn - def _get_connection(self) -> "mariadb.Connection": - """Return MariaDB connection object. - - If frappe.conf.disable_database_connection_pooling is set, return a new connection - object and close existing pool if exists. Else, return a connection from the pool. - """ - global _SITE_POOLS - - # don't pool root connections - if self.user == "root": - return self.create_connection() - - if not is_connection_pooling_enabled(): - self.close_connection_pools() - return self.create_connection() - - if frappe.local.site not in _SITE_POOLS: - site_pool = self.create_connection_pool() - else: - site_pool = self.get_connection_pool() - - try: - conn = site_pool.get_connection() - except mariadb.PoolError: - # PoolError is raised when the pool is exhausted - conn = self.create_connection() - try: - site_pool.add_connection(conn) - # log this via frappe.logger & continue - site needs bigger pool...over _POOL_SIZE - except mariadb.PoolError: - # PoolError is raised when size limit is reached - # log this via frappe.logger & continue - site needs a much bigger pool...over _MAX_POOL_SIZE - pass - - return conn - - def close_connection_pools(self): - if frappe.local.site in _SITE_POOLS: - pools = _SITE_POOLS[frappe.local.site] - for pool in pools.values(): - try: - pool.close() - except Exception: - pass - _SITE_POOLS.pop(frappe.local.site, None) - - def get_pool_name(self) -> str: - pool_type = "read-only" if self.read_only else "default" - return f"{frappe.local.site}-{pool_type}" - - def get_connection_pool(self) -> "ConnectionPool": - """Return MariaDB connection pool object. - - If `read_only` is True, return a read only pool. - """ - return _SITE_POOLS[frappe.local.site]["read_only" if self.read_only else "default"] - - def create_connection_pool(self): - pool = mariadb.ConnectionPool( - pool_name=self.get_pool_name(), - pool_size=_MAX_POOL_SIZE, - pool_reset_connection=False, - ) - pool.set_config(**self.get_connection_settings()) - - if self.read_only: - _SITE_POOLS[frappe.local.site].read_only = pool - else: - _SITE_POOLS[frappe.local.site].default = pool - - for _ in range(_POOL_SIZE): - pool.add_connection() - - return pool + def _get_connection(self): + """Return MariaDB connection object.""" + return self.create_connection() def create_connection(self): - return mariadb.connect(**self.get_connection_settings()) + return pymysql.connect(**self.get_connection_settings()) def get_connection_settings(self) -> dict: conn_settings = { "host": self.host, "user": self.user, "password": self.password, - "converter": self.CONVERSION_MAP, + "conv": self.CONVERSION_MAP, + "charset": "utf8mb4", + "use_unicode": True, } if self.user != "root": @@ -215,63 +119,15 @@ class MariaDBConnectionUtil: if frappe.conf.db_ssl_ca and frappe.conf.db_ssl_cert and frappe.conf.db_ssl_key: ssl_params = { - "ssl": True, - "ssl_ca": frappe.conf.db_ssl_ca, - "ssl_cert": frappe.conf.db_ssl_cert, - "ssl_key": frappe.conf.db_ssl_key, + "ca": frappe.conf.db_ssl_ca, + "cert": frappe.conf.db_ssl_cert, + "key": frappe.conf.db_ssl_key, } conn_settings.update(ssl_params) return conn_settings -class MariaDBCursorPatchUtil: - """Patch mariadb.cursor.Cursor to handle things not supported by pinned version of MariaDB client.""" - - def _transform_query(self, query: str, values: QueryValues) -> tuple: - """Transform the query to handle things not supported by pinned version of MariaDB client. - - Transformations: - - Escape sequences in values - """ - _values = [] - - if isinstance(values, (tuple, list)): - for val in values: - if isinstance(val, (tuple, list)): - _values.append(escape_sequence(val, charset=self._conn.character_set)) - else: - _values.append(val) - values = _values - else: - for token in _PARAM_COMP.findall(query): - key = token[2:-2] - try: - val = values[key] - except KeyError: - raise self.ProgrammingError(f"Missing value for key '{key}'") - if isinstance(val, (tuple, list)): - values[key] = escape_sequence(val, charset=self._conn.character_set) - - return query, values or [] - - def _transform_result(self, result: list[tuple]) -> list[tuple]: - # ref: https://jira.mariadb.org/projects/CONPY/issues/CONPY-213 - _result = [] - for row in result: - _row = [] - for el in row: - if isinstance(el, Decimal): - el = float(el) - elif isinstance(el, UnicodeWithAttrs): - el = escape_string(el) - _row.append(el) - _result.append(tuple(_row)) - return _result - - -class MariaDBDatabase( - MariaDBCursorPatchUtil, MariaDBConnectionUtil, MariaDBExceptionUtil, Database -): +class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database): REGEX_CHARACTER = "regexp" # NOTE: using a very small cache - as during backup, if the sequence was used in anyform, @@ -282,7 +138,7 @@ class MariaDBDatabase( # using the system after a restore. # issue link: https://jira.mariadb.org/browse/MDEV-21786 SEQUENCE_CACHE = 50 - CONVERSION_MAP = { + CONVERSION_MAP = conversions | { FIELD_TYPE.NEWDECIMAL: float, FIELD_TYPE.DATETIME: get_datetime, UnicodeWithAttrs: escape_string, @@ -342,7 +198,8 @@ class MariaDBDatabase( return db_size[0].get("database_size") def log_query(self, query, values, debug, explain): - self.last_query = super().log_query(query, values, debug, explain) + self.last_query = self._cursor._last_executed + self._log_query(query, debug, explain) return self.last_query @staticmethod @@ -368,11 +225,11 @@ class MariaDBDatabase( # column type @staticmethod def is_type_number(code): - return code == mariadb.NUMBER + return code == pymysql.NUMBER @staticmethod def is_type_datetime(code): - return code == mariadb.DATETIME + return code == pymysql.DATETIME def rename_table(self, old_name: str, new_name: str) -> list | tuple: old_name = get_table_name(old_name) diff --git a/frappe/utils/bench_helper.py b/frappe/utils/bench_helper.py index 10ace1b1b6..a0b011acc1 100644 --- a/frappe/utils/bench_helper.py +++ b/frappe/utils/bench_helper.py @@ -106,6 +106,4 @@ if __name__ == "__main__": if not frappe._dev_server: warnings.simplefilter("ignore") - frappe.DISABLE_DATABASE_CONNECTION_POOLING = not int(os.environ.get("DATABASE_POOLING", "0")) - main() diff --git a/pyproject.toml b/pyproject.toml index c3ef944b85..5eeb6f46dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ dependencies = [ "html5lib~=1.1", "ipython~=8.4.0", "ldap3~=2.9", - "mariadb~=1.1.2", "markdown2~=2.4.0", "maxminddb-geolite2==2018.703", "num2words~=0.5.10",