import re from typing import List, Tuple, Union import psycopg2 import psycopg2.extensions from psycopg2.extensions import ISOLATION_LEVEL_REPEATABLE_READ from psycopg2.errorcodes import STRING_DATA_RIGHT_TRUNCATION import frappe from frappe.database.database import Database from frappe.database.postgres.schema import PostgresTable from frappe.utils import cstr, get_table_name # cast decimals as floats DEC2FLOAT = psycopg2.extensions.new_type( psycopg2.extensions.DECIMAL.values, 'DEC2FLOAT', lambda value, curs: float(value) if value is not None else None) psycopg2.extensions.register_type(DEC2FLOAT) class PostgresDatabase(Database): ProgrammingError = psycopg2.ProgrammingError TableMissingError = psycopg2.ProgrammingError OperationalError = psycopg2.OperationalError InternalError = psycopg2.InternalError SQLError = psycopg2.ProgrammingError DataError = psycopg2.DataError InterfaceError = psycopg2.InterfaceError REGEX_CHARACTER = '~' def setup_type_map(self): self.db_type = 'postgres' self.type_map = { 'Currency': ('decimal', '21,9'), 'Int': ('bigint', None), 'Long Int': ('bigint', None), 'Float': ('decimal', '21,9'), 'Percent': ('decimal', '21,9'), 'Check': ('smallint', None), 'Small Text': ('text', ''), 'Long Text': ('text', ''), 'Code': ('text', ''), 'Text Editor': ('text', ''), 'Markdown Editor': ('text', ''), 'HTML Editor': ('text', ''), 'Date': ('date', ''), 'Datetime': ('timestamp', None), 'Time': ('time', '6'), 'Text': ('text', ''), 'Data': ('varchar', self.VARCHAR_LEN), 'Link': ('varchar', self.VARCHAR_LEN), 'Dynamic Link': ('varchar', self.VARCHAR_LEN), 'Password': ('text', ''), 'Select': ('varchar', self.VARCHAR_LEN), 'Rating': ('decimal', '3,2'), 'Read Only': ('varchar', self.VARCHAR_LEN), 'Attach': ('text', ''), 'Attach Image': ('text', ''), 'Signature': ('text', ''), 'Color': ('varchar', self.VARCHAR_LEN), 'Barcode': ('text', ''), 'Geolocation': ('text', ''), 'Duration': ('decimal', '21,9'), 'Icon': ('varchar', self.VARCHAR_LEN) } def get_connection(self): conn = psycopg2.connect("host='{}' dbname='{}' user='{}' password='{}' port={}".format( self.host, self.user, self.user, self.password, self.port )) conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ) return conn def escape(self, s, percent=True): """Escape quotes and percent in given string.""" if isinstance(s, bytes): s = s.decode('utf-8') # MariaDB's driver treats None as an empty string # So Postgres should do the same if s is None: s = '' if percent: s = s.replace("%", "%%") s = s.encode('utf-8') return str(psycopg2.extensions.QuotedString(s)) def get_database_size(self): ''''Returns database size in MB''' db_size = self.sql("SELECT (pg_database_size(%s) / 1024 / 1024) as database_size", self.db_name, as_dict=True) return db_size[0].get('database_size') # pylint: disable=W0221 def sql(self, *args, **kwargs): if args: # since tuple is immutable args = list(args) args[0] = modify_query(args[0]) args = tuple(args) elif kwargs.get('query'): kwargs['query'] = modify_query(kwargs.get('query')) return super(PostgresDatabase, self).sql(*args, **kwargs) def get_tables(self, cached=True): return [d[0] for d in self.sql("""select table_name from information_schema.tables where table_catalog='{0}' and table_type = 'BASE TABLE' and table_schema='{1}'""".format(frappe.conf.db_name, frappe.conf.get("db_schema", "public")))] def format_date(self, date): if not date: return '0001-01-01' if not isinstance(date, str): date = date.strftime('%Y-%m-%d') return date # column type @staticmethod def is_type_number(code): return code == psycopg2.NUMBER @staticmethod def is_type_datetime(code): return code == psycopg2.DATETIME # exception type @staticmethod def is_deadlocked(e): return e.pgcode == '40P01' @staticmethod def is_timedout(e): # http://initd.org/psycopg/docs/extensions.html?highlight=datatype#psycopg2.extensions.QueryCanceledError return isinstance(e, psycopg2.extensions.QueryCanceledError) @staticmethod def is_syntax_error(e): return isinstance(e, psycopg2.errors.SyntaxError) @staticmethod def is_table_missing(e): return getattr(e, 'pgcode', None) == '42P01' @staticmethod def is_missing_column(e): return getattr(e, 'pgcode', None) == '42703' @staticmethod def is_access_denied(e): return e.pgcode == '42501' @staticmethod def cant_drop_field_or_key(e): return e.pgcode.startswith('23') @staticmethod def is_duplicate_entry(e): return e.pgcode == '23505' @staticmethod def is_primary_key_violation(e): return e.pgcode == '23505' and '_pkey' in cstr(e.args[0]) @staticmethod def is_unique_key_violation(e): return e.pgcode == '23505' and '_key' in cstr(e.args[0]) @staticmethod def is_duplicate_fieldname(e): return e.pgcode == '42701' @staticmethod def is_data_too_long(e): return e.pgcode == STRING_DATA_RIGHT_TRUNCATION def rename_table(self, old_name: str, new_name: str) -> Union[List, Tuple]: old_name = get_table_name(old_name) new_name = get_table_name(new_name) return self.sql(f"ALTER TABLE `{old_name}` RENAME TO `{new_name}`") def describe(self, doctype: str)-> Union[List, Tuple]: table_name = get_table_name(doctype) return self.sql(f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}'") def change_column_type(self, doctype: str, column: str, type: str, nullable: bool = False) -> Union[List, Tuple]: table_name = get_table_name(doctype) null_constraint = "SET NOT NULL" if not nullable else "DROP NOT NULL" return self.sql(f"""ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {type}, ALTER COLUMN "{column}" {null_constraint}""") def create_auth_table(self): self.sql_ddl("""create table if not exists "__Auth" ( "doctype" VARCHAR(140) NOT NULL, "name" VARCHAR(255) NOT NULL, "fieldname" VARCHAR(140) NOT NULL, "password" TEXT NOT NULL, "encrypted" INT NOT NULL DEFAULT 0, PRIMARY KEY ("doctype", "name", "fieldname") )""") def create_global_search_table(self): if not '__global_search' in self.get_tables(): self.sql('''create table "__global_search"( doctype varchar(100), name varchar({0}), title varchar({0}), content text, route varchar({0}), published int not null default 0, unique (doctype, name))'''.format(self.VARCHAR_LEN)) def create_user_settings_table(self): self.sql_ddl("""create table if not exists "__UserSettings" ( "user" VARCHAR(180) NOT NULL, "doctype" VARCHAR(180) NOT NULL, "data" TEXT, UNIQUE ("user", "doctype") )""") def create_help_table(self): self.sql('''CREATE TABLE "help"( "path" varchar(255), "content" text, "title" text, "intro" text, "full_path" text)''') self.sql('''CREATE INDEX IF NOT EXISTS "help_index" ON "help" ("path")''') def updatedb(self, doctype, meta=None): """ Syncs a `DocType` to the table * creates if required * updates columns * updates indices """ res = self.sql("select issingle from `tabDocType` where name='{}'".format(doctype)) if not res: raise Exception('Wrong doctype {0} in updatedb'.format(doctype)) if not res[0][0]: db_table = PostgresTable(doctype, meta) db_table.validate() self.commit() db_table.sync() self.begin() @staticmethod def get_on_duplicate_update(key='name'): if isinstance(key, list): key = '", "'.join(key) return 'ON CONFLICT ("{key}") DO UPDATE SET '.format( key=key ) def check_implicit_commit(self, query): pass # postgres can run DDL in transactions without implicit commits def has_index(self, table_name, index_name): return self.sql("""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}' and indexname='{index_name}' limit 1""".format(table_name=table_name, index_name=index_name)) def add_index(self, doctype: str, fields: List, index_name: str = None): """Creates an index with given fields if not already created. Index name will be `fieldname1_fieldname2_index`""" table_name = get_table_name(doctype) index_name = index_name or self.get_index_name(fields) fields_str = '", "'.join(re.sub(r"\(.*\)", "", field) for field in fields) self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON `{table_name}` ("{fields_str}")') def add_unique(self, doctype, fields, constraint_name=None): if isinstance(fields, str): fields = [fields] if not constraint_name: constraint_name = "unique_" + "_".join(fields) if not self.sql(""" SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS WHERE table_name=%s AND constraint_type='UNIQUE' AND CONSTRAINT_NAME=%s""", ('tab' + doctype, constraint_name)): self.commit() self.sql("""ALTER TABLE `tab%s` ADD CONSTRAINT %s UNIQUE (%s)""" % (doctype, constraint_name, ", ".join(fields))) def get_table_columns_description(self, table_name): """Returns list of column and its description""" # pylint: disable=W1401 return self.sql(''' SELECT a.column_name AS name, CASE LOWER(a.data_type) WHEN 'character varying' THEN CONCAT('varchar(', a.character_maximum_length ,')') WHEN 'timestamp without time zone' THEN 'timestamp' ELSE a.data_type END AS type, BOOL_OR(b.index) AS index, SPLIT_PART(COALESCE(a.column_default, NULL), '::', 1) AS default, BOOL_OR(b.unique) AS unique FROM information_schema.columns a LEFT JOIN (SELECT indexdef, tablename, indexdef LIKE '%UNIQUE INDEX%' AS unique, indexdef NOT LIKE '%UNIQUE INDEX%' AS index FROM pg_indexes WHERE tablename='{table_name}') b ON SUBSTRING(b.indexdef, '(.*)') LIKE CONCAT('%', a.column_name, '%') WHERE a.table_name = '{table_name}' GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length; '''.format(table_name=table_name), as_dict=1) def get_database_list(self, target): return [d[0] for d in self.sql("SELECT datname FROM pg_database;")] def modify_query(query): """"Modifies query according to the requirements of postgres""" # replace ` with " for definitions query = str(query) query = query.replace('`', '"') query = replace_locate_with_strpos(query) # select from requires "" if re.search('from tab', query, flags=re.IGNORECASE): query = re.sub('from tab([a-zA-Z]*)', r'from "tab\1"', query, flags=re.IGNORECASE) return query def replace_locate_with_strpos(query): # strpos is the locate equivalent in postgres if re.search(r'locate\(', query, flags=re.IGNORECASE): query = re.sub(r'locate\(([^,]+),([^)]+)\)', r'strpos(\2, \1)', query, flags=re.IGNORECASE) return query