from __future__ import unicode_literals import re import frappe import psycopg2 import psycopg2.extensions from six import string_types from frappe.utils import cstr from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT from frappe.database.database import Database from frappe.database.postgres.schema import PostgresTable # cast decimals as floats DEC2FLOAT = psycopg2.extensions.new_type( psycopg2.extensions.DECIMAL.values, 'DEC2FLOAT', lambda value, curs: float(value) if value is not None else None) psycopg2.extensions.register_type(DEC2FLOAT) class PostgresDatabase(Database): ProgrammingError = psycopg2.ProgrammingError TableMissingError = psycopg2.ProgrammingError OperationalError = psycopg2.OperationalError InternalError = psycopg2.InternalError SQLError = psycopg2.ProgrammingError DataError = psycopg2.DataError InterfaceError = psycopg2.InterfaceError REGEX_CHARACTER = '~' def setup_type_map(self): self.db_type = 'postgres' self.type_map = { 'Currency': ('decimal', '18,6'), 'Int': ('bigint', None), 'Long Int': ('bigint', None), 'Float': ('decimal', '18,6'), 'Percent': ('decimal', '18,6'), 'Check': ('smallint', None), 'Small Text': ('text', ''), 'Long Text': ('text', ''), 'Code': ('text', ''), 'Text Editor': ('text', ''), 'Markdown Editor': ('text', ''), 'HTML Editor': ('text', ''), 'Date': ('date', ''), 'Datetime': ('timestamp', None), 'Time': ('time', '6'), 'Text': ('text', ''), 'Data': ('varchar', self.VARCHAR_LEN), 'Link': ('varchar', self.VARCHAR_LEN), 'Dynamic Link': ('varchar', self.VARCHAR_LEN), 'Password': ('varchar', self.VARCHAR_LEN), 'Select': ('varchar', self.VARCHAR_LEN), 'Rating': ('smallint', None), 'Read Only': ('varchar', self.VARCHAR_LEN), 'Attach': ('text', ''), 'Attach Image': ('text', ''), 'Signature': ('text', ''), 'Color': ('varchar', self.VARCHAR_LEN), 'Barcode': ('text', ''), 'Geolocation': ('text', '') } def get_connection(self): # warnings.filterwarnings('ignore', category=psycopg2.Warning) conn = psycopg2.connect("host='{}' dbname='{}' user='{}' password='{}' port={}".format( self.host, self.user, self.user, self.password, self.port )) conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) # TODO: Remove this return conn def escape(self, s, percent=True): """Excape quotes and percent in given string.""" if isinstance(s, bytes): s = s.decode('utf-8') if percent: s = s.replace("%", "%%") s = s.encode('utf-8') return str(psycopg2.extensions.QuotedString(s)) def get_database_size(self): ''''Returns database size in MB''' db_size = self.sql("SELECT (pg_database_size(%s) / 1024 / 1024) as database_size", self.db_name, as_dict=True) return db_size[0].get('database_size') # pylint: disable=W0221 def sql(self, *args, **kwargs): if len(args): # since tuple is immutable args = list(args) args[0] = modify_query(args[0]) args = tuple(args) elif kwargs.get('query'): kwargs['query'] = modify_query(kwargs.get('query')) return super(PostgresDatabase, self).sql(*args, **kwargs) def get_tables(self): return [d[0] for d in self.sql("""select table_name from information_schema.tables where table_catalog='{0}' and table_type = 'BASE TABLE' and table_schema='{1}'""".format(frappe.conf.db_name, frappe.conf.get("db_schema", "public")))] def format_date(self, date): if not date: return '0001-01-01' if not isinstance(date, frappe.string_types): date = date.strftime('%Y-%m-%d') return date # column type @staticmethod def is_type_number(code): return code == psycopg2.NUMBER @staticmethod def is_type_datetime(code): return code == psycopg2.DATETIME # exception type @staticmethod def is_deadlocked(e): return e.pgcode == '40P01' @staticmethod def is_timedout(e): # http://initd.org/psycopg/docs/extensions.html?highlight=datatype#psycopg2.extensions.QueryCanceledError return isinstance(e, psycopg2.extensions.QueryCanceledError) @staticmethod def is_table_missing(e): return e.pgcode == '42P01' @staticmethod def is_missing_column(e): return e.pgcode == '42703' @staticmethod def is_access_denied(e): return e.pgcode == '42501' @staticmethod def cant_drop_field_or_key(e): return e.pgcode.startswith('23') @staticmethod def is_duplicate_entry(e): return e.pgcode == '23505' @staticmethod def is_primary_key_violation(e): return e.pgcode == '23505' and '_pkey' in cstr(e.args[0]) @staticmethod def is_unique_key_violation(e): return e.pgcode == '23505' and '_key' in cstr(e.args[0]) @staticmethod def is_duplicate_fieldname(e): return e.pgcode == '42701' @staticmethod def is_data_too_long(e): return e.pgcode == '22001' def create_auth_table(self): self.sql_ddl("""create table if not exists "__Auth" ( "doctype" VARCHAR(140) NOT NULL, "name" VARCHAR(255) NOT NULL, "fieldname" VARCHAR(140) NOT NULL, "password" VARCHAR(255) NOT NULL, "encrypted" INT NOT NULL DEFAULT 0, PRIMARY KEY ("doctype", "name", "fieldname") )""") def create_global_search_table(self): if not '__global_search' in self.get_tables(): self.sql('''create table "__global_search"( doctype varchar(100), name varchar({0}), title varchar({0}), content text, route varchar({0}), published int not null default 0, unique (doctype, name))'''.format(self.VARCHAR_LEN)) def create_user_settings_table(self): self.sql_ddl("""create table if not exists "__UserSettings" ( "user" VARCHAR(180) NOT NULL, "doctype" VARCHAR(180) NOT NULL, "data" TEXT, UNIQUE ("user", "doctype") )""") def create_help_table(self): self.sql('''CREATE TABLE "help"( "path" varchar(255), "content" text, "title" text, "intro" text, "full_path" text)''') self.sql('''CREATE INDEX IF NOT EXISTS "help_index" ON "help" ("path")''') def updatedb(self, doctype, meta=None): """ Syncs a `DocType` to the table * creates if required * updates columns * updates indices """ res = self.sql("select issingle from `tabDocType` where name='{}'".format(doctype)) if not res: raise Exception('Wrong doctype {0} in updatedb'.format(doctype)) if not res[0][0]: db_table = PostgresTable(doctype, meta) db_table.validate() self.commit() db_table.sync() self.begin() @staticmethod def get_on_duplicate_update(key='name'): if isinstance(key, list): key = '", "'.join(key) return 'ON CONFLICT ("{key}") DO UPDATE SET '.format( key=key ) def check_transaction_status(self, query): pass def has_index(self, table_name, index_name): return self.sql("""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}' and indexname='{index_name}' limit 1""".format(table_name=table_name, index_name=index_name)) def add_index(self, doctype, fields, index_name=None): """Creates an index with given fields if not already created. Index name will be `fieldname1_fieldname2_index`""" index_name = index_name or self.get_index_name(fields) table_name = 'tab' + doctype self.commit() self.sql("""CREATE INDEX IF NOT EXISTS "{}" ON `{}`("{}")""".format(index_name, table_name, '", "'.join(fields))) def add_unique(self, doctype, fields, constraint_name=None): if isinstance(fields, string_types): fields = [fields] if not constraint_name: constraint_name = "unique_" + "_".join(fields) if not self.sql(""" SELECT CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS WHERE table_name=%s AND constraint_type='UNIQUE' AND CONSTRAINT_NAME=%s""", ('tab' + doctype, constraint_name)): self.commit() self.sql("""ALTER TABLE `tab%s` ADD CONSTRAINT %s UNIQUE (%s)""" % (doctype, constraint_name, ", ".join(fields))) def get_table_columns_description(self, table_name): """Returns list of column and its description""" # pylint: disable=W1401 return self.sql(''' SELECT a.column_name AS name, CASE a.data_type WHEN 'character varying' THEN CONCAT('varchar(', a.character_maximum_length ,')') WHEN 'timestamp without TIME zone' THEN 'timestamp' ELSE a.data_type END AS type, COUNT(b.indexdef) AS Index, COALESCE(a.column_default, NULL) AS default, BOOL_OR(b.unique) AS unique FROM information_schema.columns a LEFT JOIN (SELECT indexdef, tablename, indexdef LIKE '%UNIQUE INDEX%' AS unique FROM pg_indexes WHERE tablename='{table_name}') b ON SUBSTRING(b.indexdef, '\(.*\)') LIKE CONCAT('%', a.column_name, '%') WHERE a.table_name = '{table_name}' GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length;''' .format(table_name=table_name), as_dict=1) def get_database_list(self, target): return [d[0] for d in self.sql("SELECT datname FROM pg_database;")] def modify_query(query): """"Modifies query according to the requirements of postgres""" # replace ` with " for definitions query = query.replace('`', '"') query = replace_locate_with_strpos(query) # select from requires "" if re.search('from tab', query, flags=re.IGNORECASE): query = re.sub('from tab([a-zA-Z]*)', r'from "tab\1"', query, flags=re.IGNORECASE) return query def replace_locate_with_strpos(query): # strpos is the locate equivalent in postgres if re.search(r'locate\(', query, flags=re.IGNORECASE): query = re.sub(r'locate\(([^,]+),([^)]+)\)', r'strpos(\2, \1)', query, flags=re.IGNORECASE) return query