refactor: Use pymysql over mariadb client

This is supposed to be a temporary switch to make the parent PR easier
to digest. MariaDB client has some issues with release, and system
dependencies.

This commit may be reverted to enable mariadb client again.
This commit is contained in:
Gavin D'souza 2022-07-21 11:17:02 +05:30
parent 486f26a1ff
commit 006ebcbede
5 changed files with 55 additions and 204 deletions

View file

@ -48,7 +48,6 @@ __title__ = "Frappe Framework"
controllers = {}
local = Local()
STANDARD_USERS = ("Guest", "Administrator")
DISABLE_DATABASE_CONNECTION_POOLING = None
_dev_server = int(sbool(os.environ.get("DEV_SERVER", False)))
_qb_patched = {}

View file

@ -69,8 +69,6 @@ def new_site(
"Create a new site"
from frappe.installer import _new_site
frappe.DISABLE_DATABASE_CONNECTION_POOLING = True
frappe.init(site=site, new_site=True)
_new_site(

View file

@ -1,114 +1,87 @@
import re
from collections import defaultdict
from decimal import Decimal
from typing import TYPE_CHECKING
import mariadb
from mariadb.constants import ERR, FIELD_TYPE
from pymysql.converters import escape_sequence, escape_string
import pymysql
from pymysql.constants import ER, FIELD_TYPE
from pymysql.converters import conversions, escape_string
import frappe
from frappe.database.database import Database, QueryValues
from frappe.database.database import Database
from frappe.database.mariadb.schema import MariaDBTable
from frappe.utils import UnicodeWithAttrs, get_datetime, get_table_name
from frappe.utils import UnicodeWithAttrs, cstr, get_datetime, get_table_name
if TYPE_CHECKING:
from mariadb import ConnectionPool
_FIND_ITER_PATTERN = re.compile("%s")
_PARAM_COMP = re.compile(r"%\([\w]*\)s")
_SITE_POOLS = defaultdict(frappe._dict)
_MAX_POOL_SIZE = 64
_POOL_SIZE = 1
# _POOL_SIZE is selected "arbitrarily" to avoid overloading the server and being mindful of multitenancy
# init size of connection pool will be _POOL_SIZE for each site. Replica setups will have separate pool.
# This means each site with a replica setup can have 2 active pools of size _POOL_SIZE each. Each pool may
# expand up to _MAX_POOL_SIZE as per requirement. This cannot be a function of @@global.max_connections,
# no. of sites since there may be multiple processes holding connections; and this defines the size for each
# of those processes/workers. Check MariaDBConnectionUtil for connection & pool management.
def is_connection_pooling_enabled() -> bool:
"""Set `frappe.DISABLE_CONNECTION_POOLING` to enable/disable connection pooling for all on current
process. This will override config key `disable_database_connection_pooling`. Set key
`disable_database_connection_pooling` in site config for persistent settings across workers."""
if frappe.DISABLE_DATABASE_CONNECTION_POOLING is not None:
return not frappe.DISABLE_DATABASE_CONNECTION_POOLING
return not frappe.local.conf.disable_database_connection_pooling
class MariaDBExceptionUtil:
ProgrammingError = mariadb.ProgrammingError
TableMissingError = mariadb.ProgrammingError
OperationalError = mariadb.OperationalError
InternalError = mariadb.InternalError
SQLError = mariadb.ProgrammingError
DataError = mariadb.DataError
ProgrammingError = pymysql.ProgrammingError
TableMissingError = pymysql.ProgrammingError
OperationalError = pymysql.OperationalError
InternalError = pymysql.InternalError
SQLError = pymysql.ProgrammingError
DataError = pymysql.DataError
# match ER_SEQUENCE_RUN_OUT - https://mariadb.com/kb/en/mariadb-error-codes/
SequenceGeneratorLimitExceeded = mariadb.OperationalError
SequenceGeneratorLimitExceeded = pymysql.OperationalError
SequenceGeneratorLimitExceeded.errno = 4084
@staticmethod
def is_deadlocked(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_LOCK_DEADLOCK
def is_deadlocked(e: pymysql.Error) -> bool:
return e.args[0] == ER.LOCK_DEADLOCK
@staticmethod
def is_timedout(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_LOCK_WAIT_TIMEOUT
def is_timedout(e: pymysql.Error) -> bool:
return e.args[0] == ER.LOCK_WAIT_TIMEOUT
@staticmethod
def is_table_missing(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_NO_SUCH_TABLE
def is_table_missing(e: pymysql.Error) -> bool:
return e.args[0] == ER.NO_SUCH_TABLE
@staticmethod
def is_missing_table(e: mariadb.Error) -> bool:
def is_missing_table(e: pymysql.Error) -> bool:
return MariaDBDatabase.is_table_missing(e)
@staticmethod
def is_missing_column(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_BAD_FIELD_ERROR
def is_missing_column(e: pymysql.Error) -> bool:
return e.args[0] == ER.BAD_FIELD_ERROR
@staticmethod
def is_duplicate_fieldname(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_DUP_FIELDNAME
def is_duplicate_fieldname(e: pymysql.Error) -> bool:
return e.args[0] == ER.DUP_FIELDNAME
@staticmethod
def is_duplicate_entry(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_DUP_ENTRY
def is_duplicate_entry(e: pymysql.Error) -> bool:
return e.args[0] == ER.DUP_ENTRY
@staticmethod
def is_access_denied(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_ACCESS_DENIED_ERROR
def is_access_denied(e: pymysql.Error) -> bool:
return e.args[0] == ER.ACCESS_DENIED_ERROR
@staticmethod
def cant_drop_field_or_key(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_CANT_DROP_FIELD_OR_KEY
def cant_drop_field_or_key(e: pymysql.Error) -> bool:
return e.args[0] == ER.CANT_DROP_FIELD_OR_KEY
@staticmethod
def is_syntax_error(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_PARSE_ERROR
def is_syntax_error(e: pymysql.Error) -> bool:
return e.args[0] == ER.PARSE_ERROR
@staticmethod
def is_data_too_long(e: mariadb.Error) -> bool:
return getattr(e, "errno", None) == ERR.ER_DATA_TOO_LONG
def is_data_too_long(e: pymysql.Error) -> bool:
return e.args[0] == ER.DATA_TOO_LONG
@staticmethod
def is_primary_key_violation(e: mariadb.Error) -> bool:
def is_primary_key_violation(e: pymysql.Error) -> bool:
return (
MariaDBDatabase.is_duplicate_entry(e)
and "PRIMARY" in e.errmsg
and isinstance(e, mariadb.IntegrityError)
and "PRIMARY" in cstr(e.args[1])
and isinstance(e, pymysql.IntegrityError)
)
@staticmethod
def is_unique_key_violation(e: mariadb.Error) -> bool:
def is_unique_key_violation(e: pymysql.Error) -> bool:
return (
MariaDBDatabase.is_duplicate_entry(e)
and "Duplicate" in e.errmsg
and isinstance(e, mariadb.IntegrityError)
and "Duplicate" in cstr(e.args[1])
and isinstance(e, pymysql.IntegrityError)
)
@ -118,90 +91,21 @@ class MariaDBConnectionUtil:
conn.auto_reconnect = True
return conn
def _get_connection(self) -> "mariadb.Connection":
"""Return MariaDB connection object.
If frappe.conf.disable_database_connection_pooling is set, return a new connection
object and close existing pool if exists. Else, return a connection from the pool.
"""
global _SITE_POOLS
# don't pool root connections
if self.user == "root":
return self.create_connection()
if not is_connection_pooling_enabled():
self.close_connection_pools()
return self.create_connection()
if frappe.local.site not in _SITE_POOLS:
site_pool = self.create_connection_pool()
else:
site_pool = self.get_connection_pool()
try:
conn = site_pool.get_connection()
except mariadb.PoolError:
# PoolError is raised when the pool is exhausted
conn = self.create_connection()
try:
site_pool.add_connection(conn)
# log this via frappe.logger & continue - site needs bigger pool...over _POOL_SIZE
except mariadb.PoolError:
# PoolError is raised when size limit is reached
# log this via frappe.logger & continue - site needs a much bigger pool...over _MAX_POOL_SIZE
pass
return conn
def close_connection_pools(self):
if frappe.local.site in _SITE_POOLS:
pools = _SITE_POOLS[frappe.local.site]
for pool in pools.values():
try:
pool.close()
except Exception:
pass
_SITE_POOLS.pop(frappe.local.site, None)
def get_pool_name(self) -> str:
pool_type = "read-only" if self.read_only else "default"
return f"{frappe.local.site}-{pool_type}"
def get_connection_pool(self) -> "ConnectionPool":
"""Return MariaDB connection pool object.
If `read_only` is True, return a read only pool.
"""
return _SITE_POOLS[frappe.local.site]["read_only" if self.read_only else "default"]
def create_connection_pool(self):
pool = mariadb.ConnectionPool(
pool_name=self.get_pool_name(),
pool_size=_MAX_POOL_SIZE,
pool_reset_connection=False,
)
pool.set_config(**self.get_connection_settings())
if self.read_only:
_SITE_POOLS[frappe.local.site].read_only = pool
else:
_SITE_POOLS[frappe.local.site].default = pool
for _ in range(_POOL_SIZE):
pool.add_connection()
return pool
def _get_connection(self):
"""Return MariaDB connection object."""
return self.create_connection()
def create_connection(self):
return mariadb.connect(**self.get_connection_settings())
return pymysql.connect(**self.get_connection_settings())
def get_connection_settings(self) -> dict:
conn_settings = {
"host": self.host,
"user": self.user,
"password": self.password,
"converter": self.CONVERSION_MAP,
"conv": self.CONVERSION_MAP,
"charset": "utf8mb4",
"use_unicode": True,
}
if self.user != "root":
@ -215,63 +119,15 @@ class MariaDBConnectionUtil:
if frappe.conf.db_ssl_ca and frappe.conf.db_ssl_cert and frappe.conf.db_ssl_key:
ssl_params = {
"ssl": True,
"ssl_ca": frappe.conf.db_ssl_ca,
"ssl_cert": frappe.conf.db_ssl_cert,
"ssl_key": frappe.conf.db_ssl_key,
"ca": frappe.conf.db_ssl_ca,
"cert": frappe.conf.db_ssl_cert,
"key": frappe.conf.db_ssl_key,
}
conn_settings.update(ssl_params)
return conn_settings
class MariaDBCursorPatchUtil:
"""Patch mariadb.cursor.Cursor to handle things not supported by pinned version of MariaDB client."""
def _transform_query(self, query: str, values: QueryValues) -> tuple:
"""Transform the query to handle things not supported by pinned version of MariaDB client.
Transformations:
- Escape sequences in values
"""
_values = []
if isinstance(values, (tuple, list)):
for val in values:
if isinstance(val, (tuple, list)):
_values.append(escape_sequence(val, charset=self._conn.character_set))
else:
_values.append(val)
values = _values
else:
for token in _PARAM_COMP.findall(query):
key = token[2:-2]
try:
val = values[key]
except KeyError:
raise self.ProgrammingError(f"Missing value for key '{key}'")
if isinstance(val, (tuple, list)):
values[key] = escape_sequence(val, charset=self._conn.character_set)
return query, values or []
def _transform_result(self, result: list[tuple]) -> list[tuple]:
# ref: https://jira.mariadb.org/projects/CONPY/issues/CONPY-213
_result = []
for row in result:
_row = []
for el in row:
if isinstance(el, Decimal):
el = float(el)
elif isinstance(el, UnicodeWithAttrs):
el = escape_string(el)
_row.append(el)
_result.append(tuple(_row))
return _result
class MariaDBDatabase(
MariaDBCursorPatchUtil, MariaDBConnectionUtil, MariaDBExceptionUtil, Database
):
class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database):
REGEX_CHARACTER = "regexp"
# NOTE: using a very small cache - as during backup, if the sequence was used in anyform,
@ -282,7 +138,7 @@ class MariaDBDatabase(
# using the system after a restore.
# issue link: https://jira.mariadb.org/browse/MDEV-21786
SEQUENCE_CACHE = 50
CONVERSION_MAP = {
CONVERSION_MAP = conversions | {
FIELD_TYPE.NEWDECIMAL: float,
FIELD_TYPE.DATETIME: get_datetime,
UnicodeWithAttrs: escape_string,
@ -342,7 +198,8 @@ class MariaDBDatabase(
return db_size[0].get("database_size")
def log_query(self, query, values, debug, explain):
self.last_query = super().log_query(query, values, debug, explain)
self.last_query = self._cursor._last_executed
self._log_query(query, debug, explain)
return self.last_query
@staticmethod
@ -368,11 +225,11 @@ class MariaDBDatabase(
# column type
@staticmethod
def is_type_number(code):
return code == mariadb.NUMBER
return code == pymysql.NUMBER
@staticmethod
def is_type_datetime(code):
return code == mariadb.DATETIME
return code == pymysql.DATETIME
def rename_table(self, old_name: str, new_name: str) -> list | tuple:
old_name = get_table_name(old_name)

View file

@ -106,6 +106,4 @@ if __name__ == "__main__":
if not frappe._dev_server:
warnings.simplefilter("ignore")
frappe.DISABLE_DATABASE_CONNECTION_POOLING = not int(os.environ.get("DATABASE_POOLING", "0"))
main()

View file

@ -38,7 +38,6 @@ dependencies = [
"html5lib~=1.1",
"ipython~=8.4.0",
"ldap3~=2.9",
"mariadb~=1.1.2",
"markdown2~=2.4.0",
"maxminddb-geolite2==2018.703",
"num2words~=0.5.10",