seitime-frappe/frappe/database/postgres/database.py
2022-01-04 11:25:08 +05:30

337 lines
10 KiB
Python

import re
from typing import List, Tuple, Union
import psycopg2
import psycopg2.extensions
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from psycopg2.errorcodes import STRING_DATA_RIGHT_TRUNCATION
import frappe
from frappe.database.database import Database
from frappe.database.postgres.schema import PostgresTable
from frappe.utils import cstr, get_table_name
# cast decimals as floats
DEC2FLOAT = psycopg2.extensions.new_type(
psycopg2.extensions.DECIMAL.values,
'DEC2FLOAT',
lambda value, curs: float(value) if value is not None else None)
psycopg2.extensions.register_type(DEC2FLOAT)
class PostgresDatabase(Database):
ProgrammingError = psycopg2.ProgrammingError
TableMissingError = psycopg2.ProgrammingError
OperationalError = psycopg2.OperationalError
InternalError = psycopg2.InternalError
SQLError = psycopg2.ProgrammingError
DataError = psycopg2.DataError
InterfaceError = psycopg2.InterfaceError
REGEX_CHARACTER = '~'
def setup_type_map(self):
self.db_type = 'postgres'
self.type_map = {
'Currency': ('decimal', '21,9'),
'Int': ('bigint', None),
'Long Int': ('bigint', None),
'Float': ('decimal', '21,9'),
'Percent': ('decimal', '21,9'),
'Check': ('smallint', None),
'Small Text': ('text', ''),
'Long Text': ('text', ''),
'Code': ('text', ''),
'Text Editor': ('text', ''),
'Markdown Editor': ('text', ''),
'HTML Editor': ('text', ''),
'Date': ('date', ''),
'Datetime': ('timestamp', None),
'Time': ('time', '6'),
'Text': ('text', ''),
'Data': ('varchar', self.VARCHAR_LEN),
'Link': ('varchar', self.VARCHAR_LEN),
'Dynamic Link': ('varchar', self.VARCHAR_LEN),
'Password': ('text', ''),
'Select': ('varchar', self.VARCHAR_LEN),
'Rating': ('decimal', '3,2'),
'Read Only': ('varchar', self.VARCHAR_LEN),
'Attach': ('text', ''),
'Attach Image': ('text', ''),
'Signature': ('text', ''),
'Color': ('varchar', self.VARCHAR_LEN),
'Barcode': ('text', ''),
'Geolocation': ('text', ''),
'Duration': ('decimal', '21,9'),
'Icon': ('varchar', self.VARCHAR_LEN)
}
def get_connection(self):
conn = psycopg2.connect("host='{}' dbname='{}' user='{}' password='{}' port={}".format(
self.host, self.user, self.user, self.password, self.port
))
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) # TODO: Remove this
return conn
def escape(self, s, percent=True):
"""Excape quotes and percent in given string."""
if isinstance(s, bytes):
s = s.decode('utf-8')
if percent:
s = s.replace("%", "%%")
s = s.encode('utf-8')
return str(psycopg2.extensions.QuotedString(s))
def get_database_size(self):
''''Returns database size in MB'''
db_size = self.sql("SELECT (pg_database_size(%s) / 1024 / 1024) as database_size",
self.db_name, as_dict=True)
return db_size[0].get('database_size')
# pylint: disable=W0221
def sql(self, *args, **kwargs):
if args:
# since tuple is immutable
args = list(args)
args[0] = modify_query(args[0])
args = tuple(args)
elif kwargs.get('query'):
kwargs['query'] = modify_query(kwargs.get('query'))
return super(PostgresDatabase, self).sql(*args, **kwargs)
def get_tables(self):
return [d[0] for d in self.sql("""select table_name
from information_schema.tables
where table_catalog='{0}'
and table_type = 'BASE TABLE'
and table_schema='{1}'""".format(frappe.conf.db_name, frappe.conf.get("db_schema", "public")))]
def format_date(self, date):
if not date:
return '0001-01-01'
if not isinstance(date, str):
date = date.strftime('%Y-%m-%d')
return date
# column type
@staticmethod
def is_type_number(code):
return code == psycopg2.NUMBER
@staticmethod
def is_type_datetime(code):
return code == psycopg2.DATETIME
# exception type
@staticmethod
def is_deadlocked(e):
return e.pgcode == '40P01'
@staticmethod
def is_timedout(e):
# http://initd.org/psycopg/docs/extensions.html?highlight=datatype#psycopg2.extensions.QueryCanceledError
return isinstance(e, psycopg2.extensions.QueryCanceledError)
@staticmethod
def is_syntax_error(e):
return isinstance(e, psycopg2.errors.SyntaxError)
@staticmethod
def is_table_missing(e):
return getattr(e, 'pgcode', None) == '42P01'
@staticmethod
def is_missing_column(e):
return getattr(e, 'pgcode', None) == '42703'
@staticmethod
def is_access_denied(e):
return e.pgcode == '42501'
@staticmethod
def cant_drop_field_or_key(e):
return e.pgcode.startswith('23')
@staticmethod
def is_duplicate_entry(e):
return e.pgcode == '23505'
@staticmethod
def is_primary_key_violation(e):
return e.pgcode == '23505' and '_pkey' in cstr(e.args[0])
@staticmethod
def is_unique_key_violation(e):
return e.pgcode == '23505' and '_key' in cstr(e.args[0])
@staticmethod
def is_duplicate_fieldname(e):
return e.pgcode == '42701'
@staticmethod
def is_data_too_long(e):
return e.pgcode == STRING_DATA_RIGHT_TRUNCATION
def rename_table(self, old_name: str, new_name: str) -> Union[List, Tuple]:
old_name = get_table_name(old_name)
new_name = get_table_name(new_name)
return self.sql(f"ALTER TABLE `{old_name}` RENAME TO `{new_name}`")
def describe(self, doctype: str)-> Union[List, Tuple]:
table_name = get_table_name(doctype)
return self.sql(f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}'")
def change_column_type(self, doctype: str, column: str, type: str, nullable: bool = False) -> Union[List, Tuple]:
table_name = get_table_name(doctype)
null_constraint = "SET NOT NULL" if not nullable else "DROP NOT NULL"
return self.sql(f"""ALTER TABLE "{table_name}"
ALTER COLUMN "{column}" TYPE {type},
ALTER COLUMN "{column}" {null_constraint}""")
def create_auth_table(self):
self.sql_ddl("""create table if not exists "__Auth" (
"doctype" VARCHAR(140) NOT NULL,
"name" VARCHAR(255) NOT NULL,
"fieldname" VARCHAR(140) NOT NULL,
"password" TEXT NOT NULL,
"encrypted" INT NOT NULL DEFAULT 0,
PRIMARY KEY ("doctype", "name", "fieldname")
)""")
def create_global_search_table(self):
if not '__global_search' in self.get_tables():
self.sql('''create table "__global_search"(
doctype varchar(100),
name varchar({0}),
title varchar({0}),
content text,
route varchar({0}),
published int not null default 0,
unique (doctype, name))'''.format(self.VARCHAR_LEN))
def create_user_settings_table(self):
self.sql_ddl("""create table if not exists "__UserSettings" (
"user" VARCHAR(180) NOT NULL,
"doctype" VARCHAR(180) NOT NULL,
"data" TEXT,
UNIQUE ("user", "doctype")
)""")
def create_help_table(self):
self.sql('''CREATE TABLE "help"(
"path" varchar(255),
"content" text,
"title" text,
"intro" text,
"full_path" text)''')
self.sql('''CREATE INDEX IF NOT EXISTS "help_index" ON "help" ("path")''')
def updatedb(self, doctype, meta=None):
"""
Syncs a `DocType` to the table
* creates if required
* updates columns
* updates indices
"""
res = self.sql("select issingle from `tabDocType` where name='{}'".format(doctype))
if not res:
raise Exception('Wrong doctype {0} in updatedb'.format(doctype))
if not res[0][0]:
db_table = PostgresTable(doctype, meta)
db_table.validate()
self.commit()
db_table.sync()
self.begin()
@staticmethod
def get_on_duplicate_update(key='name'):
if isinstance(key, list):
key = '", "'.join(key)
return 'ON CONFLICT ("{key}") DO UPDATE SET '.format(
key=key
)
def check_transaction_status(self, query):
pass
def has_index(self, table_name, index_name):
return self.sql("""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}'
and indexname='{index_name}' limit 1""".format(table_name=table_name, index_name=index_name))
def add_index(self, doctype: str, fields: List, index_name: str = None):
"""Creates an index with given fields if not already created.
Index name will be `fieldname1_fieldname2_index`"""
table_name = get_table_name(doctype)
index_name = index_name or self.get_index_name(fields)
fields_str = '", "'.join(re.sub(r"\(.*\)", "", field) for field in fields)
self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON `{table_name}` ("{fields_str}")')
def add_unique(self, doctype, fields, constraint_name=None):
if isinstance(fields, str):
fields = [fields]
if not constraint_name:
constraint_name = "unique_" + "_".join(fields)
if not self.sql("""
SELECT CONSTRAINT_NAME
FROM information_schema.TABLE_CONSTRAINTS
WHERE table_name=%s
AND constraint_type='UNIQUE'
AND CONSTRAINT_NAME=%s""",
('tab' + doctype, constraint_name)):
self.commit()
self.sql("""ALTER TABLE `tab%s`
ADD CONSTRAINT %s UNIQUE (%s)""" % (doctype, constraint_name, ", ".join(fields)))
def get_table_columns_description(self, table_name):
"""Returns list of column and its description"""
# pylint: disable=W1401
return self.sql('''
SELECT a.column_name AS name,
CASE LOWER(a.data_type)
WHEN 'character varying' THEN CONCAT('varchar(', a.character_maximum_length ,')')
WHEN 'timestamp without time zone' THEN 'timestamp'
ELSE a.data_type
END AS type,
COUNT(b.indexdef) AS Index,
SPLIT_PART(COALESCE(a.column_default, NULL), '::', 1) AS default,
BOOL_OR(b.unique) AS unique
FROM information_schema.columns a
LEFT JOIN
(SELECT indexdef, tablename, indexdef LIKE '%UNIQUE INDEX%' AS unique
FROM pg_indexes
WHERE tablename='{table_name}') b
ON SUBSTRING(b.indexdef, '\(.*\)') LIKE CONCAT('%', a.column_name, '%')
WHERE a.table_name = '{table_name}'
GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length;'''
.format(table_name=table_name), as_dict=1)
def get_database_list(self, target):
return [d[0] for d in self.sql("SELECT datname FROM pg_database;")]
def modify_query(query):
""""Modifies query according to the requirements of postgres"""
# replace ` with " for definitions
query = str(query)
query = query.replace('`', '"')
query = replace_locate_with_strpos(query)
# select from requires ""
if re.search('from tab', query, flags=re.IGNORECASE):
query = re.sub('from tab([a-zA-Z]*)', r'from "tab\1"', query, flags=re.IGNORECASE)
return query
def replace_locate_with_strpos(query):
# strpos is the locate equivalent in postgres
if re.search(r'locate\(', query, flags=re.IGNORECASE):
query = re.sub(r'locate\(([^,]+),([^)]+)\)', r'strpos(\2, \1)', query, flags=re.IGNORECASE)
return query