feat: re-introduce mysqlclient 🚀 (#31719)
* feat: mysqlclient * fix: update error attrs * fix: decode mogrified query to unicode * fix: do some cleanup * chore: disable cleanup for now * fix: remove unnecessary call to as_unicode * test: skip perf test for now * fix: fallback to empty str * fix: unbuffered cursor support * fix: update converters and other changes * fix: add cleanup back * perf: improve timedelta converter * fix: dont attempt to run query when explain flag is set * test: cleanup tests * chore: remove commented code * perf: store conf as local var * chore: ensure sequence --------- Co-authored-by: Ankush Menat <ankush@frappe.io>
This commit is contained in:
parent
0e62819741
commit
b2cab51849
9 changed files with 646 additions and 45 deletions
2
.github/actions/setup/action.yml
vendored
2
.github/actions/setup/action.yml
vendored
|
|
@ -103,7 +103,7 @@ runs:
|
|||
|
||||
sudo apt -qq update
|
||||
sudo apt -qq remove mysql-server mysql-client
|
||||
sudo apt -qq install libcups2-dev redis-server mariadb-client
|
||||
sudo apt -qq install libcups2-dev redis-server mariadb-client libmariadb-dev
|
||||
|
||||
wget -q -O /tmp/wkhtmltox.deb https://github.com/wkhtmltopdf/packaging/releases/download/0.12.6.1-2/wkhtmltox_0.12.6.1-2.jammy_amd64.deb
|
||||
sudo apt install /tmp/wkhtmltox.deb
|
||||
|
|
|
|||
3
.github/helper/db/mariadb.json
vendored
3
.github/helper/db/mariadb.json
vendored
|
|
@ -14,6 +14,7 @@
|
|||
"root_login": "root",
|
||||
"root_password": "db_root",
|
||||
"host_name": "http://test_site:8000",
|
||||
"use_mysqlclient": 1,
|
||||
"monitor": 1,
|
||||
"server_script_enabled": true
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ if TYPE_CHECKING: # pragma: no cover
|
|||
|
||||
from werkzeug.wrappers import Request
|
||||
|
||||
from frappe.database.mariadb.database import MariaDBDatabase
|
||||
from frappe.database.mariadb.database import MariaDBDatabase as PyMariaDBDatabase
|
||||
from frappe.database.mariadb.mysqlclient import MariaDBDatabase
|
||||
from frappe.database.postgres.database import PostgresDatabase
|
||||
from frappe.email.doctype.email_queue.email_queue import EmailQueue
|
||||
from frappe.model.document import Document
|
||||
|
|
@ -160,7 +161,7 @@ ResponseDict: TypeAlias = _dict[str, Any] # type: ignore[no-any-explicit]
|
|||
FlagsDict: TypeAlias = _dict[str, Any] # type: ignore[no-any-explicit]
|
||||
FormDict: TypeAlias = _dict[str, str]
|
||||
|
||||
db: LocalProxy[Union["MariaDBDatabase", "PostgresDatabase"]] = local("db")
|
||||
db: LocalProxy[Union["PyMariaDBDatabase", "MariaDBDatabase", "PostgresDatabase"]] = local("db")
|
||||
qb: LocalProxy[Union["MariaDB", "Postgres"]] = local("qb")
|
||||
conf: LocalProxy[ConfType] = local("conf")
|
||||
form_dict: LocalProxy[FormDict] = local("form_dict")
|
||||
|
|
@ -181,7 +182,7 @@ lang: LocalProxy[str] = local("lang")
|
|||
if TYPE_CHECKING: # pragma: no cover
|
||||
# trick because some type checkers fail to follow "RedisWrapper", etc (written as string literal)
|
||||
# trough a generic wrapper; seems to be a bug
|
||||
db: MariaDBDatabase | PostgresDatabase
|
||||
db: PyMariaDBDatabase | MariaDBDatabase | PostgresDatabase
|
||||
qb: MariaDB | Postgres
|
||||
conf: ConfType
|
||||
form_dict: FormDict
|
||||
|
|
@ -287,6 +288,7 @@ def connect(site: str | None = None, db_name: str | None = None, set_admin_as_us
|
|||
"Instead, explicitly invoke frappe.init(site) prior to calling frappe.connect(), if initializing the site is necessary.",
|
||||
)
|
||||
init(site)
|
||||
|
||||
if db_name:
|
||||
from frappe.deprecation_dumpster import deprecation_warning
|
||||
|
||||
|
|
@ -297,18 +299,24 @@ def connect(site: str | None = None, db_name: str | None = None, set_admin_as_us
|
|||
"Instead, explicitly invoke frappe.init(site) with the right config prior to calling frappe.connect(), if necessary.",
|
||||
)
|
||||
|
||||
assert db_name or local.conf.db_user, "site must be fully initialized, db_user missing"
|
||||
assert db_name or local.conf.db_name, "site must be fully initialized, db_name missing"
|
||||
assert local.conf.db_password, "site must be fully initialized, db_password missing"
|
||||
conf = local.conf
|
||||
db_user = conf.db_user or db_name
|
||||
db_name_ = conf.db_name or db_name
|
||||
db_password = conf.db_password
|
||||
|
||||
assert db_user, "site must be fully initialized, db_user missing"
|
||||
assert db_name_, "site must be fully initialized, db_name missing"
|
||||
assert db_password, "site must be fully initialized, db_password missing"
|
||||
|
||||
local.db = get_db(
|
||||
socket=local.conf.db_socket,
|
||||
host=local.conf.db_host,
|
||||
port=local.conf.db_port,
|
||||
user=local.conf.db_user or db_name,
|
||||
password=local.conf.db_password,
|
||||
cur_db_name=local.conf.db_name or db_name,
|
||||
socket=conf.db_socket,
|
||||
host=conf.db_host,
|
||||
port=conf.db_port,
|
||||
user=db_user,
|
||||
password=db_password,
|
||||
cur_db_name=db_name_,
|
||||
)
|
||||
|
||||
if set_admin_as_user:
|
||||
set_user("Administrator")
|
||||
|
||||
|
|
|
|||
|
|
@ -50,12 +50,20 @@ def drop_user_and_database(db_name, db_user):
|
|||
def get_db(socket=None, host=None, user=None, password=None, port=None, cur_db_name=None):
|
||||
import frappe
|
||||
|
||||
if frappe.conf.db_type == "postgres":
|
||||
conf = frappe.local.conf
|
||||
|
||||
if conf.db_type == "postgres":
|
||||
import frappe.database.postgres.database
|
||||
|
||||
return frappe.database.postgres.database.PostgresDatabase(
|
||||
socket, host, user, password, port, cur_db_name
|
||||
)
|
||||
elif conf.use_mysqlclient:
|
||||
import frappe.database.mariadb.mysqlclient
|
||||
|
||||
return frappe.database.mariadb.mysqlclient.MariaDBDatabase(
|
||||
socket, host, user, password, port, cur_db_name
|
||||
)
|
||||
else:
|
||||
import frappe.database.mariadb.database
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ from frappe.utils import CallbackManager, cint, get_datetime, get_table_name, ge
|
|||
from frappe.utils import cast as cast_fieldtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from MySQLdb.connections import Connection as MySQLdbConnection
|
||||
from MySQLdb.cursors import Cursor as MySQLdbCursor
|
||||
from psycopg2 import connection as PostgresConnection
|
||||
from psycopg2 import cursor as PostgresCursor
|
||||
from pymysql.connections import Connection as MariadbConnection
|
||||
|
|
@ -117,8 +119,8 @@ class Database:
|
|||
|
||||
def connect(self):
|
||||
"""Connects to a database as set in `site_config.json`."""
|
||||
self._conn: MariadbConnection | PostgresConnection = self.get_connection()
|
||||
self._cursor: MariadbCursor | PostgresCursor = self._conn.cursor()
|
||||
self._conn: MySQLdbConnection | MariadbConnection | PostgresConnection = self.get_connection()
|
||||
self._cursor: MySQLdbCursor | MariadbCursor | PostgresCursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
if execution_timeout := get_query_execution_timeout():
|
||||
|
|
@ -202,9 +204,15 @@ class Database:
|
|||
|
||||
debug = debug or getattr(self, "debug", False)
|
||||
query = str(query)
|
||||
|
||||
if not run:
|
||||
return query
|
||||
|
||||
if explain:
|
||||
if debug and is_query_type(query, "select"):
|
||||
self.explain_query(query, values)
|
||||
return
|
||||
|
||||
# remove whitespace / indentation from start and end of query
|
||||
query = query.strip()
|
||||
|
||||
|
|
@ -277,7 +285,7 @@ class Database:
|
|||
time_end = time()
|
||||
frappe.log(f"Execution time: {time_end - time_start:.2f} sec")
|
||||
|
||||
self.log_query(query, values, debug, explain)
|
||||
self.log_query(query, values, debug)
|
||||
|
||||
if auto_commit:
|
||||
self.commit()
|
||||
|
|
@ -333,7 +341,6 @@ class Database:
|
|||
self,
|
||||
mogrified_query: str,
|
||||
debug: bool = False,
|
||||
explain: bool = False,
|
||||
unmogrified_query: str = "",
|
||||
) -> None:
|
||||
"""Takes the query and logs it to various interfaces according to the settings."""
|
||||
|
|
@ -346,8 +353,6 @@ class Database:
|
|||
|
||||
if debug:
|
||||
_query = _query or str(mogrified_query)
|
||||
if explain and is_query_type(_query, "select"):
|
||||
self.explain_query(_query)
|
||||
frappe.log(_query)
|
||||
|
||||
if conf.logging == 2:
|
||||
|
|
@ -364,14 +369,9 @@ class Database:
|
|||
_query = _query or str(mogrified_query)
|
||||
self.log_touched_tables(_query)
|
||||
|
||||
def log_query(
|
||||
self, query: str, values: QueryValues = None, debug: bool = False, explain: bool = False
|
||||
) -> str:
|
||||
# TODO: Use mogrify until MariaDB Connector/C 1.1 is released and we can fetch something
|
||||
# like cursor._transformed_statement from the cursor object. We can also avoid setting
|
||||
# mogrified_query if we don't need to log it.
|
||||
def log_query(self, query: str, values: QueryValues = None, debug: bool = False) -> str:
|
||||
mogrified_query = self.lazy_mogrify(query, values)
|
||||
self._log_query(mogrified_query, debug, explain, unmogrified_query=query)
|
||||
self._log_query(mogrified_query, debug, query)
|
||||
return mogrified_query
|
||||
|
||||
def mogrify(self, query: Query, values: QueryValues):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import re
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pymysql
|
||||
|
|
@ -10,8 +9,6 @@ from frappe.database.database import Database
|
|||
from frappe.database.mariadb.schema import MariaDBTable
|
||||
from frappe.utils import UnicodeWithAttrs, cstr, get_datetime, get_table_name
|
||||
|
||||
_PARAM_COMP = re.compile(r"%\([\w]*\)s")
|
||||
|
||||
|
||||
class MariaDBExceptionUtil:
|
||||
ProgrammingError = pymysql.ProgrammingError
|
||||
|
|
@ -214,10 +211,11 @@ class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database):
|
|||
|
||||
return db_size[0].get("database_size")
|
||||
|
||||
def log_query(self, query, values, debug, explain):
|
||||
self.last_query = self._cursor._executed
|
||||
self._log_query(self.last_query, debug, explain, query)
|
||||
return self.last_query
|
||||
def log_query(self, query, values, debug):
|
||||
mogrified_query = self._cursor._executed
|
||||
self.last_query = mogrified_query
|
||||
self._log_query(mogrified_query, debug, query)
|
||||
return mogrified_query
|
||||
|
||||
def _clean_up(self):
|
||||
# PERF: Erase internal references of pymysql to trigger GC as soon as
|
||||
|
|
|
|||
592
frappe/database/mariadb/mysqlclient.py
Normal file
592
frappe/database/mariadb/mysqlclient.py
Normal file
|
|
@ -0,0 +1,592 @@
|
|||
import datetime
|
||||
from contextlib import contextmanager
|
||||
|
||||
import MySQLdb
|
||||
from MySQLdb._mysql import escape_string
|
||||
from MySQLdb.constants import ER, FIELD_TYPE
|
||||
from MySQLdb.converters import conversions
|
||||
|
||||
import frappe
|
||||
from frappe.database.database import Database
|
||||
from frappe.database.mariadb.schema import MariaDBTable
|
||||
from frappe.utils import get_datetime, get_table_name
|
||||
|
||||
ER_STATEMENT_TIMEOUT = 1969
|
||||
|
||||
|
||||
class MariaDBExceptionUtil:
|
||||
ProgrammingError = MySQLdb.ProgrammingError
|
||||
TableMissingError = MySQLdb.ProgrammingError
|
||||
OperationalError = MySQLdb.OperationalError
|
||||
InternalError = MySQLdb.InternalError
|
||||
SQLError = MySQLdb.ProgrammingError
|
||||
DataError = MySQLdb.DataError
|
||||
|
||||
# match SEQUENCE_RUN_OUT - https://mariadb.com/kb/en/mariadb-error-codes/
|
||||
SequenceGeneratorLimitExceeded = MySQLdb.OperationalError
|
||||
|
||||
@staticmethod
|
||||
def is_deadlocked(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.LOCK_DEADLOCK
|
||||
|
||||
@staticmethod
|
||||
def is_timedout(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.LOCK_WAIT_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def is_read_only_mode_error(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
|
||||
|
||||
@staticmethod
|
||||
def is_table_missing(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.NO_SUCH_TABLE
|
||||
|
||||
@staticmethod
|
||||
def is_missing_table(e: MySQLdb.Error) -> bool:
|
||||
return MariaDBDatabase.is_table_missing(e)
|
||||
|
||||
@staticmethod
|
||||
def is_missing_column(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.BAD_FIELD_ERROR
|
||||
|
||||
@staticmethod
|
||||
def is_duplicate_fieldname(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.DUP_FIELDNAME
|
||||
|
||||
@staticmethod
|
||||
def is_duplicate_entry(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.DUP_ENTRY
|
||||
|
||||
@staticmethod
|
||||
def is_access_denied(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.ACCESS_DENIED_ERROR
|
||||
|
||||
@staticmethod
|
||||
def cant_drop_field_or_key(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.CANT_DROP_FIELD_OR_KEY
|
||||
|
||||
@staticmethod
|
||||
def is_syntax_error(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.PARSE_ERROR
|
||||
|
||||
@staticmethod
|
||||
def is_statement_timeout(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER_STATEMENT_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def is_data_too_long(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.DATA_TOO_LONG
|
||||
|
||||
@staticmethod
|
||||
def is_db_table_size_limit(e: MySQLdb.Error) -> bool:
|
||||
return e.args[0] == ER.TOO_BIG_ROWSIZE
|
||||
|
||||
@staticmethod
|
||||
def is_primary_key_violation(e: MySQLdb.Error) -> bool:
|
||||
return (
|
||||
isinstance(e, MySQLdb.IntegrityError)
|
||||
and MariaDBExceptionUtil.is_duplicate_entry(e)
|
||||
and "PRIMARY" in e.args[1]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_unique_key_violation(e: MySQLdb.Error) -> bool:
|
||||
return (
|
||||
isinstance(e, MySQLdb.IntegrityError)
|
||||
and MariaDBExceptionUtil.is_duplicate_entry(e)
|
||||
and "Duplicate" in e.args[1]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_interface_error(e: MySQLdb.Error):
|
||||
return isinstance(e, MySQLdb.InterfaceError)
|
||||
|
||||
|
||||
class MariaDBConnectionUtil:
|
||||
def get_connection(self):
|
||||
conn = self._get_connection()
|
||||
conn.auto_reconnect = True
|
||||
return conn
|
||||
|
||||
def _get_connection(self) -> "MySQLdb.Connection":
|
||||
return self.create_connection()
|
||||
|
||||
def create_connection(self):
|
||||
return MySQLdb.connect(**self.get_connection_settings())
|
||||
|
||||
def set_execution_timeout(self, seconds: int):
|
||||
self.sql("set session max_statement_time = %s", int(seconds))
|
||||
|
||||
def get_connection_settings(self) -> dict:
|
||||
conn_settings = {
|
||||
"user": self.user,
|
||||
"conv": self.CONVERSION_MAP,
|
||||
"charset": "utf8mb4",
|
||||
"use_unicode": True,
|
||||
}
|
||||
|
||||
if self.cur_db_name:
|
||||
conn_settings["database"] = self.cur_db_name
|
||||
|
||||
if self.socket:
|
||||
conn_settings["unix_socket"] = self.socket
|
||||
else:
|
||||
conn_settings["host"] = self.host
|
||||
if self.port:
|
||||
conn_settings["port"] = int(self.port)
|
||||
|
||||
if self.password:
|
||||
conn_settings["password"] = self.password
|
||||
|
||||
if frappe.conf.local_infile:
|
||||
conn_settings["local_infile"] = frappe.conf.local_infile
|
||||
|
||||
if frappe.conf.db_ssl_ca and frappe.conf.db_ssl_cert and frappe.conf.db_ssl_key:
|
||||
conn_settings["ssl"] = {
|
||||
"ca": frappe.conf.db_ssl_ca,
|
||||
"cert": frappe.conf.db_ssl_cert,
|
||||
"key": frappe.conf.db_ssl_key,
|
||||
}
|
||||
|
||||
return conn_settings
|
||||
|
||||
|
||||
### Converters
|
||||
|
||||
|
||||
def escape_frozenset(obj, mapping=None):
|
||||
return frappe.local.db._conn.literal(tuple(obj))
|
||||
|
||||
|
||||
# adapted from pymysql
|
||||
def escape_timedelta(obj, mapping=None):
|
||||
_seconds = obj.seconds
|
||||
|
||||
if obj.microseconds:
|
||||
fmt = "'{0:02d}:{1:02d}:{2:02d}.{3:06d}'"
|
||||
else:
|
||||
fmt = "'{0:02d}:{1:02d}:{2:02d}'"
|
||||
return fmt.format(
|
||||
(_seconds // 3600) % 24 + obj.days * 24, # hours
|
||||
(_seconds // 60) % 60, # minutes
|
||||
_seconds % 60, # seconds
|
||||
obj.microseconds, # microseconds
|
||||
)
|
||||
|
||||
|
||||
# adapted from pymysql
|
||||
def escape_dict(obj, mapping=None):
|
||||
raise TypeError("dict can not be used as parameter")
|
||||
|
||||
|
||||
class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database):
|
||||
REGEX_CHARACTER = "regexp"
|
||||
CONVERSION_MAP = conversions | {
|
||||
FIELD_TYPE.NEWDECIMAL: float,
|
||||
FIELD_TYPE.DATETIME: get_datetime,
|
||||
dict: escape_dict,
|
||||
frozenset: escape_frozenset,
|
||||
datetime.timedelta: escape_timedelta, # not handled as desired by MySQLdb
|
||||
# no need to specify UnicodeWithAttrs, as it subclasses str - which is handled
|
||||
}
|
||||
default_port = "3306"
|
||||
MAX_ROW_SIZE_LIMIT = 65_535 # bytes
|
||||
|
||||
def setup_type_map(self):
|
||||
self.db_type = "mariadb"
|
||||
self.type_map = {
|
||||
"Currency": ("decimal", "21,9"),
|
||||
"Int": ("int", None),
|
||||
"Long Int": ("bigint", "20"),
|
||||
"Float": ("decimal", "21,9"),
|
||||
"Percent": ("decimal", "21,9"),
|
||||
"Check": ("tinyint", None),
|
||||
"Small Text": ("text", ""),
|
||||
"Long Text": ("longtext", ""),
|
||||
"Code": ("longtext", ""),
|
||||
"Text Editor": ("longtext", ""),
|
||||
"Markdown Editor": ("longtext", ""),
|
||||
"HTML Editor": ("longtext", ""),
|
||||
"Date": ("date", ""),
|
||||
"Datetime": ("datetime", "6"),
|
||||
"Time": ("time", "6"),
|
||||
"Text": ("text", ""),
|
||||
"Data": ("varchar", self.VARCHAR_LEN),
|
||||
"Link": ("varchar", self.VARCHAR_LEN),
|
||||
"Dynamic Link": ("varchar", self.VARCHAR_LEN),
|
||||
"Password": ("text", ""),
|
||||
"Select": ("varchar", self.VARCHAR_LEN),
|
||||
"Rating": ("decimal", "3,2"),
|
||||
"Read Only": ("varchar", self.VARCHAR_LEN),
|
||||
"Attach": ("text", ""),
|
||||
"Attach Image": ("text", ""),
|
||||
"Signature": ("longtext", ""),
|
||||
"Color": ("varchar", self.VARCHAR_LEN),
|
||||
"Barcode": ("longtext", ""),
|
||||
"Geolocation": ("longtext", ""),
|
||||
"Duration": ("decimal", "21,9"),
|
||||
"Icon": ("varchar", self.VARCHAR_LEN),
|
||||
"Phone": ("varchar", self.VARCHAR_LEN),
|
||||
"Autocomplete": ("varchar", self.VARCHAR_LEN),
|
||||
"JSON": ("json", ""),
|
||||
}
|
||||
|
||||
def get_database_size(self):
|
||||
"""Return database size in MB."""
|
||||
db_size = self.sql(
|
||||
"""
|
||||
SELECT `table_schema` as `database_name`,
|
||||
SUM(`data_length` + `index_length`) / 1024 / 1024 AS `database_size`
|
||||
FROM information_schema.tables WHERE `table_schema` = %s GROUP BY `table_schema`
|
||||
""",
|
||||
self.cur_db_name,
|
||||
as_dict=True,
|
||||
)
|
||||
|
||||
return db_size[0].get("database_size")
|
||||
|
||||
def log_query(self, query, values, debug):
|
||||
mogrified_query = self._cursor._executed.decode()
|
||||
self.last_query = mogrified_query
|
||||
self._log_query(mogrified_query, debug, query)
|
||||
return mogrified_query
|
||||
|
||||
def _clean_up(self):
|
||||
# PERF: Erase internal references to trigger GC as soon as
|
||||
# results are consumed.
|
||||
self._cursor._rows = None
|
||||
|
||||
@staticmethod
|
||||
def escape(s, percent=True):
|
||||
"""Escape quotes and percent in given string."""
|
||||
|
||||
s = frappe.as_unicode(escape_string(s)).replace("`", "\\`") if s else ""
|
||||
|
||||
# NOTE separating % escape, because % escape should only be done when using LIKE operator
|
||||
# or when you use python format string to generate query that already has a %s
|
||||
# for example: sql("select name from `tabUser` where name=%s and {0}".format(conditions), something)
|
||||
# defaulting it to True, as this is the most frequent use case
|
||||
# ideally we shouldn't have to use ESCAPE and strive to pass values via the values argument of sql
|
||||
if percent:
|
||||
s = s.replace("%", "%%")
|
||||
|
||||
return "'" + s + "'"
|
||||
|
||||
# column type
|
||||
@staticmethod
|
||||
def is_type_number(code):
|
||||
return code == MySQLdb.NUMBER
|
||||
|
||||
@staticmethod
|
||||
def is_type_datetime(code):
|
||||
return code == MySQLdb.DATETIME
|
||||
|
||||
def rename_table(self, old_name: str, new_name: str) -> list | tuple:
|
||||
old_name = get_table_name(old_name)
|
||||
new_name = get_table_name(new_name)
|
||||
return self.sql(f"RENAME TABLE `{old_name}` TO `{new_name}`")
|
||||
|
||||
def describe(self, doctype: str) -> list | tuple:
|
||||
table_name = get_table_name(doctype)
|
||||
return self.sql(f"DESC `{table_name}`")
|
||||
|
||||
def change_column_type(
|
||||
self, doctype: str, column: str, type: str, nullable: bool = False
|
||||
) -> list | tuple:
|
||||
table_name = get_table_name(doctype)
|
||||
null_constraint = "NOT NULL" if not nullable else ""
|
||||
return self.sql_ddl(f"ALTER TABLE `{table_name}` MODIFY `{column}` {type} {null_constraint}")
|
||||
|
||||
def rename_column(self, doctype: str, old_column_name, new_column_name):
|
||||
current_data_type = self.get_column_type(doctype, old_column_name)
|
||||
|
||||
table_name = get_table_name(doctype)
|
||||
|
||||
frappe.db.sql_ddl(
|
||||
f"""ALTER TABLE `{table_name}`
|
||||
CHANGE COLUMN `{old_column_name}`
|
||||
`{new_column_name}`
|
||||
{current_data_type}"""
|
||||
# ^ Mariadb requires passing current data type again even if there's no change
|
||||
# This requirement is gone from v10.5
|
||||
)
|
||||
|
||||
def create_auth_table(self):
|
||||
self.sql_ddl(
|
||||
"""create table if not exists `__Auth` (
|
||||
`doctype` VARCHAR(140) NOT NULL,
|
||||
`name` VARCHAR(255) NOT NULL,
|
||||
`fieldname` VARCHAR(140) NOT NULL,
|
||||
`password` TEXT NOT NULL,
|
||||
`encrypted` TINYINT NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (`doctype`, `name`, `fieldname`)
|
||||
) ENGINE=InnoDB ROW_FORMAT=DYNAMIC CHARACTER SET=utf8mb4 COLLATE=utf8mb4_unicode_ci"""
|
||||
)
|
||||
|
||||
def create_global_search_table(self):
|
||||
if "__global_search" not in self.get_tables():
|
||||
self.sql(
|
||||
f"""create table __global_search(
|
||||
doctype varchar(100),
|
||||
name varchar({self.VARCHAR_LEN}),
|
||||
title varchar({self.VARCHAR_LEN}),
|
||||
content text,
|
||||
fulltext(content),
|
||||
route varchar({self.VARCHAR_LEN}),
|
||||
published TINYINT not null default 0,
|
||||
unique `doctype_name` (doctype, name))
|
||||
COLLATE=utf8mb4_unicode_ci
|
||||
ENGINE=MyISAM
|
||||
CHARACTER SET=utf8mb4"""
|
||||
)
|
||||
|
||||
def create_user_settings_table(self):
|
||||
self.sql_ddl(
|
||||
"""create table if not exists __UserSettings (
|
||||
`user` VARCHAR(180) NOT NULL,
|
||||
`doctype` VARCHAR(180) NOT NULL,
|
||||
`data` TEXT,
|
||||
UNIQUE(user, doctype)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8"""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_on_duplicate_update():
|
||||
return "ON DUPLICATE key UPDATE "
|
||||
|
||||
def get_table_columns_description(self, table_name):
|
||||
"""Return list of columns with descriptions."""
|
||||
return self.sql(
|
||||
f"""select
|
||||
column_name as 'name',
|
||||
column_type as 'type',
|
||||
column_default as 'default',
|
||||
COALESCE(
|
||||
(select 1
|
||||
from information_schema.statistics
|
||||
where table_name="{table_name}"
|
||||
and column_name=columns.column_name
|
||||
and NON_UNIQUE=1
|
||||
and Seq_in_index = 1
|
||||
limit 1
|
||||
), 0) as 'index',
|
||||
column_key = 'UNI' as 'unique',
|
||||
(is_nullable = 'NO') AS 'not_nullable'
|
||||
from information_schema.columns as columns
|
||||
where table_name = '{table_name}'
|
||||
and table_schema = '{frappe.db.cur_db_name}' """,
|
||||
as_dict=1,
|
||||
)
|
||||
|
||||
def get_column_type(self, doctype, column):
|
||||
"""Return column type from database."""
|
||||
information_schema = frappe.qb.Schema("information_schema")
|
||||
table = get_table_name(doctype)
|
||||
|
||||
return (
|
||||
frappe.qb.from_(information_schema.columns)
|
||||
.select(information_schema.columns.column_type)
|
||||
.where(
|
||||
(information_schema.columns.table_name == table)
|
||||
& (information_schema.columns.column_name == column)
|
||||
& (information_schema.columns.table_schema == self.cur_db_name)
|
||||
)
|
||||
.run(pluck=True)[0]
|
||||
)
|
||||
|
||||
def has_index(self, table_name, index_name):
|
||||
return self.sql(
|
||||
f"""SHOW INDEX FROM `{table_name}`
|
||||
WHERE Key_name='{index_name}'"""
|
||||
)
|
||||
|
||||
def get_column_index(self, table_name: str, fieldname: str, unique: bool = False) -> frappe._dict | None:
|
||||
"""Check if column exists for a specific fields in specified order.
|
||||
|
||||
This differs from db.has_index because it doesn't rely on index name but columns inside an
|
||||
index.
|
||||
"""
|
||||
|
||||
indexes = self.sql(
|
||||
f"""SHOW INDEX FROM `{table_name}`
|
||||
WHERE Column_name = "{fieldname}"
|
||||
AND Seq_in_index = 1
|
||||
AND Non_unique={int(not unique)}
|
||||
AND Index_type != 'FULLTEXT'
|
||||
""",
|
||||
as_dict=True,
|
||||
)
|
||||
|
||||
# Same index can be part of clustered index which contains more fields
|
||||
# We don't want those.
|
||||
for index in indexes:
|
||||
clustered_index = self.sql(
|
||||
f"""SHOW INDEX FROM `{table_name}`
|
||||
WHERE Key_name = "{index.Key_name}"
|
||||
AND Seq_in_index = 2
|
||||
""",
|
||||
as_dict=True,
|
||||
)
|
||||
if not clustered_index:
|
||||
return index
|
||||
|
||||
def add_index(self, doctype: str, fields: list, index_name: str | None = None):
|
||||
"""Creates an index with given fields if not already created.
|
||||
Index name will be `fieldname1_fieldname2_index`"""
|
||||
from frappe.custom.doctype.property_setter.property_setter import make_property_setter
|
||||
|
||||
index_name = index_name or self.get_index_name(fields)
|
||||
table_name = get_table_name(doctype)
|
||||
if not self.has_index(table_name, index_name):
|
||||
self.commit()
|
||||
self.sql(
|
||||
"""ALTER TABLE `{}`
|
||||
ADD INDEX IF NOT EXISTS `{}`({})""".format(table_name, index_name, ", ".join(fields))
|
||||
)
|
||||
# Ensure that DB migration doesn't clear this index, assuming this is manually added
|
||||
# via code or console.
|
||||
if len(fields) == 1 and not (frappe.flags.in_install or frappe.flags.in_migrate):
|
||||
make_property_setter(
|
||||
doctype,
|
||||
fields[0],
|
||||
property="search_index",
|
||||
value="1",
|
||||
property_type="Check",
|
||||
for_doctype=False, # Applied on docfield
|
||||
)
|
||||
|
||||
def add_unique(self, doctype, fields, constraint_name=None):
|
||||
if isinstance(fields, str):
|
||||
fields = [fields]
|
||||
if not constraint_name:
|
||||
constraint_name = "unique_" + "_".join(fields)
|
||||
|
||||
if not self.sql(
|
||||
"""select CONSTRAINT_NAME from information_schema.TABLE_CONSTRAINTS
|
||||
where table_name=%s and constraint_type='UNIQUE' and CONSTRAINT_NAME=%s""",
|
||||
("tab" + doctype, constraint_name),
|
||||
):
|
||||
self.commit()
|
||||
self.sql(
|
||||
"""alter table `tab{}`
|
||||
add unique `{}`({})""".format(doctype, constraint_name, ", ".join(fields))
|
||||
)
|
||||
|
||||
def updatedb(self, doctype, meta=None):
|
||||
"""
|
||||
Syncs a `DocType` to the table
|
||||
* creates if required
|
||||
* updates columns
|
||||
* updates indices
|
||||
"""
|
||||
res = self.sql("select issingle from `tabDocType` where name=%s", (doctype,))
|
||||
if not res:
|
||||
raise Exception(f"Wrong doctype {doctype} in updatedb")
|
||||
|
||||
if not res[0][0]:
|
||||
db_table = MariaDBTable(doctype, meta)
|
||||
db_table.validate()
|
||||
|
||||
db_table.sync()
|
||||
self.commit()
|
||||
|
||||
def get_database_list(self):
|
||||
return self.sql("SHOW DATABASES", pluck=True)
|
||||
|
||||
def get_tables(self, cached=True):
|
||||
"""Return list of tables."""
|
||||
to_query = not cached
|
||||
|
||||
if cached:
|
||||
tables = frappe.client_cache.get_value("db_tables")
|
||||
to_query = not tables
|
||||
|
||||
if to_query:
|
||||
information_schema = frappe.qb.Schema("information_schema")
|
||||
|
||||
tables = (
|
||||
frappe.qb.from_(information_schema.tables)
|
||||
.select(information_schema.tables.table_name)
|
||||
.where(information_schema.tables.table_schema == frappe.db.cur_db_name)
|
||||
.run(pluck=True)
|
||||
)
|
||||
frappe.client_cache.set_value("db_tables", tables)
|
||||
|
||||
return tables
|
||||
|
||||
def get_row_size(self, doctype: str) -> int:
|
||||
"""Get estimated max row size of any table in bytes."""
|
||||
|
||||
# Query reused from this answer: https://dba.stackexchange.com/a/313889/274503
|
||||
# Modification: get values for particular table instead of full summary.
|
||||
# Reference: https://mariadb.com/kb/en/data-type-storage-requirements/
|
||||
|
||||
est_row_size = frappe.db.sql(
|
||||
"""
|
||||
SELECT SUM(col_sizes.col_size) AS EST_MAX_ROW_SIZE
|
||||
FROM (
|
||||
SELECT
|
||||
cols.COLUMN_NAME,
|
||||
CASE cols.DATA_TYPE
|
||||
WHEN 'tinyint' THEN 1
|
||||
WHEN 'smallint' THEN 2
|
||||
WHEN 'mediumint' THEN 3
|
||||
WHEN 'int' THEN 4
|
||||
WHEN 'bigint' THEN 8
|
||||
WHEN 'float' THEN IF(cols.NUMERIC_PRECISION > 24, 8, 4)
|
||||
WHEN 'double' THEN 8
|
||||
WHEN 'decimal' THEN ((cols.NUMERIC_PRECISION - cols.NUMERIC_SCALE) DIV 9)*4 + (cols.NUMERIC_SCALE DIV 9)*4 + CEIL(MOD(cols.NUMERIC_PRECISION - cols.NUMERIC_SCALE,9)/2) + CEIL(MOD(cols.NUMERIC_SCALE,9)/2)
|
||||
WHEN 'bit' THEN (cols.NUMERIC_PRECISION + 7) DIV 8
|
||||
WHEN 'year' THEN 1
|
||||
WHEN 'date' THEN 3
|
||||
WHEN 'time' THEN 3 + CEIL(cols.DATETIME_PRECISION /2)
|
||||
WHEN 'datetime' THEN 5 + CEIL(cols.DATETIME_PRECISION /2)
|
||||
WHEN 'timestamp' THEN 4 + CEIL(cols.DATETIME_PRECISION /2)
|
||||
WHEN 'char' THEN cols.CHARACTER_OCTET_LENGTH
|
||||
WHEN 'binary' THEN cols.CHARACTER_OCTET_LENGTH
|
||||
WHEN 'varchar' THEN IF(cols.CHARACTER_OCTET_LENGTH > 255, 2, 1) + cols.CHARACTER_OCTET_LENGTH
|
||||
WHEN 'varbinary' THEN IF(cols.CHARACTER_OCTET_LENGTH > 255, 2, 1) + cols.CHARACTER_OCTET_LENGTH
|
||||
WHEN 'tinyblob' THEN 9
|
||||
WHEN 'tinytext' THEN 9
|
||||
WHEN 'blob' THEN 10
|
||||
WHEN 'text' THEN 10
|
||||
WHEN 'mediumblob' THEN 11
|
||||
WHEN 'mediumtext' THEN 11
|
||||
WHEN 'longblob' THEN 12
|
||||
WHEN 'longtext' THEN 12
|
||||
WHEN 'enum' THEN 2
|
||||
WHEN 'set' THEN 8
|
||||
ELSE 0
|
||||
END AS col_size
|
||||
FROM INFORMATION_SCHEMA.COLUMNS cols
|
||||
WHERE cols.TABLE_NAME = %s
|
||||
) AS col_sizes;""",
|
||||
(get_table_name(doctype),),
|
||||
)
|
||||
|
||||
if est_row_size:
|
||||
return int(est_row_size[0][0])
|
||||
|
||||
@contextmanager
|
||||
def unbuffered_cursor(self):
|
||||
from MySQLdb.cursors import SSCursor
|
||||
|
||||
try:
|
||||
if not self._conn:
|
||||
self.connect()
|
||||
|
||||
original_cursor = self._cursor
|
||||
new_cursor = self._cursor = self._conn.cursor(SSCursor)
|
||||
yield
|
||||
finally:
|
||||
self._cursor = original_cursor
|
||||
new_cursor.close()
|
||||
|
||||
def estimate_count(self, doctype: str):
|
||||
"""Get estimated count of total rows in a table."""
|
||||
from frappe.utils.data import cint
|
||||
|
||||
table = get_table_name(doctype)
|
||||
|
||||
count = self.sql("select table_rows from information_schema.tables where table_name = %s", table)
|
||||
return cint(count[0][0]) if count else 0
|
||||
|
|
@ -1382,21 +1382,17 @@ class TestDbConnectWithEnvCredentials(IntegrationTestCase):
|
|||
frappe.init(self.current_site, force=True)
|
||||
frappe.connect()
|
||||
|
||||
with self.assertRaises(Exception) as cm:
|
||||
with self.assertRaises(frappe.db.OperationalError) as cm:
|
||||
frappe.db.connect()
|
||||
|
||||
self.assertTrue(re.search(r"(host name|server on) [\"']iqx.local[\"']", str(cm.exception)))
|
||||
|
||||
# with wrong user name
|
||||
with set_env_variable("FRAPPE_DB_USER", "uname"):
|
||||
frappe.init(self.current_site, force=True)
|
||||
frappe.connect()
|
||||
|
||||
with self.assertRaises(Exception) as cm:
|
||||
with self.assertRaises(frappe.db.OperationalError) as cm:
|
||||
frappe.db.connect()
|
||||
|
||||
self.assertTrue(re.search(r"user [\"']uname[\"']", str(cm.exception)))
|
||||
|
||||
# with wrong password
|
||||
with set_env_variable("FRAPPE_DB_PASSWORD", "pass"):
|
||||
frappe.init(self.current_site, force=True)
|
||||
|
|
@ -1414,11 +1410,9 @@ class TestDbConnectWithEnvCredentials(IntegrationTestCase):
|
|||
frappe.init(self.current_site, force=True)
|
||||
frappe.connect()
|
||||
|
||||
with self.assertRaises(Exception) as cm:
|
||||
with self.assertRaises(frappe.db.OperationalError) as cm:
|
||||
frappe.db.connect()
|
||||
|
||||
self.assertTrue(re.search("(port 1111 failed|Errno 111)", str(cm.exception)))
|
||||
|
||||
# now with configured settings without any influences from env
|
||||
# finally connect should work without any error (when no wrong credentials are given via ENV)
|
||||
frappe.init(self.current_site, force=True)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ dependencies = [
|
|||
"PyMySQL==1.1.1",
|
||||
"pypdf~=3.17.0",
|
||||
"PyPika==0.48.9",
|
||||
"mysqlclient~=2.2.7",
|
||||
"PyQRCode~=1.2.1",
|
||||
"PyYAML~=6.0.1",
|
||||
"RestrictedPython~=8.0",
|
||||
|
|
@ -134,7 +135,6 @@ test = [
|
|||
"Faker~=18.10.1",
|
||||
"hypothesis~=6.77.0",
|
||||
"freezegun~=1.5.1",
|
||||
"pdbpp~=0.10.3",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue