If you're reading 1000s of rows from MySQL, the default behaviour is to read all of them in memory at once. One of the use case for reading large rows is reporting where a lot of data is read and then processed in Python. The read row is hoever not used again but still consumes memory until entire function exits. SSCursor (Server Side Cursor) allows fetching one row at a time. Note: This is slower than fetching everything at once AND has risk of connection loss. So, don't use this as a crutch. If possible rewrite code so processing is done in SQL.
540 lines
16 KiB
Python
540 lines
16 KiB
Python
import re
|
|
from contextlib import contextmanager
|
|
|
|
import pymysql
|
|
from pymysql.constants import ER, FIELD_TYPE
|
|
from pymysql.converters import conversions, escape_string
|
|
|
|
import frappe
|
|
from frappe.database.database import Database
|
|
from frappe.database.mariadb.schema import MariaDBTable
|
|
from frappe.utils import UnicodeWithAttrs, cstr, get_datetime, get_table_name
|
|
|
|
_PARAM_COMP = re.compile(r"%\([\w]*\)s")
|
|
|
|
|
|
class MariaDBExceptionUtil:
|
|
ProgrammingError = pymysql.ProgrammingError
|
|
TableMissingError = pymysql.ProgrammingError
|
|
OperationalError = pymysql.OperationalError
|
|
InternalError = pymysql.InternalError
|
|
SQLError = pymysql.ProgrammingError
|
|
DataError = pymysql.DataError
|
|
|
|
# match ER_SEQUENCE_RUN_OUT - https://mariadb.com/kb/en/mariadb-error-codes/
|
|
SequenceGeneratorLimitExceeded = pymysql.OperationalError
|
|
SequenceGeneratorLimitExceeded.errno = 4084
|
|
|
|
@staticmethod
|
|
def is_deadlocked(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.LOCK_DEADLOCK
|
|
|
|
@staticmethod
|
|
def is_timedout(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.LOCK_WAIT_TIMEOUT
|
|
|
|
@staticmethod
|
|
def is_read_only_mode_error(e: pymysql.Error) -> bool:
|
|
return e.args[0] == 1792
|
|
|
|
@staticmethod
|
|
def is_table_missing(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.NO_SUCH_TABLE
|
|
|
|
@staticmethod
|
|
def is_missing_table(e: pymysql.Error) -> bool:
|
|
return MariaDBDatabase.is_table_missing(e)
|
|
|
|
@staticmethod
|
|
def is_missing_column(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.BAD_FIELD_ERROR
|
|
|
|
@staticmethod
|
|
def is_duplicate_fieldname(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.DUP_FIELDNAME
|
|
|
|
@staticmethod
|
|
def is_duplicate_entry(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.DUP_ENTRY
|
|
|
|
@staticmethod
|
|
def is_access_denied(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.ACCESS_DENIED_ERROR
|
|
|
|
@staticmethod
|
|
def cant_drop_field_or_key(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.CANT_DROP_FIELD_OR_KEY
|
|
|
|
@staticmethod
|
|
def is_syntax_error(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.PARSE_ERROR
|
|
|
|
@staticmethod
|
|
def is_statement_timeout(e: pymysql.Error) -> bool:
|
|
return e.args[0] == 1969
|
|
|
|
@staticmethod
|
|
def is_data_too_long(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.DATA_TOO_LONG
|
|
|
|
@staticmethod
|
|
def is_db_table_size_limit(e: pymysql.Error) -> bool:
|
|
return e.args[0] == ER.TOO_BIG_ROWSIZE
|
|
|
|
@staticmethod
|
|
def is_primary_key_violation(e: pymysql.Error) -> bool:
|
|
return (
|
|
MariaDBDatabase.is_duplicate_entry(e)
|
|
and "PRIMARY" in cstr(e.args[1])
|
|
and isinstance(e, pymysql.IntegrityError)
|
|
)
|
|
|
|
@staticmethod
|
|
def is_unique_key_violation(e: pymysql.Error) -> bool:
|
|
return (
|
|
MariaDBDatabase.is_duplicate_entry(e)
|
|
and "Duplicate" in cstr(e.args[1])
|
|
and isinstance(e, pymysql.IntegrityError)
|
|
)
|
|
|
|
|
|
class MariaDBConnectionUtil:
|
|
def get_connection(self):
|
|
conn = self._get_connection()
|
|
conn.auto_reconnect = True
|
|
return conn
|
|
|
|
def _get_connection(self):
|
|
"""Return MariaDB connection object."""
|
|
return self.create_connection()
|
|
|
|
def create_connection(self):
|
|
return pymysql.connect(**self.get_connection_settings())
|
|
|
|
def set_execution_timeout(self, seconds: int):
|
|
self.sql("set session max_statement_time = %s", int(seconds))
|
|
|
|
def get_connection_settings(self) -> dict:
|
|
conn_settings = {
|
|
"host": self.host,
|
|
"user": self.user,
|
|
"password": self.password,
|
|
"conv": self.CONVERSION_MAP,
|
|
"charset": "utf8mb4",
|
|
"use_unicode": True,
|
|
}
|
|
|
|
if self.user not in (frappe.flags.root_login, "root"):
|
|
conn_settings["database"] = self.cur_db_name
|
|
|
|
if self.port:
|
|
conn_settings["port"] = int(self.port)
|
|
|
|
if frappe.conf.local_infile:
|
|
conn_settings["local_infile"] = frappe.conf.local_infile
|
|
|
|
if frappe.conf.db_ssl_ca and frappe.conf.db_ssl_cert and frappe.conf.db_ssl_key:
|
|
conn_settings["ssl"] = {
|
|
"ca": frappe.conf.db_ssl_ca,
|
|
"cert": frappe.conf.db_ssl_cert,
|
|
"key": frappe.conf.db_ssl_key,
|
|
}
|
|
return conn_settings
|
|
|
|
|
|
class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database):
|
|
REGEX_CHARACTER = "regexp"
|
|
CONVERSION_MAP = conversions | {
|
|
FIELD_TYPE.NEWDECIMAL: float,
|
|
FIELD_TYPE.DATETIME: get_datetime,
|
|
UnicodeWithAttrs: escape_string,
|
|
}
|
|
default_port = "3306"
|
|
MAX_ROW_SIZE_LIMIT = 65_535 # bytes
|
|
|
|
def setup_type_map(self):
|
|
self.db_type = "mariadb"
|
|
self.type_map = {
|
|
"Currency": ("decimal", "21,9"),
|
|
"Int": ("int", "11"),
|
|
"Long Int": ("bigint", "20"),
|
|
"Float": ("decimal", "21,9"),
|
|
"Percent": ("decimal", "21,9"),
|
|
"Check": ("int", "1"),
|
|
"Small Text": ("text", ""),
|
|
"Long Text": ("longtext", ""),
|
|
"Code": ("longtext", ""),
|
|
"Text Editor": ("longtext", ""),
|
|
"Markdown Editor": ("longtext", ""),
|
|
"HTML Editor": ("longtext", ""),
|
|
"Date": ("date", ""),
|
|
"Datetime": ("datetime", "6"),
|
|
"Time": ("time", "6"),
|
|
"Text": ("text", ""),
|
|
"Data": ("varchar", self.VARCHAR_LEN),
|
|
"Link": ("varchar", self.VARCHAR_LEN),
|
|
"Dynamic Link": ("varchar", self.VARCHAR_LEN),
|
|
"Password": ("text", ""),
|
|
"Select": ("varchar", self.VARCHAR_LEN),
|
|
"Rating": ("decimal", "3,2"),
|
|
"Read Only": ("varchar", self.VARCHAR_LEN),
|
|
"Attach": ("text", ""),
|
|
"Attach Image": ("text", ""),
|
|
"Signature": ("longtext", ""),
|
|
"Color": ("varchar", self.VARCHAR_LEN),
|
|
"Barcode": ("longtext", ""),
|
|
"Geolocation": ("longtext", ""),
|
|
"Duration": ("decimal", "21,9"),
|
|
"Icon": ("varchar", self.VARCHAR_LEN),
|
|
"Phone": ("varchar", self.VARCHAR_LEN),
|
|
"Autocomplete": ("varchar", self.VARCHAR_LEN),
|
|
"JSON": ("json", ""),
|
|
}
|
|
|
|
def get_database_size(self):
|
|
"""Return database size in MB."""
|
|
db_size = self.sql(
|
|
"""
|
|
SELECT `table_schema` as `database_name`,
|
|
SUM(`data_length` + `index_length`) / 1024 / 1024 AS `database_size`
|
|
FROM information_schema.tables WHERE `table_schema` = %s GROUP BY `table_schema`
|
|
""",
|
|
self.cur_db_name,
|
|
as_dict=True,
|
|
)
|
|
|
|
return db_size[0].get("database_size")
|
|
|
|
def log_query(self, query, values, debug, explain):
|
|
self.last_query = self._cursor._executed
|
|
self._log_query(self.last_query, debug, explain, query)
|
|
return self.last_query
|
|
|
|
def _clean_up(self):
|
|
# PERF: Erase internal references of pymysql to trigger GC as soon as
|
|
# results are consumed.
|
|
self._cursor._result = None
|
|
self._cursor._rows = None
|
|
self._cursor.connection._result = None
|
|
|
|
@staticmethod
|
|
def escape(s, percent=True):
|
|
"""Excape quotes and percent in given string."""
|
|
# Update: We've scrapped PyMySQL in favour of MariaDB's official Python client
|
|
# Also, given we're promoting use of the PyPika builder via frappe.qb, the use
|
|
# of this method should be limited.
|
|
|
|
# pymysql expects unicode argument to escape_string with Python 3
|
|
s = frappe.as_unicode(escape_string(frappe.as_unicode(s)), "utf-8").replace("`", "\\`")
|
|
|
|
# NOTE separating % escape, because % escape should only be done when using LIKE operator
|
|
# or when you use python format string to generate query that already has a %s
|
|
# for example: sql("select name from `tabUser` where name=%s and {0}".format(conditions), something)
|
|
# defaulting it to True, as this is the most frequent use case
|
|
# ideally we shouldn't have to use ESCAPE and strive to pass values via the values argument of sql
|
|
if percent:
|
|
s = s.replace("%", "%%")
|
|
|
|
return "'" + s + "'"
|
|
|
|
# column type
|
|
@staticmethod
|
|
def is_type_number(code):
|
|
return code == pymysql.NUMBER
|
|
|
|
@staticmethod
|
|
def is_type_datetime(code):
|
|
return code == pymysql.DATETIME
|
|
|
|
def rename_table(self, old_name: str, new_name: str) -> list | tuple:
|
|
old_name = get_table_name(old_name)
|
|
new_name = get_table_name(new_name)
|
|
return self.sql(f"RENAME TABLE `{old_name}` TO `{new_name}`")
|
|
|
|
def describe(self, doctype: str) -> list | tuple:
|
|
table_name = get_table_name(doctype)
|
|
return self.sql(f"DESC `{table_name}`")
|
|
|
|
def change_column_type(
|
|
self, doctype: str, column: str, type: str, nullable: bool = False
|
|
) -> list | tuple:
|
|
table_name = get_table_name(doctype)
|
|
null_constraint = "NOT NULL" if not nullable else ""
|
|
return self.sql_ddl(f"ALTER TABLE `{table_name}` MODIFY `{column}` {type} {null_constraint}")
|
|
|
|
def rename_column(self, doctype: str, old_column_name, new_column_name):
|
|
current_data_type = self.get_column_type(doctype, old_column_name)
|
|
|
|
table_name = get_table_name(doctype)
|
|
|
|
frappe.db.sql_ddl(
|
|
f"""ALTER TABLE `{table_name}`
|
|
CHANGE COLUMN `{old_column_name}`
|
|
`{new_column_name}`
|
|
{current_data_type}"""
|
|
# ^ Mariadb requires passing current data type again even if there's no change
|
|
# This requirement is gone from v10.5
|
|
)
|
|
|
|
def create_auth_table(self):
|
|
self.sql_ddl(
|
|
"""create table if not exists `__Auth` (
|
|
`doctype` VARCHAR(140) NOT NULL,
|
|
`name` VARCHAR(255) NOT NULL,
|
|
`fieldname` VARCHAR(140) NOT NULL,
|
|
`password` TEXT NOT NULL,
|
|
`encrypted` INT(1) NOT NULL DEFAULT 0,
|
|
PRIMARY KEY (`doctype`, `name`, `fieldname`)
|
|
) ENGINE=InnoDB ROW_FORMAT=DYNAMIC CHARACTER SET=utf8mb4 COLLATE=utf8mb4_unicode_ci"""
|
|
)
|
|
|
|
def create_global_search_table(self):
|
|
if "__global_search" not in self.get_tables():
|
|
self.sql(
|
|
"""create table __global_search(
|
|
doctype varchar(100),
|
|
name varchar({0}),
|
|
title varchar({0}),
|
|
content text,
|
|
fulltext(content),
|
|
route varchar({0}),
|
|
published int(1) not null default 0,
|
|
unique `doctype_name` (doctype, name))
|
|
COLLATE=utf8mb4_unicode_ci
|
|
ENGINE=MyISAM
|
|
CHARACTER SET=utf8mb4""".format(
|
|
self.VARCHAR_LEN
|
|
)
|
|
)
|
|
|
|
def create_user_settings_table(self):
|
|
self.sql_ddl(
|
|
"""create table if not exists __UserSettings (
|
|
`user` VARCHAR(180) NOT NULL,
|
|
`doctype` VARCHAR(180) NOT NULL,
|
|
`data` TEXT,
|
|
UNIQUE(user, doctype)
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8"""
|
|
)
|
|
|
|
@staticmethod
|
|
def get_on_duplicate_update():
|
|
return "ON DUPLICATE key UPDATE "
|
|
|
|
def get_table_columns_description(self, table_name):
|
|
"""Return list of columns with descriptions."""
|
|
return self.sql(
|
|
"""select
|
|
column_name as 'name',
|
|
column_type as 'type',
|
|
column_default as 'default',
|
|
COALESCE(
|
|
(select 1
|
|
from information_schema.statistics
|
|
where table_name="{table_name}"
|
|
and column_name=columns.column_name
|
|
and NON_UNIQUE=1
|
|
and Seq_in_index = 1
|
|
limit 1
|
|
), 0) as 'index',
|
|
column_key = 'UNI' as 'unique',
|
|
(is_nullable = 'NO') AS 'not_nullable'
|
|
from information_schema.columns as columns
|
|
where table_name = '{table_name}' """.format(
|
|
table_name=table_name
|
|
),
|
|
as_dict=1,
|
|
)
|
|
|
|
def get_column_type(self, doctype, column):
|
|
"""Return column type from database."""
|
|
information_schema = frappe.qb.Schema("information_schema")
|
|
table = get_table_name(doctype)
|
|
|
|
return (
|
|
frappe.qb.from_(information_schema.columns)
|
|
.select(information_schema.columns.column_type)
|
|
.where(
|
|
(information_schema.columns.table_name == table)
|
|
& (information_schema.columns.column_name == column)
|
|
)
|
|
.run(pluck=True)[0]
|
|
)
|
|
|
|
def has_index(self, table_name, index_name):
|
|
return self.sql(
|
|
"""SHOW INDEX FROM `{table_name}`
|
|
WHERE Key_name='{index_name}'""".format(
|
|
table_name=table_name, index_name=index_name
|
|
)
|
|
)
|
|
|
|
def get_column_index(
|
|
self, table_name: str, fieldname: str, unique: bool = False
|
|
) -> frappe._dict | None:
|
|
"""Check if column exists for a specific fields in specified order.
|
|
|
|
This differs from db.has_index because it doesn't rely on index name but columns inside an
|
|
index.
|
|
"""
|
|
|
|
indexes = self.sql(
|
|
f"""SHOW INDEX FROM `{table_name}`
|
|
WHERE Column_name = "{fieldname}"
|
|
AND Seq_in_index = 1
|
|
AND Non_unique={int(not unique)}
|
|
""",
|
|
as_dict=True,
|
|
)
|
|
|
|
# Same index can be part of clustered index which contains more fields
|
|
# We don't want those.
|
|
for index in indexes:
|
|
clustered_index = self.sql(
|
|
f"""SHOW INDEX FROM `{table_name}`
|
|
WHERE Key_name = "{index.Key_name}"
|
|
AND Seq_in_index = 2
|
|
""",
|
|
as_dict=True,
|
|
)
|
|
if not clustered_index:
|
|
return index
|
|
|
|
def add_index(self, doctype: str, fields: list, index_name: str = None):
|
|
"""Creates an index with given fields if not already created.
|
|
Index name will be `fieldname1_fieldname2_index`"""
|
|
index_name = index_name or self.get_index_name(fields)
|
|
table_name = get_table_name(doctype)
|
|
if not self.has_index(table_name, index_name):
|
|
self.commit()
|
|
self.sql(
|
|
"""ALTER TABLE `%s`
|
|
ADD INDEX `%s`(%s)"""
|
|
% (table_name, index_name, ", ".join(fields))
|
|
)
|
|
|
|
def add_unique(self, doctype, fields, constraint_name=None):
|
|
if isinstance(fields, str):
|
|
fields = [fields]
|
|
if not constraint_name:
|
|
constraint_name = "unique_" + "_".join(fields)
|
|
|
|
if not self.sql(
|
|
"""select CONSTRAINT_NAME from information_schema.TABLE_CONSTRAINTS
|
|
where table_name=%s and constraint_type='UNIQUE' and CONSTRAINT_NAME=%s""",
|
|
("tab" + doctype, constraint_name),
|
|
):
|
|
self.commit()
|
|
self.sql(
|
|
"""alter table `tab%s`
|
|
add unique `%s`(%s)"""
|
|
% (doctype, constraint_name, ", ".join(fields))
|
|
)
|
|
|
|
def updatedb(self, doctype, meta=None):
|
|
"""
|
|
Syncs a `DocType` to the table
|
|
* creates if required
|
|
* updates columns
|
|
* updates indices
|
|
"""
|
|
res = self.sql("select issingle from `tabDocType` where name=%s", (doctype,))
|
|
if not res:
|
|
raise Exception(f"Wrong doctype {doctype} in updatedb")
|
|
|
|
if not res[0][0]:
|
|
db_table = MariaDBTable(doctype, meta)
|
|
db_table.validate()
|
|
|
|
self.commit()
|
|
db_table.sync()
|
|
self.commit()
|
|
|
|
def get_database_list(self):
|
|
return self.sql("SHOW DATABASES", pluck=True)
|
|
|
|
def get_tables(self, cached=True):
|
|
"""Return list of tables."""
|
|
to_query = not cached
|
|
|
|
if cached:
|
|
tables = frappe.cache.get_value("db_tables")
|
|
to_query = not tables
|
|
|
|
if to_query:
|
|
information_schema = frappe.qb.Schema("information_schema")
|
|
|
|
tables = (
|
|
frappe.qb.from_(information_schema.tables)
|
|
.select(information_schema.tables.table_name)
|
|
.where(information_schema.tables.table_schema != "information_schema")
|
|
.run(pluck=True)
|
|
)
|
|
frappe.cache.set_value("db_tables", tables)
|
|
|
|
return tables
|
|
|
|
def get_row_size(self, doctype: str) -> int:
|
|
"""Get estimated max row size of any table in bytes."""
|
|
|
|
# Query reused from this answer: https://dba.stackexchange.com/a/313889/274503
|
|
# Modification: get values for particular table instead of full summary.
|
|
# Reference: https://mariadb.com/kb/en/data-type-storage-requirements/
|
|
|
|
est_row_size = frappe.db.sql(
|
|
"""
|
|
SELECT SUM(col_sizes.col_size) AS EST_MAX_ROW_SIZE
|
|
FROM (
|
|
SELECT
|
|
cols.COLUMN_NAME,
|
|
CASE cols.DATA_TYPE
|
|
WHEN 'tinyint' THEN 1
|
|
WHEN 'smallint' THEN 2
|
|
WHEN 'mediumint' THEN 3
|
|
WHEN 'int' THEN 4
|
|
WHEN 'bigint' THEN 8
|
|
WHEN 'float' THEN IF(cols.NUMERIC_PRECISION > 24, 8, 4)
|
|
WHEN 'double' THEN 8
|
|
WHEN 'decimal' THEN ((cols.NUMERIC_PRECISION - cols.NUMERIC_SCALE) DIV 9)*4 + (cols.NUMERIC_SCALE DIV 9)*4 + CEIL(MOD(cols.NUMERIC_PRECISION - cols.NUMERIC_SCALE,9)/2) + CEIL(MOD(cols.NUMERIC_SCALE,9)/2)
|
|
WHEN 'bit' THEN (cols.NUMERIC_PRECISION + 7) DIV 8
|
|
WHEN 'year' THEN 1
|
|
WHEN 'date' THEN 3
|
|
WHEN 'time' THEN 3 + CEIL(cols.DATETIME_PRECISION /2)
|
|
WHEN 'datetime' THEN 5 + CEIL(cols.DATETIME_PRECISION /2)
|
|
WHEN 'timestamp' THEN 4 + CEIL(cols.DATETIME_PRECISION /2)
|
|
WHEN 'char' THEN cols.CHARACTER_OCTET_LENGTH
|
|
WHEN 'binary' THEN cols.CHARACTER_OCTET_LENGTH
|
|
WHEN 'varchar' THEN IF(cols.CHARACTER_OCTET_LENGTH > 255, 2, 1) + cols.CHARACTER_OCTET_LENGTH
|
|
WHEN 'varbinary' THEN IF(cols.CHARACTER_OCTET_LENGTH > 255, 2, 1) + cols.CHARACTER_OCTET_LENGTH
|
|
WHEN 'tinyblob' THEN 9
|
|
WHEN 'tinytext' THEN 9
|
|
WHEN 'blob' THEN 10
|
|
WHEN 'text' THEN 10
|
|
WHEN 'mediumblob' THEN 11
|
|
WHEN 'mediumtext' THEN 11
|
|
WHEN 'longblob' THEN 12
|
|
WHEN 'longtext' THEN 12
|
|
WHEN 'enum' THEN 2
|
|
WHEN 'set' THEN 8
|
|
ELSE 0
|
|
END AS col_size
|
|
FROM INFORMATION_SCHEMA.COLUMNS cols
|
|
WHERE cols.TABLE_NAME = %s
|
|
) AS col_sizes;""",
|
|
(get_table_name(doctype),),
|
|
)
|
|
|
|
if est_row_size:
|
|
return int(est_row_size[0][0])
|
|
|
|
@contextmanager
|
|
def unbuffered_cursor(self):
|
|
from pymysql.cursors import SSCursor
|
|
|
|
try:
|
|
original_cursor = self._cursor
|
|
new_cursor = self._cursor = self._conn.cursor(SSCursor)
|
|
yield
|
|
finally:
|
|
self._cursor = original_cursor
|
|
new_cursor.close()
|