Merge pull request #27000 from pgruener/bugfix/postgres_schema_support
fix: Fixed schema isolation/support for postgres connectivity.
This commit is contained in:
commit
357307ffde
2 changed files with 239 additions and 10 deletions
|
|
@ -2,6 +2,7 @@ import re
|
|||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
from psycopg2 import sql
|
||||
from psycopg2.errorcodes import (
|
||||
CLASS_INTEGRITY_CONSTRAINT_VIOLATION,
|
||||
DEADLOCK_DETECTED,
|
||||
|
|
@ -169,6 +170,15 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
def last_query(self):
|
||||
return LazyDecode(self._cursor.query)
|
||||
|
||||
@property
|
||||
def db_schema(self):
|
||||
return frappe.conf.get("db_schema", "public").replace("'", "").replace('"', "")
|
||||
|
||||
def connect(self):
|
||||
super().connect()
|
||||
|
||||
self._cursor.execute("SET search_path TO %s", (self.db_schema,))
|
||||
|
||||
def get_connection(self):
|
||||
conn_settings = {
|
||||
"dbname": self.cur_db_name,
|
||||
|
|
@ -228,12 +238,34 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
for d in self.sql(
|
||||
"""select table_name
|
||||
from information_schema.tables
|
||||
where table_catalog='{}'
|
||||
where table_catalog=%s
|
||||
and table_type = 'BASE TABLE'
|
||||
and table_schema='{}'""".format(self.cur_db_name, frappe.conf.get("db_schema", "public"))
|
||||
and table_schema=%s""",
|
||||
(self.cur_db_name, self.db_schema),
|
||||
)
|
||||
]
|
||||
|
||||
def get_db_table_columns(self, table) -> list[str]:
|
||||
"""Returns list of column names from given table."""
|
||||
if (columns := frappe.cache.hget("table_columns", table)) is not None:
|
||||
return columns
|
||||
|
||||
information_schema = frappe.qb.Schema("information_schema")
|
||||
|
||||
columns = (
|
||||
frappe.qb.from_(information_schema.columns)
|
||||
.select(information_schema.columns.column_name)
|
||||
.where(
|
||||
(information_schema.columns.table_name == table)
|
||||
& (information_schema.columns.table_schema == self.db_schema)
|
||||
)
|
||||
.run(pluck=True)
|
||||
)
|
||||
|
||||
frappe.cache.hset("table_columns", table, columns)
|
||||
|
||||
return columns
|
||||
|
||||
def format_date(self, date):
|
||||
if not date:
|
||||
return "0001-01-01"
|
||||
|
|
@ -260,7 +292,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
def describe(self, doctype: str) -> list | tuple:
|
||||
table_name = get_table_name(doctype)
|
||||
return self.sql(
|
||||
f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}'"
|
||||
f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}' and table_schema='{frappe.conf.get('db_schema', 'public')}'"
|
||||
)
|
||||
|
||||
def change_column_type(
|
||||
|
|
@ -349,8 +381,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
|
||||
def has_index(self, table_name, index_name):
|
||||
return self.sql(
|
||||
f"""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}'
|
||||
and indexname='{index_name}' limit 1"""
|
||||
"""SELECT 1 FROM pg_indexes WHERE tablename=%s
|
||||
and schemaname = %s
|
||||
and indexname=%s limit 1""",
|
||||
(table_name, self.db_schema, index_name),
|
||||
)
|
||||
|
||||
def add_index(self, doctype: str, fields: list, index_name: str | None = None):
|
||||
|
|
@ -360,7 +394,9 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
index_name = index_name or self.get_index_name(fields)
|
||||
fields_str = '", "'.join(re.sub(r"\(.*\)", "", field) for field in fields)
|
||||
|
||||
self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON `{table_name}` ("{fields_str}")')
|
||||
self.sql_ddl(
|
||||
f'CREATE INDEX IF NOT EXISTS "{index_name}" ON "{self.db_schema}"."{table_name}" ("{fields_str}")'
|
||||
)
|
||||
|
||||
def add_unique(self, doctype, fields, constraint_name=None):
|
||||
if isinstance(fields, str):
|
||||
|
|
@ -374,13 +410,24 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
FROM information_schema.TABLE_CONSTRAINTS
|
||||
WHERE table_name=%s
|
||||
AND constraint_type='UNIQUE'
|
||||
AND constraint_schema=%s
|
||||
AND CONSTRAINT_NAME=%s""",
|
||||
("tab" + doctype, constraint_name),
|
||||
("tab" + doctype, self.db_schema, constraint_name),
|
||||
):
|
||||
self.commit()
|
||||
|
||||
self.sql(
|
||||
"""ALTER TABLE `tab{}`
|
||||
ADD CONSTRAINT {} UNIQUE ({})""".format(doctype, constraint_name, ", ".join(fields))
|
||||
sql.SQL(
|
||||
"""ALTER TABLE {schema}.{table}
|
||||
ADD CONSTRAINT {constraint} UNIQUE ({fields})"""
|
||||
)
|
||||
.format(
|
||||
schema=sql.Identifier(self.db_schema),
|
||||
table=sql.Identifier("tab" + doctype),
|
||||
constraint=sql.Identifier(constraint_name),
|
||||
fields=sql.SQL(", ").join(sql.Identifier(field) for field in fields),
|
||||
)
|
||||
.as_string(self._conn)
|
||||
)
|
||||
|
||||
def get_table_columns_description(self, table_name):
|
||||
|
|
@ -404,9 +451,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
indexdef LIKE '%UNIQUE INDEX%' AS unique,
|
||||
indexdef NOT LIKE '%UNIQUE INDEX%' AS index
|
||||
FROM pg_indexes
|
||||
WHERE tablename='{table_name}') b
|
||||
WHERE tablename='{table_name}' AND schemaname='{self.db_schema}') b
|
||||
ON SUBSTRING(b.indexdef, '(.*)') LIKE CONCAT('%', a.column_name, '%')
|
||||
WHERE a.table_name = '{table_name}'
|
||||
AND a.table_schema = '{self.db_schema}'
|
||||
GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length, a.is_nullable;
|
||||
""",
|
||||
as_dict=1,
|
||||
|
|
@ -423,6 +471,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
.where(
|
||||
(information_schema.columns.table_name == table)
|
||||
& (information_schema.columns.column_name == column)
|
||||
& (information_schema.columns.table_schema == self.db_schema)
|
||||
)
|
||||
.run(pluck=True)[0]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1081,6 +1081,186 @@ class TestSqlIterator(FrappeTestCase):
|
|||
self.test_db_sql_iterator()
|
||||
|
||||
|
||||
class ExtFrappeTestCase(FrappeTestCase):
|
||||
def assertSqlException(self):
|
||||
class SqlExceptionContextManager:
|
||||
def __init__(self, test_case):
|
||||
self.test_case = test_case
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type is None:
|
||||
self.test_case.fail("Expected exception but none was raised")
|
||||
else:
|
||||
frappe.db.rollback()
|
||||
# Returning True suppresses the exception
|
||||
return True
|
||||
|
||||
return SqlExceptionContextManager(self)
|
||||
|
||||
|
||||
@run_only_if(db_type_is.POSTGRES)
|
||||
class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase):
|
||||
test_table_name = "TestSchemaTable"
|
||||
|
||||
def setUp(self, rollback=False) -> None:
|
||||
if rollback:
|
||||
frappe.db.rollback()
|
||||
|
||||
if frappe.db.sql(
|
||||
"""SELECT 1
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name = 'alt_schema'
|
||||
limit 1 """
|
||||
):
|
||||
self.cleanup()
|
||||
|
||||
frappe.db.sql(
|
||||
f"""
|
||||
CREATE SCHEMA alt_schema;
|
||||
|
||||
CREATE TABLE "public"."tab{self.test_table_name}" (
|
||||
col_a VARCHAR,
|
||||
col_b VARCHAR
|
||||
);
|
||||
|
||||
CREATE TABLE "alt_schema"."tab{self.test_table_name}" (
|
||||
col_c VARCHAR PRIMARY KEY,
|
||||
col_d VARCHAR
|
||||
);
|
||||
|
||||
CREATE TABLE "alt_schema"."tab{self.test_table_name}_2" (
|
||||
col_c VARCHAR,
|
||||
col_d VARCHAR
|
||||
);
|
||||
|
||||
CREATE TABLE "alt_schema"."tabUser" (
|
||||
col_c VARCHAR,
|
||||
col_d VARCHAR
|
||||
);
|
||||
|
||||
insert into "public"."tab{self.test_table_name}" (col_a, col_b) values ('a', 'b');
|
||||
"""
|
||||
)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self) -> None:
|
||||
frappe.db.sql(
|
||||
f"""
|
||||
DROP TABLE "public"."tab{self.test_table_name}";
|
||||
DROP TABLE "alt_schema"."tab{self.test_table_name}";
|
||||
DROP TABLE "alt_schema"."tab{self.test_table_name}_2";
|
||||
DROP TABLE "alt_schema"."tabUser";
|
||||
DROP SCHEMA "alt_schema" CASCADE;
|
||||
"""
|
||||
)
|
||||
|
||||
def test_get_tables(self) -> None:
|
||||
tables = frappe.db.get_tables(cached=False)
|
||||
|
||||
# should have received the table {test_table_name} only once (from public schema)
|
||||
count = sum([1 for table in tables if f"tab{self.test_table_name}" in table])
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
# should not have received {test_table_name}_2, as selection should only be from public schema
|
||||
self.assertNotIn(f"tab{self.test_table_name}_2", tables)
|
||||
|
||||
def test_db_table_columns(self) -> None:
|
||||
columns = frappe.db.get_table_columns(self.test_table_name)
|
||||
|
||||
# should have received the columns of the table from public schema
|
||||
self.assertEqual(columns, ["col_a", "col_b"])
|
||||
|
||||
frappe.conf["db_schema"] = "alt_schema"
|
||||
frappe.cache.delete_key("table_columns") # remove table columns cache for next try from alt_schema
|
||||
|
||||
# should have received the columns of the table from alt_schema
|
||||
columns = frappe.db.get_table_columns(self.test_table_name)
|
||||
self.assertEqual(columns, ["col_c", "col_d"])
|
||||
|
||||
del frappe.conf["db_schema"]
|
||||
frappe.cache.delete_key("table_columns")
|
||||
|
||||
def test_describe(self) -> None:
|
||||
self.assertSequenceEqual([("col_a",), ("col_b",)], frappe.db.describe(self.test_table_name))
|
||||
|
||||
def test_has_index(self) -> None:
|
||||
# should not find any index on the table in default public schema (as it is only in the alt_schema)
|
||||
self.assertFalse(frappe.db.has_index(f"tab{self.test_table_name}", f"tab{self.test_table_name}_pkey"))
|
||||
|
||||
def test_add_index(self) -> None:
|
||||
frappe.conf["db_schema"] = "alt_schema"
|
||||
|
||||
# only dummy tabUser table in alt_schema has "col_c" column
|
||||
frappe.db.add_index("User", ("col_c",))
|
||||
|
||||
del frappe.conf["db_schema"]
|
||||
frappe.cache.delete_key("table_columns")
|
||||
|
||||
# the index creation in the default schema should fail
|
||||
with self.assertSqlException():
|
||||
frappe.db.add_index(doctype="User", fields=("col_c",))
|
||||
|
||||
# TODO: is there some method like remove_index:
|
||||
# TODO: apps/frappe/frappe/patches/v14_0/drop_unused_indexes.py # def drop_index_if_exists()
|
||||
# TODO: apps/frappe/frappe/database/postgres/schema.py # def alter()
|
||||
|
||||
def test_add_unique(self) -> None:
|
||||
# should fail to add a unique constraint on the table in default public schema with those columns which are only present in alt_schema
|
||||
with self.assertSqlException():
|
||||
frappe.db.add_unique(f"{self.test_table_name}", ["col_c", "col_d"])
|
||||
|
||||
# but should work if the schema is configured to alt_schema
|
||||
frappe.conf["db_schema"] = "alt_schema"
|
||||
|
||||
# should have received the columns of the table from alt_schema
|
||||
frappe.db.add_unique(f"{self.test_table_name}", ["col_c", "col_d"])
|
||||
|
||||
del frappe.conf["db_schema"]
|
||||
|
||||
def test_get_table_columns_description(self):
|
||||
# should only return the columns of the table in the default public schema
|
||||
columns = frappe.db.get_table_columns_description(f"tab{self.test_table_name}")
|
||||
|
||||
self.assertTrue(any([col for col in columns if col["name"] == "col_a"]))
|
||||
self.assertTrue(any([col for col in columns if col["name"] == "col_b"]))
|
||||
self.assertFalse(any([col for col in columns if col["name"] == "col_c"]))
|
||||
self.assertFalse(any([col for col in columns if col["name"] == "col_d"]))
|
||||
|
||||
def test_get_column_type(self):
|
||||
# should return the column type of the column in the default public schema
|
||||
self.assertEqual(frappe.db.get_column_type(self.test_table_name, "col_a"), "character varying")
|
||||
|
||||
# should raise an error for the column in the alt_schema
|
||||
with self.assertSqlException():
|
||||
frappe.db.get_column_type(self.test_table_name, "col_c")
|
||||
|
||||
def test_search_path(self):
|
||||
# by default the the public schema tables should be addressed by search path
|
||||
rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"')
|
||||
self.assertEqual(
|
||||
rows,
|
||||
[
|
||||
(
|
||||
"a",
|
||||
"b",
|
||||
)
|
||||
],
|
||||
) # there should be a single row in the public table
|
||||
|
||||
# when schema is changed to alt_schema, the alt_schema tables should be addressed by search path
|
||||
frappe.conf["db_schema"] = "alt_schema"
|
||||
frappe.db.connect()
|
||||
rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"')
|
||||
self.assertEqual(rows, []) # there are no records in the alt_schema table
|
||||
|
||||
del frappe.conf["db_schema"]
|
||||
|
||||
|
||||
class TestDbConnectWithEnvCredentials(FrappeTestCase):
|
||||
current_site = frappe.local.site
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue