refactor: Use pymysql over mariadb client
This is supposed to be a temporary switch to make the parent PR easier to digest. MariaDB client has some issues with release, and system dependencies. This commit may be reverted to enable mariadb client again.
This commit is contained in:
parent
486f26a1ff
commit
006ebcbede
5 changed files with 55 additions and 204 deletions
|
|
@ -48,7 +48,6 @@ __title__ = "Frappe Framework"
|
|||
controllers = {}
|
||||
local = Local()
|
||||
STANDARD_USERS = ("Guest", "Administrator")
|
||||
DISABLE_DATABASE_CONNECTION_POOLING = None
|
||||
|
||||
_dev_server = int(sbool(os.environ.get("DEV_SERVER", False)))
|
||||
_qb_patched = {}
|
||||
|
|
|
|||
|
|
@ -69,8 +69,6 @@ def new_site(
|
|||
"Create a new site"
|
||||
from frappe.installer import _new_site
|
||||
|
||||
frappe.DISABLE_DATABASE_CONNECTION_POOLING = True
|
||||
|
||||
frappe.init(site=site, new_site=True)
|
||||
|
||||
_new_site(
|
||||
|
|
|
|||
|
|
@ -1,114 +1,87 @@
|
|||
import re
|
||||
from collections import defaultdict
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import mariadb
|
||||
from mariadb.constants import ERR, FIELD_TYPE
|
||||
from pymysql.converters import escape_sequence, escape_string
|
||||
import pymysql
|
||||
from pymysql.constants import ER, FIELD_TYPE
|
||||
from pymysql.converters import conversions, escape_string
|
||||
|
||||
import frappe
|
||||
from frappe.database.database import Database, QueryValues
|
||||
from frappe.database.database import Database
|
||||
from frappe.database.mariadb.schema import MariaDBTable
|
||||
from frappe.utils import UnicodeWithAttrs, get_datetime, get_table_name
|
||||
from frappe.utils import UnicodeWithAttrs, cstr, get_datetime, get_table_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mariadb import ConnectionPool
|
||||
|
||||
_FIND_ITER_PATTERN = re.compile("%s")
|
||||
_PARAM_COMP = re.compile(r"%\([\w]*\)s")
|
||||
_SITE_POOLS = defaultdict(frappe._dict)
|
||||
_MAX_POOL_SIZE = 64
|
||||
_POOL_SIZE = 1
|
||||
|
||||
# _POOL_SIZE is selected "arbitrarily" to avoid overloading the server and being mindful of multitenancy
|
||||
# init size of connection pool will be _POOL_SIZE for each site. Replica setups will have separate pool.
|
||||
# This means each site with a replica setup can have 2 active pools of size _POOL_SIZE each. Each pool may
|
||||
# expand up to _MAX_POOL_SIZE as per requirement. This cannot be a function of @@global.max_connections,
|
||||
# no. of sites since there may be multiple processes holding connections; and this defines the size for each
|
||||
# of those processes/workers. Check MariaDBConnectionUtil for connection & pool management.
|
||||
|
||||
|
||||
def is_connection_pooling_enabled() -> bool:
|
||||
"""Set `frappe.DISABLE_CONNECTION_POOLING` to enable/disable connection pooling for all on current
|
||||
process. This will override config key `disable_database_connection_pooling`. Set key
|
||||
`disable_database_connection_pooling` in site config for persistent settings across workers."""
|
||||
|
||||
if frappe.DISABLE_DATABASE_CONNECTION_POOLING is not None:
|
||||
return not frappe.DISABLE_DATABASE_CONNECTION_POOLING
|
||||
return not frappe.local.conf.disable_database_connection_pooling
|
||||
|
||||
|
||||
class MariaDBExceptionUtil:
|
||||
ProgrammingError = mariadb.ProgrammingError
|
||||
TableMissingError = mariadb.ProgrammingError
|
||||
OperationalError = mariadb.OperationalError
|
||||
InternalError = mariadb.InternalError
|
||||
SQLError = mariadb.ProgrammingError
|
||||
DataError = mariadb.DataError
|
||||
ProgrammingError = pymysql.ProgrammingError
|
||||
TableMissingError = pymysql.ProgrammingError
|
||||
OperationalError = pymysql.OperationalError
|
||||
InternalError = pymysql.InternalError
|
||||
SQLError = pymysql.ProgrammingError
|
||||
DataError = pymysql.DataError
|
||||
|
||||
# match ER_SEQUENCE_RUN_OUT - https://mariadb.com/kb/en/mariadb-error-codes/
|
||||
SequenceGeneratorLimitExceeded = mariadb.OperationalError
|
||||
SequenceGeneratorLimitExceeded = pymysql.OperationalError
|
||||
SequenceGeneratorLimitExceeded.errno = 4084
|
||||
|
||||
@staticmethod
|
||||
def is_deadlocked(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_LOCK_DEADLOCK
|
||||
def is_deadlocked(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.LOCK_DEADLOCK
|
||||
|
||||
@staticmethod
|
||||
def is_timedout(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_LOCK_WAIT_TIMEOUT
|
||||
def is_timedout(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.LOCK_WAIT_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def is_table_missing(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_NO_SUCH_TABLE
|
||||
def is_table_missing(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.NO_SUCH_TABLE
|
||||
|
||||
@staticmethod
|
||||
def is_missing_table(e: mariadb.Error) -> bool:
|
||||
def is_missing_table(e: pymysql.Error) -> bool:
|
||||
return MariaDBDatabase.is_table_missing(e)
|
||||
|
||||
@staticmethod
|
||||
def is_missing_column(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_BAD_FIELD_ERROR
|
||||
def is_missing_column(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.BAD_FIELD_ERROR
|
||||
|
||||
@staticmethod
|
||||
def is_duplicate_fieldname(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_DUP_FIELDNAME
|
||||
def is_duplicate_fieldname(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.DUP_FIELDNAME
|
||||
|
||||
@staticmethod
|
||||
def is_duplicate_entry(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_DUP_ENTRY
|
||||
def is_duplicate_entry(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.DUP_ENTRY
|
||||
|
||||
@staticmethod
|
||||
def is_access_denied(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_ACCESS_DENIED_ERROR
|
||||
def is_access_denied(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.ACCESS_DENIED_ERROR
|
||||
|
||||
@staticmethod
|
||||
def cant_drop_field_or_key(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_CANT_DROP_FIELD_OR_KEY
|
||||
def cant_drop_field_or_key(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.CANT_DROP_FIELD_OR_KEY
|
||||
|
||||
@staticmethod
|
||||
def is_syntax_error(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_PARSE_ERROR
|
||||
def is_syntax_error(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.PARSE_ERROR
|
||||
|
||||
@staticmethod
|
||||
def is_data_too_long(e: mariadb.Error) -> bool:
|
||||
return getattr(e, "errno", None) == ERR.ER_DATA_TOO_LONG
|
||||
def is_data_too_long(e: pymysql.Error) -> bool:
|
||||
return e.args[0] == ER.DATA_TOO_LONG
|
||||
|
||||
@staticmethod
|
||||
def is_primary_key_violation(e: mariadb.Error) -> bool:
|
||||
def is_primary_key_violation(e: pymysql.Error) -> bool:
|
||||
return (
|
||||
MariaDBDatabase.is_duplicate_entry(e)
|
||||
and "PRIMARY" in e.errmsg
|
||||
and isinstance(e, mariadb.IntegrityError)
|
||||
and "PRIMARY" in cstr(e.args[1])
|
||||
and isinstance(e, pymysql.IntegrityError)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_unique_key_violation(e: mariadb.Error) -> bool:
|
||||
def is_unique_key_violation(e: pymysql.Error) -> bool:
|
||||
return (
|
||||
MariaDBDatabase.is_duplicate_entry(e)
|
||||
and "Duplicate" in e.errmsg
|
||||
and isinstance(e, mariadb.IntegrityError)
|
||||
and "Duplicate" in cstr(e.args[1])
|
||||
and isinstance(e, pymysql.IntegrityError)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -118,90 +91,21 @@ class MariaDBConnectionUtil:
|
|||
conn.auto_reconnect = True
|
||||
return conn
|
||||
|
||||
def _get_connection(self) -> "mariadb.Connection":
|
||||
"""Return MariaDB connection object.
|
||||
|
||||
If frappe.conf.disable_database_connection_pooling is set, return a new connection
|
||||
object and close existing pool if exists. Else, return a connection from the pool.
|
||||
"""
|
||||
global _SITE_POOLS
|
||||
|
||||
# don't pool root connections
|
||||
if self.user == "root":
|
||||
return self.create_connection()
|
||||
|
||||
if not is_connection_pooling_enabled():
|
||||
self.close_connection_pools()
|
||||
return self.create_connection()
|
||||
|
||||
if frappe.local.site not in _SITE_POOLS:
|
||||
site_pool = self.create_connection_pool()
|
||||
else:
|
||||
site_pool = self.get_connection_pool()
|
||||
|
||||
try:
|
||||
conn = site_pool.get_connection()
|
||||
except mariadb.PoolError:
|
||||
# PoolError is raised when the pool is exhausted
|
||||
conn = self.create_connection()
|
||||
try:
|
||||
site_pool.add_connection(conn)
|
||||
# log this via frappe.logger & continue - site needs bigger pool...over _POOL_SIZE
|
||||
except mariadb.PoolError:
|
||||
# PoolError is raised when size limit is reached
|
||||
# log this via frappe.logger & continue - site needs a much bigger pool...over _MAX_POOL_SIZE
|
||||
pass
|
||||
|
||||
return conn
|
||||
|
||||
def close_connection_pools(self):
|
||||
if frappe.local.site in _SITE_POOLS:
|
||||
pools = _SITE_POOLS[frappe.local.site]
|
||||
for pool in pools.values():
|
||||
try:
|
||||
pool.close()
|
||||
except Exception:
|
||||
pass
|
||||
_SITE_POOLS.pop(frappe.local.site, None)
|
||||
|
||||
def get_pool_name(self) -> str:
|
||||
pool_type = "read-only" if self.read_only else "default"
|
||||
return f"{frappe.local.site}-{pool_type}"
|
||||
|
||||
def get_connection_pool(self) -> "ConnectionPool":
|
||||
"""Return MariaDB connection pool object.
|
||||
|
||||
If `read_only` is True, return a read only pool.
|
||||
"""
|
||||
return _SITE_POOLS[frappe.local.site]["read_only" if self.read_only else "default"]
|
||||
|
||||
def create_connection_pool(self):
|
||||
pool = mariadb.ConnectionPool(
|
||||
pool_name=self.get_pool_name(),
|
||||
pool_size=_MAX_POOL_SIZE,
|
||||
pool_reset_connection=False,
|
||||
)
|
||||
pool.set_config(**self.get_connection_settings())
|
||||
|
||||
if self.read_only:
|
||||
_SITE_POOLS[frappe.local.site].read_only = pool
|
||||
else:
|
||||
_SITE_POOLS[frappe.local.site].default = pool
|
||||
|
||||
for _ in range(_POOL_SIZE):
|
||||
pool.add_connection()
|
||||
|
||||
return pool
|
||||
def _get_connection(self):
|
||||
"""Return MariaDB connection object."""
|
||||
return self.create_connection()
|
||||
|
||||
def create_connection(self):
|
||||
return mariadb.connect(**self.get_connection_settings())
|
||||
return pymysql.connect(**self.get_connection_settings())
|
||||
|
||||
def get_connection_settings(self) -> dict:
|
||||
conn_settings = {
|
||||
"host": self.host,
|
||||
"user": self.user,
|
||||
"password": self.password,
|
||||
"converter": self.CONVERSION_MAP,
|
||||
"conv": self.CONVERSION_MAP,
|
||||
"charset": "utf8mb4",
|
||||
"use_unicode": True,
|
||||
}
|
||||
|
||||
if self.user != "root":
|
||||
|
|
@ -215,63 +119,15 @@ class MariaDBConnectionUtil:
|
|||
|
||||
if frappe.conf.db_ssl_ca and frappe.conf.db_ssl_cert and frappe.conf.db_ssl_key:
|
||||
ssl_params = {
|
||||
"ssl": True,
|
||||
"ssl_ca": frappe.conf.db_ssl_ca,
|
||||
"ssl_cert": frappe.conf.db_ssl_cert,
|
||||
"ssl_key": frappe.conf.db_ssl_key,
|
||||
"ca": frappe.conf.db_ssl_ca,
|
||||
"cert": frappe.conf.db_ssl_cert,
|
||||
"key": frappe.conf.db_ssl_key,
|
||||
}
|
||||
conn_settings.update(ssl_params)
|
||||
return conn_settings
|
||||
|
||||
|
||||
class MariaDBCursorPatchUtil:
|
||||
"""Patch mariadb.cursor.Cursor to handle things not supported by pinned version of MariaDB client."""
|
||||
|
||||
def _transform_query(self, query: str, values: QueryValues) -> tuple:
|
||||
"""Transform the query to handle things not supported by pinned version of MariaDB client.
|
||||
|
||||
Transformations:
|
||||
- Escape sequences in values
|
||||
"""
|
||||
_values = []
|
||||
|
||||
if isinstance(values, (tuple, list)):
|
||||
for val in values:
|
||||
if isinstance(val, (tuple, list)):
|
||||
_values.append(escape_sequence(val, charset=self._conn.character_set))
|
||||
else:
|
||||
_values.append(val)
|
||||
values = _values
|
||||
else:
|
||||
for token in _PARAM_COMP.findall(query):
|
||||
key = token[2:-2]
|
||||
try:
|
||||
val = values[key]
|
||||
except KeyError:
|
||||
raise self.ProgrammingError(f"Missing value for key '{key}'")
|
||||
if isinstance(val, (tuple, list)):
|
||||
values[key] = escape_sequence(val, charset=self._conn.character_set)
|
||||
|
||||
return query, values or []
|
||||
|
||||
def _transform_result(self, result: list[tuple]) -> list[tuple]:
|
||||
# ref: https://jira.mariadb.org/projects/CONPY/issues/CONPY-213
|
||||
_result = []
|
||||
for row in result:
|
||||
_row = []
|
||||
for el in row:
|
||||
if isinstance(el, Decimal):
|
||||
el = float(el)
|
||||
elif isinstance(el, UnicodeWithAttrs):
|
||||
el = escape_string(el)
|
||||
_row.append(el)
|
||||
_result.append(tuple(_row))
|
||||
return _result
|
||||
|
||||
|
||||
class MariaDBDatabase(
|
||||
MariaDBCursorPatchUtil, MariaDBConnectionUtil, MariaDBExceptionUtil, Database
|
||||
):
|
||||
class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database):
|
||||
REGEX_CHARACTER = "regexp"
|
||||
|
||||
# NOTE: using a very small cache - as during backup, if the sequence was used in anyform,
|
||||
|
|
@ -282,7 +138,7 @@ class MariaDBDatabase(
|
|||
# using the system after a restore.
|
||||
# issue link: https://jira.mariadb.org/browse/MDEV-21786
|
||||
SEQUENCE_CACHE = 50
|
||||
CONVERSION_MAP = {
|
||||
CONVERSION_MAP = conversions | {
|
||||
FIELD_TYPE.NEWDECIMAL: float,
|
||||
FIELD_TYPE.DATETIME: get_datetime,
|
||||
UnicodeWithAttrs: escape_string,
|
||||
|
|
@ -342,7 +198,8 @@ class MariaDBDatabase(
|
|||
return db_size[0].get("database_size")
|
||||
|
||||
def log_query(self, query, values, debug, explain):
|
||||
self.last_query = super().log_query(query, values, debug, explain)
|
||||
self.last_query = self._cursor._last_executed
|
||||
self._log_query(query, debug, explain)
|
||||
return self.last_query
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -368,11 +225,11 @@ class MariaDBDatabase(
|
|||
# column type
|
||||
@staticmethod
|
||||
def is_type_number(code):
|
||||
return code == mariadb.NUMBER
|
||||
return code == pymysql.NUMBER
|
||||
|
||||
@staticmethod
|
||||
def is_type_datetime(code):
|
||||
return code == mariadb.DATETIME
|
||||
return code == pymysql.DATETIME
|
||||
|
||||
def rename_table(self, old_name: str, new_name: str) -> list | tuple:
|
||||
old_name = get_table_name(old_name)
|
||||
|
|
|
|||
|
|
@ -106,6 +106,4 @@ if __name__ == "__main__":
|
|||
if not frappe._dev_server:
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
frappe.DISABLE_DATABASE_CONNECTION_POOLING = not int(os.environ.get("DATABASE_POOLING", "0"))
|
||||
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ dependencies = [
|
|||
"html5lib~=1.1",
|
||||
"ipython~=8.4.0",
|
||||
"ldap3~=2.9",
|
||||
"mariadb~=1.1.2",
|
||||
"markdown2~=2.4.0",
|
||||
"maxminddb-geolite2==2018.703",
|
||||
"num2words~=0.5.10",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue