diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml index 29d48831f7..3e220bb9fc 100644 --- a/.github/actions/setup/action.yml +++ b/.github/actions/setup/action.yml @@ -103,7 +103,7 @@ runs: sudo apt -qq update sudo apt -qq remove mysql-server mysql-client - sudo apt -qq install libcups2-dev redis-server mariadb-client + sudo apt -qq install libcups2-dev redis-server mariadb-client libmariadb-dev wget -q -O /tmp/wkhtmltox.deb https://github.com/wkhtmltopdf/packaging/releases/download/0.12.6.1-2/wkhtmltox_0.12.6.1-2.jammy_amd64.deb sudo apt install /tmp/wkhtmltox.deb diff --git a/.github/helper/db/mariadb.json b/.github/helper/db/mariadb.json index 4b1f98e78f..e6bf518d77 100644 --- a/.github/helper/db/mariadb.json +++ b/.github/helper/db/mariadb.json @@ -14,6 +14,7 @@ "root_login": "root", "root_password": "db_root", "host_name": "http://test_site:8000", + "use_mysqlclient": 1, "monitor": 1, "server_script_enabled": true -} \ No newline at end of file +} diff --git a/frappe/__init__.py b/frappe/__init__.py index bd8068741b..f323f58f39 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -66,7 +66,8 @@ if TYPE_CHECKING: # pragma: no cover from werkzeug.wrappers import Request - from frappe.database.mariadb.database import MariaDBDatabase + from frappe.database.mariadb.database import MariaDBDatabase as PyMariaDBDatabase + from frappe.database.mariadb.mysqlclient import MariaDBDatabase from frappe.database.postgres.database import PostgresDatabase from frappe.email.doctype.email_queue.email_queue import EmailQueue from frappe.model.document import Document @@ -160,7 +161,7 @@ ResponseDict: TypeAlias = _dict[str, Any] # type: ignore[no-any-explicit] FlagsDict: TypeAlias = _dict[str, Any] # type: ignore[no-any-explicit] FormDict: TypeAlias = _dict[str, str] -db: LocalProxy[Union["MariaDBDatabase", "PostgresDatabase"]] = local("db") +db: LocalProxy[Union["PyMariaDBDatabase", "MariaDBDatabase", "PostgresDatabase"]] = local("db") qb: LocalProxy[Union["MariaDB", "Postgres"]] = local("qb") conf: LocalProxy[ConfType] = local("conf") form_dict: LocalProxy[FormDict] = local("form_dict") @@ -181,7 +182,7 @@ lang: LocalProxy[str] = local("lang") if TYPE_CHECKING: # pragma: no cover # trick because some type checkers fail to follow "RedisWrapper", etc (written as string literal) # trough a generic wrapper; seems to be a bug - db: MariaDBDatabase | PostgresDatabase + db: PyMariaDBDatabase | MariaDBDatabase | PostgresDatabase qb: MariaDB | Postgres conf: ConfType form_dict: FormDict @@ -287,6 +288,7 @@ def connect(site: str | None = None, db_name: str | None = None, set_admin_as_us "Instead, explicitly invoke frappe.init(site) prior to calling frappe.connect(), if initializing the site is necessary.", ) init(site) + if db_name: from frappe.deprecation_dumpster import deprecation_warning @@ -297,18 +299,24 @@ def connect(site: str | None = None, db_name: str | None = None, set_admin_as_us "Instead, explicitly invoke frappe.init(site) with the right config prior to calling frappe.connect(), if necessary.", ) - assert db_name or local.conf.db_user, "site must be fully initialized, db_user missing" - assert db_name or local.conf.db_name, "site must be fully initialized, db_name missing" - assert local.conf.db_password, "site must be fully initialized, db_password missing" + conf = local.conf + db_user = conf.db_user or db_name + db_name_ = conf.db_name or db_name + db_password = conf.db_password + + assert db_user, "site must be fully initialized, db_user missing" + assert db_name_, "site must be fully initialized, db_name missing" + assert db_password, "site must be fully initialized, db_password missing" local.db = get_db( - socket=local.conf.db_socket, - host=local.conf.db_host, - port=local.conf.db_port, - user=local.conf.db_user or db_name, - password=local.conf.db_password, - cur_db_name=local.conf.db_name or db_name, + socket=conf.db_socket, + host=conf.db_host, + port=conf.db_port, + user=db_user, + password=db_password, + cur_db_name=db_name_, ) + if set_admin_as_user: set_user("Administrator") diff --git a/frappe/database/__init__.py b/frappe/database/__init__.py index cf911f11e2..77cb050f7a 100644 --- a/frappe/database/__init__.py +++ b/frappe/database/__init__.py @@ -50,12 +50,20 @@ def drop_user_and_database(db_name, db_user): def get_db(socket=None, host=None, user=None, password=None, port=None, cur_db_name=None): import frappe - if frappe.conf.db_type == "postgres": + conf = frappe.local.conf + + if conf.db_type == "postgres": import frappe.database.postgres.database return frappe.database.postgres.database.PostgresDatabase( socket, host, user, password, port, cur_db_name ) + elif conf.use_mysqlclient: + import frappe.database.mariadb.mysqlclient + + return frappe.database.mariadb.mysqlclient.MariaDBDatabase( + socket, host, user, password, port, cur_db_name + ) else: import frappe.database.mariadb.database diff --git a/frappe/database/database.py b/frappe/database/database.py index 3a62e7fa14..5fbf9d5400 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -37,6 +37,8 @@ from frappe.utils import CallbackManager, cint, get_datetime, get_table_name, ge from frappe.utils import cast as cast_fieldtype if TYPE_CHECKING: + from MySQLdb.connections import Connection as MySQLdbConnection + from MySQLdb.cursors import Cursor as MySQLdbCursor from psycopg2 import connection as PostgresConnection from psycopg2 import cursor as PostgresCursor from pymysql.connections import Connection as MariadbConnection @@ -117,8 +119,8 @@ class Database: def connect(self): """Connects to a database as set in `site_config.json`.""" - self._conn: MariadbConnection | PostgresConnection = self.get_connection() - self._cursor: MariadbCursor | PostgresCursor = self._conn.cursor() + self._conn: MySQLdbConnection | MariadbConnection | PostgresConnection = self.get_connection() + self._cursor: MySQLdbCursor | MariadbCursor | PostgresCursor = self._conn.cursor() try: if execution_timeout := get_query_execution_timeout(): @@ -202,9 +204,15 @@ class Database: debug = debug or getattr(self, "debug", False) query = str(query) + if not run: return query + if explain: + if debug and is_query_type(query, "select"): + self.explain_query(query, values) + return + # remove whitespace / indentation from start and end of query query = query.strip() @@ -277,7 +285,7 @@ class Database: time_end = time() frappe.log(f"Execution time: {time_end - time_start:.2f} sec") - self.log_query(query, values, debug, explain) + self.log_query(query, values, debug) if auto_commit: self.commit() @@ -333,7 +341,6 @@ class Database: self, mogrified_query: str, debug: bool = False, - explain: bool = False, unmogrified_query: str = "", ) -> None: """Takes the query and logs it to various interfaces according to the settings.""" @@ -346,8 +353,6 @@ class Database: if debug: _query = _query or str(mogrified_query) - if explain and is_query_type(_query, "select"): - self.explain_query(_query) frappe.log(_query) if conf.logging == 2: @@ -364,14 +369,9 @@ class Database: _query = _query or str(mogrified_query) self.log_touched_tables(_query) - def log_query( - self, query: str, values: QueryValues = None, debug: bool = False, explain: bool = False - ) -> str: - # TODO: Use mogrify until MariaDB Connector/C 1.1 is released and we can fetch something - # like cursor._transformed_statement from the cursor object. We can also avoid setting - # mogrified_query if we don't need to log it. + def log_query(self, query: str, values: QueryValues = None, debug: bool = False) -> str: mogrified_query = self.lazy_mogrify(query, values) - self._log_query(mogrified_query, debug, explain, unmogrified_query=query) + self._log_query(mogrified_query, debug, query) return mogrified_query def mogrify(self, query: Query, values: QueryValues): diff --git a/frappe/database/mariadb/database.py b/frappe/database/mariadb/database.py index af9bce1e88..614df08399 100644 --- a/frappe/database/mariadb/database.py +++ b/frappe/database/mariadb/database.py @@ -1,4 +1,3 @@ -import re from contextlib import contextmanager import pymysql @@ -10,8 +9,6 @@ from frappe.database.database import Database from frappe.database.mariadb.schema import MariaDBTable from frappe.utils import UnicodeWithAttrs, cstr, get_datetime, get_table_name -_PARAM_COMP = re.compile(r"%\([\w]*\)s") - class MariaDBExceptionUtil: ProgrammingError = pymysql.ProgrammingError @@ -214,10 +211,11 @@ class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database): return db_size[0].get("database_size") - def log_query(self, query, values, debug, explain): - self.last_query = self._cursor._executed - self._log_query(self.last_query, debug, explain, query) - return self.last_query + def log_query(self, query, values, debug): + mogrified_query = self._cursor._executed + self.last_query = mogrified_query + self._log_query(mogrified_query, debug, query) + return mogrified_query def _clean_up(self): # PERF: Erase internal references of pymysql to trigger GC as soon as diff --git a/frappe/database/mariadb/mysqlclient.py b/frappe/database/mariadb/mysqlclient.py new file mode 100644 index 0000000000..0f132822b1 --- /dev/null +++ b/frappe/database/mariadb/mysqlclient.py @@ -0,0 +1,592 @@ +import datetime +from contextlib import contextmanager + +import MySQLdb +from MySQLdb._mysql import escape_string +from MySQLdb.constants import ER, FIELD_TYPE +from MySQLdb.converters import conversions + +import frappe +from frappe.database.database import Database +from frappe.database.mariadb.schema import MariaDBTable +from frappe.utils import get_datetime, get_table_name + +ER_STATEMENT_TIMEOUT = 1969 + + +class MariaDBExceptionUtil: + ProgrammingError = MySQLdb.ProgrammingError + TableMissingError = MySQLdb.ProgrammingError + OperationalError = MySQLdb.OperationalError + InternalError = MySQLdb.InternalError + SQLError = MySQLdb.ProgrammingError + DataError = MySQLdb.DataError + + # match SEQUENCE_RUN_OUT - https://mariadb.com/kb/en/mariadb-error-codes/ + SequenceGeneratorLimitExceeded = MySQLdb.OperationalError + + @staticmethod + def is_deadlocked(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.LOCK_DEADLOCK + + @staticmethod + def is_timedout(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.LOCK_WAIT_TIMEOUT + + @staticmethod + def is_read_only_mode_error(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.CANT_EXECUTE_IN_READ_ONLY_TRANSACTION + + @staticmethod + def is_table_missing(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.NO_SUCH_TABLE + + @staticmethod + def is_missing_table(e: MySQLdb.Error) -> bool: + return MariaDBDatabase.is_table_missing(e) + + @staticmethod + def is_missing_column(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.BAD_FIELD_ERROR + + @staticmethod + def is_duplicate_fieldname(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.DUP_FIELDNAME + + @staticmethod + def is_duplicate_entry(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.DUP_ENTRY + + @staticmethod + def is_access_denied(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.ACCESS_DENIED_ERROR + + @staticmethod + def cant_drop_field_or_key(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.CANT_DROP_FIELD_OR_KEY + + @staticmethod + def is_syntax_error(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.PARSE_ERROR + + @staticmethod + def is_statement_timeout(e: MySQLdb.Error) -> bool: + return e.args[0] == ER_STATEMENT_TIMEOUT + + @staticmethod + def is_data_too_long(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.DATA_TOO_LONG + + @staticmethod + def is_db_table_size_limit(e: MySQLdb.Error) -> bool: + return e.args[0] == ER.TOO_BIG_ROWSIZE + + @staticmethod + def is_primary_key_violation(e: MySQLdb.Error) -> bool: + return ( + isinstance(e, MySQLdb.IntegrityError) + and MariaDBExceptionUtil.is_duplicate_entry(e) + and "PRIMARY" in e.args[1] + ) + + @staticmethod + def is_unique_key_violation(e: MySQLdb.Error) -> bool: + return ( + isinstance(e, MySQLdb.IntegrityError) + and MariaDBExceptionUtil.is_duplicate_entry(e) + and "Duplicate" in e.args[1] + ) + + @staticmethod + def is_interface_error(e: MySQLdb.Error): + return isinstance(e, MySQLdb.InterfaceError) + + +class MariaDBConnectionUtil: + def get_connection(self): + conn = self._get_connection() + conn.auto_reconnect = True + return conn + + def _get_connection(self) -> "MySQLdb.Connection": + return self.create_connection() + + def create_connection(self): + return MySQLdb.connect(**self.get_connection_settings()) + + def set_execution_timeout(self, seconds: int): + self.sql("set session max_statement_time = %s", int(seconds)) + + def get_connection_settings(self) -> dict: + conn_settings = { + "user": self.user, + "conv": self.CONVERSION_MAP, + "charset": "utf8mb4", + "use_unicode": True, + } + + if self.cur_db_name: + conn_settings["database"] = self.cur_db_name + + if self.socket: + conn_settings["unix_socket"] = self.socket + else: + conn_settings["host"] = self.host + if self.port: + conn_settings["port"] = int(self.port) + + if self.password: + conn_settings["password"] = self.password + + if frappe.conf.local_infile: + conn_settings["local_infile"] = frappe.conf.local_infile + + if frappe.conf.db_ssl_ca and frappe.conf.db_ssl_cert and frappe.conf.db_ssl_key: + conn_settings["ssl"] = { + "ca": frappe.conf.db_ssl_ca, + "cert": frappe.conf.db_ssl_cert, + "key": frappe.conf.db_ssl_key, + } + + return conn_settings + + +### Converters + + +def escape_frozenset(obj, mapping=None): + return frappe.local.db._conn.literal(tuple(obj)) + + +# adapted from pymysql +def escape_timedelta(obj, mapping=None): + _seconds = obj.seconds + + if obj.microseconds: + fmt = "'{0:02d}:{1:02d}:{2:02d}.{3:06d}'" + else: + fmt = "'{0:02d}:{1:02d}:{2:02d}'" + return fmt.format( + (_seconds // 3600) % 24 + obj.days * 24, # hours + (_seconds // 60) % 60, # minutes + _seconds % 60, # seconds + obj.microseconds, # microseconds + ) + + +# adapted from pymysql +def escape_dict(obj, mapping=None): + raise TypeError("dict can not be used as parameter") + + +class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database): + REGEX_CHARACTER = "regexp" + CONVERSION_MAP = conversions | { + FIELD_TYPE.NEWDECIMAL: float, + FIELD_TYPE.DATETIME: get_datetime, + dict: escape_dict, + frozenset: escape_frozenset, + datetime.timedelta: escape_timedelta, # not handled as desired by MySQLdb + # no need to specify UnicodeWithAttrs, as it subclasses str - which is handled + } + default_port = "3306" + MAX_ROW_SIZE_LIMIT = 65_535 # bytes + + def setup_type_map(self): + self.db_type = "mariadb" + self.type_map = { + "Currency": ("decimal", "21,9"), + "Int": ("int", None), + "Long Int": ("bigint", "20"), + "Float": ("decimal", "21,9"), + "Percent": ("decimal", "21,9"), + "Check": ("tinyint", None), + "Small Text": ("text", ""), + "Long Text": ("longtext", ""), + "Code": ("longtext", ""), + "Text Editor": ("longtext", ""), + "Markdown Editor": ("longtext", ""), + "HTML Editor": ("longtext", ""), + "Date": ("date", ""), + "Datetime": ("datetime", "6"), + "Time": ("time", "6"), + "Text": ("text", ""), + "Data": ("varchar", self.VARCHAR_LEN), + "Link": ("varchar", self.VARCHAR_LEN), + "Dynamic Link": ("varchar", self.VARCHAR_LEN), + "Password": ("text", ""), + "Select": ("varchar", self.VARCHAR_LEN), + "Rating": ("decimal", "3,2"), + "Read Only": ("varchar", self.VARCHAR_LEN), + "Attach": ("text", ""), + "Attach Image": ("text", ""), + "Signature": ("longtext", ""), + "Color": ("varchar", self.VARCHAR_LEN), + "Barcode": ("longtext", ""), + "Geolocation": ("longtext", ""), + "Duration": ("decimal", "21,9"), + "Icon": ("varchar", self.VARCHAR_LEN), + "Phone": ("varchar", self.VARCHAR_LEN), + "Autocomplete": ("varchar", self.VARCHAR_LEN), + "JSON": ("json", ""), + } + + def get_database_size(self): + """Return database size in MB.""" + db_size = self.sql( + """ + SELECT `table_schema` as `database_name`, + SUM(`data_length` + `index_length`) / 1024 / 1024 AS `database_size` + FROM information_schema.tables WHERE `table_schema` = %s GROUP BY `table_schema` + """, + self.cur_db_name, + as_dict=True, + ) + + return db_size[0].get("database_size") + + def log_query(self, query, values, debug): + mogrified_query = self._cursor._executed.decode() + self.last_query = mogrified_query + self._log_query(mogrified_query, debug, query) + return mogrified_query + + def _clean_up(self): + # PERF: Erase internal references to trigger GC as soon as + # results are consumed. + self._cursor._rows = None + + @staticmethod + def escape(s, percent=True): + """Escape quotes and percent in given string.""" + + s = frappe.as_unicode(escape_string(s)).replace("`", "\\`") if s else "" + + # NOTE separating % escape, because % escape should only be done when using LIKE operator + # or when you use python format string to generate query that already has a %s + # for example: sql("select name from `tabUser` where name=%s and {0}".format(conditions), something) + # defaulting it to True, as this is the most frequent use case + # ideally we shouldn't have to use ESCAPE and strive to pass values via the values argument of sql + if percent: + s = s.replace("%", "%%") + + return "'" + s + "'" + + # column type + @staticmethod + def is_type_number(code): + return code == MySQLdb.NUMBER + + @staticmethod + def is_type_datetime(code): + return code == MySQLdb.DATETIME + + def rename_table(self, old_name: str, new_name: str) -> list | tuple: + old_name = get_table_name(old_name) + new_name = get_table_name(new_name) + return self.sql(f"RENAME TABLE `{old_name}` TO `{new_name}`") + + def describe(self, doctype: str) -> list | tuple: + table_name = get_table_name(doctype) + return self.sql(f"DESC `{table_name}`") + + def change_column_type( + self, doctype: str, column: str, type: str, nullable: bool = False + ) -> list | tuple: + table_name = get_table_name(doctype) + null_constraint = "NOT NULL" if not nullable else "" + return self.sql_ddl(f"ALTER TABLE `{table_name}` MODIFY `{column}` {type} {null_constraint}") + + def rename_column(self, doctype: str, old_column_name, new_column_name): + current_data_type = self.get_column_type(doctype, old_column_name) + + table_name = get_table_name(doctype) + + frappe.db.sql_ddl( + f"""ALTER TABLE `{table_name}` + CHANGE COLUMN `{old_column_name}` + `{new_column_name}` + {current_data_type}""" + # ^ Mariadb requires passing current data type again even if there's no change + # This requirement is gone from v10.5 + ) + + def create_auth_table(self): + self.sql_ddl( + """create table if not exists `__Auth` ( + `doctype` VARCHAR(140) NOT NULL, + `name` VARCHAR(255) NOT NULL, + `fieldname` VARCHAR(140) NOT NULL, + `password` TEXT NOT NULL, + `encrypted` TINYINT NOT NULL DEFAULT 0, + PRIMARY KEY (`doctype`, `name`, `fieldname`) + ) ENGINE=InnoDB ROW_FORMAT=DYNAMIC CHARACTER SET=utf8mb4 COLLATE=utf8mb4_unicode_ci""" + ) + + def create_global_search_table(self): + if "__global_search" not in self.get_tables(): + self.sql( + f"""create table __global_search( + doctype varchar(100), + name varchar({self.VARCHAR_LEN}), + title varchar({self.VARCHAR_LEN}), + content text, + fulltext(content), + route varchar({self.VARCHAR_LEN}), + published TINYINT not null default 0, + unique `doctype_name` (doctype, name)) + COLLATE=utf8mb4_unicode_ci + ENGINE=MyISAM + CHARACTER SET=utf8mb4""" + ) + + def create_user_settings_table(self): + self.sql_ddl( + """create table if not exists __UserSettings ( + `user` VARCHAR(180) NOT NULL, + `doctype` VARCHAR(180) NOT NULL, + `data` TEXT, + UNIQUE(user, doctype) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8""" + ) + + @staticmethod + def get_on_duplicate_update(): + return "ON DUPLICATE key UPDATE " + + def get_table_columns_description(self, table_name): + """Return list of columns with descriptions.""" + return self.sql( + f"""select + column_name as 'name', + column_type as 'type', + column_default as 'default', + COALESCE( + (select 1 + from information_schema.statistics + where table_name="{table_name}" + and column_name=columns.column_name + and NON_UNIQUE=1 + and Seq_in_index = 1 + limit 1 + ), 0) as 'index', + column_key = 'UNI' as 'unique', + (is_nullable = 'NO') AS 'not_nullable' + from information_schema.columns as columns + where table_name = '{table_name}' + and table_schema = '{frappe.db.cur_db_name}' """, + as_dict=1, + ) + + def get_column_type(self, doctype, column): + """Return column type from database.""" + information_schema = frappe.qb.Schema("information_schema") + table = get_table_name(doctype) + + return ( + frappe.qb.from_(information_schema.columns) + .select(information_schema.columns.column_type) + .where( + (information_schema.columns.table_name == table) + & (information_schema.columns.column_name == column) + & (information_schema.columns.table_schema == self.cur_db_name) + ) + .run(pluck=True)[0] + ) + + def has_index(self, table_name, index_name): + return self.sql( + f"""SHOW INDEX FROM `{table_name}` + WHERE Key_name='{index_name}'""" + ) + + def get_column_index(self, table_name: str, fieldname: str, unique: bool = False) -> frappe._dict | None: + """Check if column exists for a specific fields in specified order. + + This differs from db.has_index because it doesn't rely on index name but columns inside an + index. + """ + + indexes = self.sql( + f"""SHOW INDEX FROM `{table_name}` + WHERE Column_name = "{fieldname}" + AND Seq_in_index = 1 + AND Non_unique={int(not unique)} + AND Index_type != 'FULLTEXT' + """, + as_dict=True, + ) + + # Same index can be part of clustered index which contains more fields + # We don't want those. + for index in indexes: + clustered_index = self.sql( + f"""SHOW INDEX FROM `{table_name}` + WHERE Key_name = "{index.Key_name}" + AND Seq_in_index = 2 + """, + as_dict=True, + ) + if not clustered_index: + return index + + def add_index(self, doctype: str, fields: list, index_name: str | None = None): + """Creates an index with given fields if not already created. + Index name will be `fieldname1_fieldname2_index`""" + from frappe.custom.doctype.property_setter.property_setter import make_property_setter + + index_name = index_name or self.get_index_name(fields) + table_name = get_table_name(doctype) + if not self.has_index(table_name, index_name): + self.commit() + self.sql( + """ALTER TABLE `{}` + ADD INDEX IF NOT EXISTS `{}`({})""".format(table_name, index_name, ", ".join(fields)) + ) + # Ensure that DB migration doesn't clear this index, assuming this is manually added + # via code or console. + if len(fields) == 1 and not (frappe.flags.in_install or frappe.flags.in_migrate): + make_property_setter( + doctype, + fields[0], + property="search_index", + value="1", + property_type="Check", + for_doctype=False, # Applied on docfield + ) + + def add_unique(self, doctype, fields, constraint_name=None): + if isinstance(fields, str): + fields = [fields] + if not constraint_name: + constraint_name = "unique_" + "_".join(fields) + + if not self.sql( + """select CONSTRAINT_NAME from information_schema.TABLE_CONSTRAINTS + where table_name=%s and constraint_type='UNIQUE' and CONSTRAINT_NAME=%s""", + ("tab" + doctype, constraint_name), + ): + self.commit() + self.sql( + """alter table `tab{}` + add unique `{}`({})""".format(doctype, constraint_name, ", ".join(fields)) + ) + + def updatedb(self, doctype, meta=None): + """ + Syncs a `DocType` to the table + * creates if required + * updates columns + * updates indices + """ + res = self.sql("select issingle from `tabDocType` where name=%s", (doctype,)) + if not res: + raise Exception(f"Wrong doctype {doctype} in updatedb") + + if not res[0][0]: + db_table = MariaDBTable(doctype, meta) + db_table.validate() + + db_table.sync() + self.commit() + + def get_database_list(self): + return self.sql("SHOW DATABASES", pluck=True) + + def get_tables(self, cached=True): + """Return list of tables.""" + to_query = not cached + + if cached: + tables = frappe.client_cache.get_value("db_tables") + to_query = not tables + + if to_query: + information_schema = frappe.qb.Schema("information_schema") + + tables = ( + frappe.qb.from_(information_schema.tables) + .select(information_schema.tables.table_name) + .where(information_schema.tables.table_schema == frappe.db.cur_db_name) + .run(pluck=True) + ) + frappe.client_cache.set_value("db_tables", tables) + + return tables + + def get_row_size(self, doctype: str) -> int: + """Get estimated max row size of any table in bytes.""" + + # Query reused from this answer: https://dba.stackexchange.com/a/313889/274503 + # Modification: get values for particular table instead of full summary. + # Reference: https://mariadb.com/kb/en/data-type-storage-requirements/ + + est_row_size = frappe.db.sql( + """ + SELECT SUM(col_sizes.col_size) AS EST_MAX_ROW_SIZE + FROM ( + SELECT + cols.COLUMN_NAME, + CASE cols.DATA_TYPE + WHEN 'tinyint' THEN 1 + WHEN 'smallint' THEN 2 + WHEN 'mediumint' THEN 3 + WHEN 'int' THEN 4 + WHEN 'bigint' THEN 8 + WHEN 'float' THEN IF(cols.NUMERIC_PRECISION > 24, 8, 4) + WHEN 'double' THEN 8 + WHEN 'decimal' THEN ((cols.NUMERIC_PRECISION - cols.NUMERIC_SCALE) DIV 9)*4 + (cols.NUMERIC_SCALE DIV 9)*4 + CEIL(MOD(cols.NUMERIC_PRECISION - cols.NUMERIC_SCALE,9)/2) + CEIL(MOD(cols.NUMERIC_SCALE,9)/2) + WHEN 'bit' THEN (cols.NUMERIC_PRECISION + 7) DIV 8 + WHEN 'year' THEN 1 + WHEN 'date' THEN 3 + WHEN 'time' THEN 3 + CEIL(cols.DATETIME_PRECISION /2) + WHEN 'datetime' THEN 5 + CEIL(cols.DATETIME_PRECISION /2) + WHEN 'timestamp' THEN 4 + CEIL(cols.DATETIME_PRECISION /2) + WHEN 'char' THEN cols.CHARACTER_OCTET_LENGTH + WHEN 'binary' THEN cols.CHARACTER_OCTET_LENGTH + WHEN 'varchar' THEN IF(cols.CHARACTER_OCTET_LENGTH > 255, 2, 1) + cols.CHARACTER_OCTET_LENGTH + WHEN 'varbinary' THEN IF(cols.CHARACTER_OCTET_LENGTH > 255, 2, 1) + cols.CHARACTER_OCTET_LENGTH + WHEN 'tinyblob' THEN 9 + WHEN 'tinytext' THEN 9 + WHEN 'blob' THEN 10 + WHEN 'text' THEN 10 + WHEN 'mediumblob' THEN 11 + WHEN 'mediumtext' THEN 11 + WHEN 'longblob' THEN 12 + WHEN 'longtext' THEN 12 + WHEN 'enum' THEN 2 + WHEN 'set' THEN 8 + ELSE 0 + END AS col_size + FROM INFORMATION_SCHEMA.COLUMNS cols + WHERE cols.TABLE_NAME = %s + ) AS col_sizes;""", + (get_table_name(doctype),), + ) + + if est_row_size: + return int(est_row_size[0][0]) + + @contextmanager + def unbuffered_cursor(self): + from MySQLdb.cursors import SSCursor + + try: + if not self._conn: + self.connect() + + original_cursor = self._cursor + new_cursor = self._cursor = self._conn.cursor(SSCursor) + yield + finally: + self._cursor = original_cursor + new_cursor.close() + + def estimate_count(self, doctype: str): + """Get estimated count of total rows in a table.""" + from frappe.utils.data import cint + + table = get_table_name(doctype) + + count = self.sql("select table_rows from information_schema.tables where table_name = %s", table) + return cint(count[0][0]) if count else 0 diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index 0401fae25f..dade5e2298 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -1382,21 +1382,17 @@ class TestDbConnectWithEnvCredentials(IntegrationTestCase): frappe.init(self.current_site, force=True) frappe.connect() - with self.assertRaises(Exception) as cm: + with self.assertRaises(frappe.db.OperationalError) as cm: frappe.db.connect() - self.assertTrue(re.search(r"(host name|server on) [\"']iqx.local[\"']", str(cm.exception))) - # with wrong user name with set_env_variable("FRAPPE_DB_USER", "uname"): frappe.init(self.current_site, force=True) frappe.connect() - with self.assertRaises(Exception) as cm: + with self.assertRaises(frappe.db.OperationalError) as cm: frappe.db.connect() - self.assertTrue(re.search(r"user [\"']uname[\"']", str(cm.exception))) - # with wrong password with set_env_variable("FRAPPE_DB_PASSWORD", "pass"): frappe.init(self.current_site, force=True) @@ -1414,11 +1410,9 @@ class TestDbConnectWithEnvCredentials(IntegrationTestCase): frappe.init(self.current_site, force=True) frappe.connect() - with self.assertRaises(Exception) as cm: + with self.assertRaises(frappe.db.OperationalError) as cm: frappe.db.connect() - self.assertTrue(re.search("(port 1111 failed|Errno 111)", str(cm.exception))) - # now with configured settings without any influences from env # finally connect should work without any error (when no wrong credentials are given via ENV) frappe.init(self.current_site, force=True) diff --git a/pyproject.toml b/pyproject.toml index c002e170a8..103252d872 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "PyMySQL==1.1.1", "pypdf~=3.17.0", "PyPika==0.48.9", + "mysqlclient~=2.2.7", "PyQRCode~=1.2.1", "PyYAML~=6.0.1", "RestrictedPython~=8.0", @@ -134,7 +135,6 @@ test = [ "Faker~=18.10.1", "hypothesis~=6.77.0", "freezegun~=1.5.1", - "pdbpp~=0.10.3", ] [build-system]