* updated change_column_type to allow making columns nullable. * breaking change: in postgres the method was previously nullable by default, changed it to be consistent with mariadb.
303 lines
9.2 KiB
Python
303 lines
9.2 KiB
Python
from typing import List, Tuple, Union
|
|
|
|
import pymysql
|
|
from pymysql.constants import ER, FIELD_TYPE
|
|
from pymysql.converters import conversions, escape_string
|
|
|
|
import frappe
|
|
from frappe.database.database import Database
|
|
from frappe.database.mariadb.schema import MariaDBTable
|
|
from frappe.utils import UnicodeWithAttrs, cstr, get_datetime, get_table_name
|
|
|
|
|
|
class MariaDBDatabase(Database):
|
|
ProgrammingError = pymysql.err.ProgrammingError
|
|
TableMissingError = pymysql.err.ProgrammingError
|
|
OperationalError = pymysql.err.OperationalError
|
|
InternalError = pymysql.err.InternalError
|
|
SQLError = pymysql.err.ProgrammingError
|
|
DataError = pymysql.err.DataError
|
|
REGEX_CHARACTER = 'regexp'
|
|
|
|
def setup_type_map(self):
|
|
self.db_type = 'mariadb'
|
|
self.type_map = {
|
|
'Currency': ('decimal', '21,9'),
|
|
'Int': ('int', '11'),
|
|
'Long Int': ('bigint', '20'),
|
|
'Float': ('decimal', '21,9'),
|
|
'Percent': ('decimal', '21,9'),
|
|
'Check': ('int', '1'),
|
|
'Small Text': ('text', ''),
|
|
'Long Text': ('longtext', ''),
|
|
'Code': ('longtext', ''),
|
|
'Text Editor': ('longtext', ''),
|
|
'Markdown Editor': ('longtext', ''),
|
|
'HTML Editor': ('longtext', ''),
|
|
'Date': ('date', ''),
|
|
'Datetime': ('datetime', '6'),
|
|
'Time': ('time', '6'),
|
|
'Text': ('text', ''),
|
|
'Data': ('varchar', self.VARCHAR_LEN),
|
|
'Link': ('varchar', self.VARCHAR_LEN),
|
|
'Dynamic Link': ('varchar', self.VARCHAR_LEN),
|
|
'Password': ('text', ''),
|
|
'Select': ('varchar', self.VARCHAR_LEN),
|
|
'Rating': ('decimal', '3,2'),
|
|
'Read Only': ('varchar', self.VARCHAR_LEN),
|
|
'Attach': ('text', ''),
|
|
'Attach Image': ('text', ''),
|
|
'Signature': ('longtext', ''),
|
|
'Color': ('varchar', self.VARCHAR_LEN),
|
|
'Barcode': ('longtext', ''),
|
|
'Geolocation': ('longtext', ''),
|
|
'Duration': ('decimal', '21,9'),
|
|
'Icon': ('varchar', self.VARCHAR_LEN)
|
|
}
|
|
|
|
def get_connection(self):
|
|
usessl = 0
|
|
if frappe.conf.db_ssl_ca and frappe.conf.db_ssl_cert and frappe.conf.db_ssl_key:
|
|
usessl = 1
|
|
ssl_params = {
|
|
'ca':frappe.conf.db_ssl_ca,
|
|
'cert':frappe.conf.db_ssl_cert,
|
|
'key':frappe.conf.db_ssl_key
|
|
}
|
|
|
|
conversions.update({
|
|
FIELD_TYPE.NEWDECIMAL: float,
|
|
FIELD_TYPE.DATETIME: get_datetime,
|
|
UnicodeWithAttrs: conversions[str]
|
|
})
|
|
|
|
conn = pymysql.connect(
|
|
user=self.user or '',
|
|
password=self.password or '',
|
|
host=self.host,
|
|
port=self.port,
|
|
charset='utf8mb4',
|
|
use_unicode=True,
|
|
ssl=ssl_params if usessl else None,
|
|
conv=conversions,
|
|
local_infile=frappe.conf.local_infile
|
|
)
|
|
|
|
# MYSQL_OPTION_MULTI_STATEMENTS_OFF = 1
|
|
# # self._conn.set_server_option(MYSQL_OPTION_MULTI_STATEMENTS_OFF)
|
|
|
|
if self.user != 'root':
|
|
conn.select_db(self.user)
|
|
|
|
return conn
|
|
|
|
def get_database_size(self):
|
|
''''Returns database size in MB'''
|
|
db_size = self.sql('''
|
|
SELECT `table_schema` as `database_name`,
|
|
SUM(`data_length` + `index_length`) / 1024 / 1024 AS `database_size`
|
|
FROM information_schema.tables WHERE `table_schema` = %s GROUP BY `table_schema`
|
|
''', self.db_name, as_dict=True)
|
|
|
|
return db_size[0].get('database_size')
|
|
|
|
@staticmethod
|
|
def escape(s, percent=True):
|
|
"""Excape quotes and percent in given string."""
|
|
# pymysql expects unicode argument to escape_string with Python 3
|
|
s = frappe.as_unicode(escape_string(frappe.as_unicode(s)), "utf-8").replace("`", "\\`")
|
|
|
|
# NOTE separating % escape, because % escape should only be done when using LIKE operator
|
|
# or when you use python format string to generate query that already has a %s
|
|
# for example: sql("select name from `tabUser` where name=%s and {0}".format(conditions), something)
|
|
# defaulting it to True, as this is the most frequent use case
|
|
# ideally we shouldn't have to use ESCAPE and strive to pass values via the values argument of sql
|
|
if percent:
|
|
s = s.replace("%", "%%")
|
|
|
|
return "'" + s + "'"
|
|
|
|
# column type
|
|
@staticmethod
|
|
def is_type_number(code):
|
|
return code == pymysql.NUMBER
|
|
|
|
@staticmethod
|
|
def is_type_datetime(code):
|
|
return code in (pymysql.DATE, pymysql.DATETIME)
|
|
|
|
def rename_table(self, old_name: str, new_name: str) -> Union[List, Tuple]:
|
|
old_name = get_table_name(old_name)
|
|
new_name = get_table_name(new_name)
|
|
return self.sql(f"RENAME TABLE `{old_name}` TO `{new_name}`")
|
|
|
|
def describe(self, doctype: str) -> Union[List, Tuple]:
|
|
table_name = get_table_name(doctype)
|
|
return self.sql(f"DESC `{table_name}`")
|
|
|
|
def change_column_type(self, doctype: str, column: str, type: str, nullable: bool = False) -> Union[List, Tuple]:
|
|
table_name = get_table_name(doctype)
|
|
null_constraint = "NOT NULL" if not nullable else ""
|
|
return self.sql(f"ALTER TABLE `{table_name}` MODIFY `{column}` {type} {null_constraint}")
|
|
|
|
# exception types
|
|
@staticmethod
|
|
def is_deadlocked(e):
|
|
return e.args[0] == ER.LOCK_DEADLOCK
|
|
|
|
@staticmethod
|
|
def is_timedout(e):
|
|
return e.args[0] == ER.LOCK_WAIT_TIMEOUT
|
|
|
|
@staticmethod
|
|
def is_table_missing(e):
|
|
return e.args[0] == ER.NO_SUCH_TABLE
|
|
|
|
@staticmethod
|
|
def is_missing_column(e):
|
|
return e.args[0] == ER.BAD_FIELD_ERROR
|
|
|
|
@staticmethod
|
|
def is_duplicate_fieldname(e):
|
|
return e.args[0] == ER.DUP_FIELDNAME
|
|
|
|
@staticmethod
|
|
def is_duplicate_entry(e):
|
|
return e.args[0] == ER.DUP_ENTRY
|
|
|
|
@staticmethod
|
|
def is_access_denied( e):
|
|
return e.args[0] == ER.ACCESS_DENIED_ERROR
|
|
|
|
@staticmethod
|
|
def cant_drop_field_or_key(e):
|
|
return e.args[0] == ER.CANT_DROP_FIELD_OR_KEY
|
|
|
|
@staticmethod
|
|
def is_syntax_error(e):
|
|
return e.args[0] == ER.PARSE_ERROR
|
|
|
|
@staticmethod
|
|
def is_data_too_long(e):
|
|
return e.args[0] == ER.DATA_TOO_LONG
|
|
|
|
def is_primary_key_violation(self, e):
|
|
return self.is_duplicate_entry(e) and 'PRIMARY' in cstr(e.args[1])
|
|
|
|
def is_unique_key_violation(self, e):
|
|
return self.is_duplicate_entry(e) and 'Duplicate' in cstr(e.args[1])
|
|
|
|
|
|
def create_auth_table(self):
|
|
self.sql_ddl("""create table if not exists `__Auth` (
|
|
`doctype` VARCHAR(140) NOT NULL,
|
|
`name` VARCHAR(255) NOT NULL,
|
|
`fieldname` VARCHAR(140) NOT NULL,
|
|
`password` TEXT NOT NULL,
|
|
`encrypted` INT(1) NOT NULL DEFAULT 0,
|
|
PRIMARY KEY (`doctype`, `name`, `fieldname`)
|
|
) ENGINE=InnoDB ROW_FORMAT=DYNAMIC CHARACTER SET=utf8mb4 COLLATE=utf8mb4_unicode_ci""")
|
|
|
|
def create_global_search_table(self):
|
|
if not '__global_search' in self.get_tables():
|
|
self.sql('''create table __global_search(
|
|
doctype varchar(100),
|
|
name varchar({0}),
|
|
title varchar({0}),
|
|
content text,
|
|
fulltext(content),
|
|
route varchar({0}),
|
|
published int(1) not null default 0,
|
|
unique `doctype_name` (doctype, name))
|
|
COLLATE=utf8mb4_unicode_ci
|
|
ENGINE=MyISAM
|
|
CHARACTER SET=utf8mb4'''.format(self.VARCHAR_LEN))
|
|
|
|
def create_user_settings_table(self):
|
|
self.sql_ddl("""create table if not exists __UserSettings (
|
|
`user` VARCHAR(180) NOT NULL,
|
|
`doctype` VARCHAR(180) NOT NULL,
|
|
`data` TEXT,
|
|
UNIQUE(user, doctype)
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8""")
|
|
|
|
def create_help_table(self):
|
|
self.sql('''create table help(
|
|
path varchar(255),
|
|
content text,
|
|
title text,
|
|
intro text,
|
|
full_path text,
|
|
fulltext(title),
|
|
fulltext(content),
|
|
index (path))
|
|
COLLATE=utf8mb4_unicode_ci
|
|
ENGINE=MyISAM
|
|
CHARACTER SET=utf8mb4''')
|
|
|
|
@staticmethod
|
|
def get_on_duplicate_update(key=None):
|
|
return 'ON DUPLICATE key UPDATE '
|
|
|
|
def get_table_columns_description(self, table_name):
|
|
"""Returns list of column and its description"""
|
|
return self.sql('''select
|
|
column_name as 'name',
|
|
column_type as 'type',
|
|
column_default as 'default',
|
|
column_key = 'MUL' as 'index',
|
|
column_key = 'UNI' as 'unique'
|
|
from information_schema.columns
|
|
where table_name = '{table_name}' '''.format(table_name=table_name), as_dict=1)
|
|
|
|
def has_index(self, table_name, index_name):
|
|
return self.sql("""SHOW INDEX FROM `{table_name}`
|
|
WHERE Key_name='{index_name}'""".format(
|
|
table_name=table_name,
|
|
index_name=index_name
|
|
))
|
|
|
|
def add_index(self, doctype: str, fields: List, index_name: str = None):
|
|
"""Creates an index with given fields if not already created.
|
|
Index name will be `fieldname1_fieldname2_index`"""
|
|
index_name = index_name or self.get_index_name(fields)
|
|
table_name = get_table_name(doctype)
|
|
if not self.has_index(table_name, index_name):
|
|
self.commit()
|
|
self.sql("""ALTER TABLE `%s`
|
|
ADD INDEX `%s`(%s)""" % (table_name, index_name, ", ".join(fields)))
|
|
|
|
def add_unique(self, doctype, fields, constraint_name=None):
|
|
if isinstance(fields, str):
|
|
fields = [fields]
|
|
if not constraint_name:
|
|
constraint_name = "unique_" + "_".join(fields)
|
|
|
|
if not self.sql("""select CONSTRAINT_NAME from information_schema.TABLE_CONSTRAINTS
|
|
where table_name=%s and constraint_type='UNIQUE' and CONSTRAINT_NAME=%s""",
|
|
('tab' + doctype, constraint_name)):
|
|
self.commit()
|
|
self.sql("""alter table `tab%s`
|
|
add unique `%s`(%s)""" % (doctype, constraint_name, ", ".join(fields)))
|
|
|
|
def updatedb(self, doctype, meta=None):
|
|
"""
|
|
Syncs a `DocType` to the table
|
|
* creates if required
|
|
* updates columns
|
|
* updates indices
|
|
"""
|
|
res = self.sql("select issingle from `tabDocType` where name=%s", (doctype,))
|
|
if not res:
|
|
raise Exception('Wrong doctype {0} in updatedb'.format(doctype))
|
|
|
|
if not res[0][0]:
|
|
db_table = MariaDBTable(doctype, meta)
|
|
db_table.validate()
|
|
|
|
self.commit()
|
|
db_table.sync()
|
|
self.begin()
|
|
|
|
def get_database_list(self, target):
|
|
return [d[0] for d in self.sql("SHOW DATABASES;")]
|