seitime-frappe/frappe/database/postgres/database.py
2019-02-06 16:09:16 +05:30

311 lines
No EOL
9.2 KiB
Python

from __future__ import unicode_literals
import re
import frappe
import psycopg2
import psycopg2.extensions
from six import string_types
from frappe.utils import cstr
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from frappe.database.database import Database
from frappe.database.postgres.schema import PostgresTable
# cast decimals as floats
DEC2FLOAT = psycopg2.extensions.new_type(
psycopg2.extensions.DECIMAL.values,
'DEC2FLOAT',
lambda value, curs: float(value) if value is not None else None)
psycopg2.extensions.register_type(DEC2FLOAT)
class PostgresDatabase(Database):
ProgrammingError = psycopg2.ProgrammingError
OperationalError = psycopg2.OperationalError
InternalError = psycopg2.InternalError
SQLError = psycopg2.ProgrammingError
DataError = psycopg2.DataError
InterfaceError = psycopg2.InterfaceError
REGEX_CHARACTER = '~'
def setup_type_map(self):
self.type_map = {
'Currency': ('decimal', '18,6'),
'Int': ('bigint', None),
'Long Int': ('bigint', None),
'Float': ('decimal', '18,6'),
'Percent': ('decimal', '18,6'),
'Check': ('smallint', None),
'Small Text': ('text', ''),
'Long Text': ('text', ''),
'Code': ('text', ''),
'Text Editor': ('text', ''),
'Markdown Editor': ('text', ''),
'HTML Editor': ('text', ''),
'Date': ('date', ''),
'Datetime': ('timestamp', None),
'Time': ('time', '6'),
'Text': ('text', ''),
'Data': ('varchar', self.VARCHAR_LEN),
'Link': ('varchar', self.VARCHAR_LEN),
'Dynamic Link': ('varchar', self.VARCHAR_LEN),
'Password': ('varchar', self.VARCHAR_LEN),
'Select': ('varchar', self.VARCHAR_LEN),
'Read Only': ('varchar', self.VARCHAR_LEN),
'Attach': ('text', ''),
'Attach Image': ('text', ''),
'Signature': ('text', ''),
'Color': ('varchar', self.VARCHAR_LEN),
'Barcode': ('text', ''),
'Geolocation': ('text', '')
}
def get_connection(self):
# warnings.filterwarnings('ignore', category=psycopg2.Warning)
conn = psycopg2.connect('host={} dbname={}'.format(self.host, self.user))
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) # TODO: Remove this
# conn = psycopg2.connect('host={} dbname={} user={} password={}'.format(self.host,
# self.user, self.user, self.password))
return conn
def escape(self, s, percent=True):
"""Excape quotes and percent in given string."""
if isinstance(s, bytes):
s = s.decode('utf-8')
if percent:
s = s.replace("%", "%%")
s = s.encode('utf-8')
return str(psycopg2.extensions.QuotedString(s))
def get_database_size(self):
''''Returns database size in MB'''
db_size = self.sql("SELECT (pg_database_size(%s) / 1024 / 1024) as database_size",
self.db_name, as_dict=True)
return db_size[0].get('database_size')
# pylint: disable=W0221
def sql(self, *args, **kwargs):
if len(args):
# since tuple is immutable
args = list(args)
args[0] = modify_query(args[0])
args = tuple(args)
elif kwargs.get('query'):
kwargs['query'] = modify_query(kwargs.get('query'))
return super(PostgresDatabase, self).sql(*args, **kwargs)
def get_tables(self):
return [d[0] for d in self.sql("""select table_name
from information_schema.tables
where table_catalog='{0}'
and table_type = 'BASE TABLE'
and table_schema='public'""".format(frappe.conf.db_name))]
def format_date(self, date):
if not date:
return '0001-01-01::DATE'
if isinstance(date, frappe.string_types):
if ':' not in date:
date = date + '::DATE'
else:
date = date.strftime('%Y-%m-%d') + '::DATE'
return date
# column type
@staticmethod
def is_type_number(code):
return code == psycopg2.NUMBER
@staticmethod
def is_type_datetime(code):
return code == psycopg2.DATETIME
# exception type
@staticmethod
def is_deadlocked(e):
return e.pgcode == '40P01'
@staticmethod
def is_timedout(e):
# http://initd.org/psycopg/docs/extensions.html?highlight=datatype#psycopg2.extensions.QueryCanceledError
return isinstance(e, psycopg2.extensions.QueryCanceledError)
@staticmethod
def is_table_missing(e):
return e.pgcode == '42P01'
@staticmethod
def is_missing_column(e):
return e.pgcode == '42703'
@staticmethod
def is_access_denied(e):
return e.pgcode == '42501'
@staticmethod
def cant_drop_field_or_key(e):
return e.pgcode.startswith('23')
@staticmethod
def is_duplicate_entry(e):
return e.pgcode == '23505'
@staticmethod
def is_primary_key_violation(e):
return e.pgcode == '23505' and '_pkey' in cstr(e.args[0])
@staticmethod
def is_unique_key_violation(e):
return e.pgcode == '23505' and '_key' in cstr(e.args[0])
@staticmethod
def is_duplicate_fieldname(e):
return e.pgcode == '42701'
def create_auth_table(self):
self.sql_ddl("""create table if not exists "__Auth" (
"doctype" VARCHAR(140) NOT NULL,
"name" VARCHAR(255) NOT NULL,
"fieldname" VARCHAR(140) NOT NULL,
"password" VARCHAR(255) NOT NULL,
"encrypted" INT NOT NULL DEFAULT 0,
PRIMARY KEY ("doctype", "name", "fieldname")
)""")
def create_global_search_table(self):
if not '__global_search' in self.get_tables():
self.sql('''create table "__global_search"(
doctype varchar(100),
name varchar({0}),
title varchar({0}),
content text,
route varchar({0}),
published int not null default 0,
unique (doctype, name))'''.format(self.VARCHAR_LEN))
def create_user_settings_table(self):
self.sql_ddl("""create table if not exists "__UserSettings" (
"user" VARCHAR(180) NOT NULL,
"doctype" VARCHAR(180) NOT NULL,
"data" TEXT,
UNIQUE ("user", "doctype")
)""")
def create_help_table(self):
self.sql('''CREATE TABLE "help"(
"path" varchar(255),
"content" text,
"title" text,
"intro" text,
"full_path" text)''')
self.sql('''CREATE INDEX IF NOT EXISTS "help_index" ON "help" ("path")''')
def updatedb(self, doctype, meta=None):
"""
Syncs a `DocType` to the table
* creates if required
* updates columns
* updates indices
"""
res = self.sql("select issingle from `tabDocType` where name='{}'".format(doctype))
if not res:
raise Exception('Wrong doctype {0} in updatedb'.format(doctype))
if not res[0][0]:
db_table = PostgresTable(doctype, meta)
db_table.validate()
self.commit()
db_table.sync()
self.begin()
@staticmethod
def get_on_duplicate_update(key='name'):
if isinstance(key, list):
key = '", "'.join(key)
return 'ON CONFLICT ("{key}") DO UPDATE SET '.format(
key=key
)
def check_transaction_status(self, query):
pass
def has_index(self, table_name, index_name):
return self.sql("""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}'
and indexname='{index_name}' limit 1""".format(table_name=table_name, index_name=index_name))
def add_index(self, doctype, fields, index_name=None):
"""Creates an index with given fields if not already created.
Index name will be `fieldname1_fieldname2_index`"""
index_name = index_name or self.get_index_name(fields)
table_name = 'tab' + doctype
self.commit()
self.sql("""CREATE INDEX IF NOT EXISTS "{}" ON `{}`("{}")""".format(index_name, table_name, '", "'.join(fields)))
def add_unique(self, doctype, fields, constraint_name=None):
if isinstance(fields, string_types):
fields = [fields]
if not constraint_name:
constraint_name = "unique_" + "_".join(fields)
if not self.sql("""
SELECT CONSTRAINT_NAME
FROM information_schema.TABLE_CONSTRAINTS
WHERE table_name=%s
AND constraint_type='UNIQUE'
AND CONSTRAINT_NAME=%s""",
('tab' + doctype, constraint_name)):
self.commit()
self.sql("""ALTER TABLE `tab%s`
ADD CONSTRAINT %s UNIQUE (%s)""" % (doctype, constraint_name, ", ".join(fields)))
def get_table_columns_description(self, table_name):
"""Returns list of column and its description"""
# pylint: disable=W1401
return self.sql('''
SELECT a.column_name AS name,
CASE a.data_type
WHEN 'character varying' THEN CONCAT('varchar(', a.character_maximum_length ,')')
WHEN 'timestamp without TIME zone' THEN 'timestamp'
ELSE a.data_type
END AS type,
COUNT(b.indexdef) AS Index,
COALESCE(a.column_default, NULL) AS default,
BOOL_OR(b.unique) AS unique
FROM information_schema.columns a
LEFT JOIN
(SELECT indexdef, tablename, indexdef LIKE '%UNIQUE INDEX%' AS unique
FROM pg_indexes
WHERE tablename='{table_name}') b
ON SUBSTRING(b.indexdef, '\(.*\)') LIKE CONCAT('%', a.column_name, '%')
WHERE a.table_name = '{table_name}'
GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length;'''
.format(table_name=table_name), as_dict=1)
def get_database_list(self, target):
return [d[0] for d in self.sql("SELECT datname FROM pg_database;")]
def modify_query(query):
""""Modifies query according to the requirements of postgres"""
# replace ` with " for definitions
query = query.replace('`', '"')
query = replace_locate_with_strpos(query)
# select from requires ""
if re.search('from tab', query, flags=re.IGNORECASE):
query = re.sub('from tab([a-zA-Z]*)', r'from "tab\1"', query, flags=re.IGNORECASE)
return query
def replace_locate_with_strpos(query):
# strpos is the locate equivalent in postgres
if re.search(r'locate\(', query, flags=re.IGNORECASE):
query = re.sub(r'locate\(([^,]+),([^)]+)\)', r'strpos(\2, \1)', query, flags=re.IGNORECASE)
return query