Merge pull request #27000 from pgruener/bugfix/postgres_schema_support

fix: Fixed schema isolation/support for postgres connectivity.
This commit is contained in:
Akhil Narang 2024-07-22 13:50:26 +05:30 committed by GitHub
commit 357307ffde
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 239 additions and 10 deletions

View file

@ -2,6 +2,7 @@ import re
import psycopg2
import psycopg2.extensions
from psycopg2 import sql
from psycopg2.errorcodes import (
CLASS_INTEGRITY_CONSTRAINT_VIOLATION,
DEADLOCK_DETECTED,
@ -169,6 +170,15 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
def last_query(self):
return LazyDecode(self._cursor.query)
@property
def db_schema(self):
return frappe.conf.get("db_schema", "public").replace("'", "").replace('"', "")
def connect(self):
super().connect()
self._cursor.execute("SET search_path TO %s", (self.db_schema,))
def get_connection(self):
conn_settings = {
"dbname": self.cur_db_name,
@ -228,12 +238,34 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
for d in self.sql(
"""select table_name
from information_schema.tables
where table_catalog='{}'
where table_catalog=%s
and table_type = 'BASE TABLE'
and table_schema='{}'""".format(self.cur_db_name, frappe.conf.get("db_schema", "public"))
and table_schema=%s""",
(self.cur_db_name, self.db_schema),
)
]
def get_db_table_columns(self, table) -> list[str]:
"""Returns list of column names from given table."""
if (columns := frappe.cache.hget("table_columns", table)) is not None:
return columns
information_schema = frappe.qb.Schema("information_schema")
columns = (
frappe.qb.from_(information_schema.columns)
.select(information_schema.columns.column_name)
.where(
(information_schema.columns.table_name == table)
& (information_schema.columns.table_schema == self.db_schema)
)
.run(pluck=True)
)
frappe.cache.hset("table_columns", table, columns)
return columns
def format_date(self, date):
if not date:
return "0001-01-01"
@ -260,7 +292,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
def describe(self, doctype: str) -> list | tuple:
table_name = get_table_name(doctype)
return self.sql(
f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}'"
f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}' and table_schema='{frappe.conf.get('db_schema', 'public')}'"
)
def change_column_type(
@ -349,8 +381,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
def has_index(self, table_name, index_name):
return self.sql(
f"""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}'
and indexname='{index_name}' limit 1"""
"""SELECT 1 FROM pg_indexes WHERE tablename=%s
and schemaname = %s
and indexname=%s limit 1""",
(table_name, self.db_schema, index_name),
)
def add_index(self, doctype: str, fields: list, index_name: str | None = None):
@ -360,7 +394,9 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
index_name = index_name or self.get_index_name(fields)
fields_str = '", "'.join(re.sub(r"\(.*\)", "", field) for field in fields)
self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON `{table_name}` ("{fields_str}")')
self.sql_ddl(
f'CREATE INDEX IF NOT EXISTS "{index_name}" ON "{self.db_schema}"."{table_name}" ("{fields_str}")'
)
def add_unique(self, doctype, fields, constraint_name=None):
if isinstance(fields, str):
@ -374,13 +410,24 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
FROM information_schema.TABLE_CONSTRAINTS
WHERE table_name=%s
AND constraint_type='UNIQUE'
AND constraint_schema=%s
AND CONSTRAINT_NAME=%s""",
("tab" + doctype, constraint_name),
("tab" + doctype, self.db_schema, constraint_name),
):
self.commit()
self.sql(
"""ALTER TABLE `tab{}`
ADD CONSTRAINT {} UNIQUE ({})""".format(doctype, constraint_name, ", ".join(fields))
sql.SQL(
"""ALTER TABLE {schema}.{table}
ADD CONSTRAINT {constraint} UNIQUE ({fields})"""
)
.format(
schema=sql.Identifier(self.db_schema),
table=sql.Identifier("tab" + doctype),
constraint=sql.Identifier(constraint_name),
fields=sql.SQL(", ").join(sql.Identifier(field) for field in fields),
)
.as_string(self._conn)
)
def get_table_columns_description(self, table_name):
@ -404,9 +451,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
indexdef LIKE '%UNIQUE INDEX%' AS unique,
indexdef NOT LIKE '%UNIQUE INDEX%' AS index
FROM pg_indexes
WHERE tablename='{table_name}') b
WHERE tablename='{table_name}' AND schemaname='{self.db_schema}') b
ON SUBSTRING(b.indexdef, '(.*)') LIKE CONCAT('%', a.column_name, '%')
WHERE a.table_name = '{table_name}'
AND a.table_schema = '{self.db_schema}'
GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length, a.is_nullable;
""",
as_dict=1,
@ -423,6 +471,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
.where(
(information_schema.columns.table_name == table)
& (information_schema.columns.column_name == column)
& (information_schema.columns.table_schema == self.db_schema)
)
.run(pluck=True)[0]
)

View file

@ -1081,6 +1081,186 @@ class TestSqlIterator(FrappeTestCase):
self.test_db_sql_iterator()
class ExtFrappeTestCase(FrappeTestCase):
def assertSqlException(self):
class SqlExceptionContextManager:
def __init__(self, test_case):
self.test_case = test_case
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self.test_case.fail("Expected exception but none was raised")
else:
frappe.db.rollback()
# Returning True suppresses the exception
return True
return SqlExceptionContextManager(self)
@run_only_if(db_type_is.POSTGRES)
class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase):
test_table_name = "TestSchemaTable"
def setUp(self, rollback=False) -> None:
if rollback:
frappe.db.rollback()
if frappe.db.sql(
"""SELECT 1
FROM information_schema.schemata
WHERE schema_name = 'alt_schema'
limit 1 """
):
self.cleanup()
frappe.db.sql(
f"""
CREATE SCHEMA alt_schema;
CREATE TABLE "public"."tab{self.test_table_name}" (
col_a VARCHAR,
col_b VARCHAR
);
CREATE TABLE "alt_schema"."tab{self.test_table_name}" (
col_c VARCHAR PRIMARY KEY,
col_d VARCHAR
);
CREATE TABLE "alt_schema"."tab{self.test_table_name}_2" (
col_c VARCHAR,
col_d VARCHAR
);
CREATE TABLE "alt_schema"."tabUser" (
col_c VARCHAR,
col_d VARCHAR
);
insert into "public"."tab{self.test_table_name}" (col_a, col_b) values ('a', 'b');
"""
)
def tearDown(self) -> None:
self.cleanup()
def cleanup(self) -> None:
frappe.db.sql(
f"""
DROP TABLE "public"."tab{self.test_table_name}";
DROP TABLE "alt_schema"."tab{self.test_table_name}";
DROP TABLE "alt_schema"."tab{self.test_table_name}_2";
DROP TABLE "alt_schema"."tabUser";
DROP SCHEMA "alt_schema" CASCADE;
"""
)
def test_get_tables(self) -> None:
tables = frappe.db.get_tables(cached=False)
# should have received the table {test_table_name} only once (from public schema)
count = sum([1 for table in tables if f"tab{self.test_table_name}" in table])
self.assertEqual(count, 1)
# should not have received {test_table_name}_2, as selection should only be from public schema
self.assertNotIn(f"tab{self.test_table_name}_2", tables)
def test_db_table_columns(self) -> None:
columns = frappe.db.get_table_columns(self.test_table_name)
# should have received the columns of the table from public schema
self.assertEqual(columns, ["col_a", "col_b"])
frappe.conf["db_schema"] = "alt_schema"
frappe.cache.delete_key("table_columns") # remove table columns cache for next try from alt_schema
# should have received the columns of the table from alt_schema
columns = frappe.db.get_table_columns(self.test_table_name)
self.assertEqual(columns, ["col_c", "col_d"])
del frappe.conf["db_schema"]
frappe.cache.delete_key("table_columns")
def test_describe(self) -> None:
self.assertSequenceEqual([("col_a",), ("col_b",)], frappe.db.describe(self.test_table_name))
def test_has_index(self) -> None:
# should not find any index on the table in default public schema (as it is only in the alt_schema)
self.assertFalse(frappe.db.has_index(f"tab{self.test_table_name}", f"tab{self.test_table_name}_pkey"))
def test_add_index(self) -> None:
frappe.conf["db_schema"] = "alt_schema"
# only dummy tabUser table in alt_schema has "col_c" column
frappe.db.add_index("User", ("col_c",))
del frappe.conf["db_schema"]
frappe.cache.delete_key("table_columns")
# the index creation in the default schema should fail
with self.assertSqlException():
frappe.db.add_index(doctype="User", fields=("col_c",))
# TODO: is there some method like remove_index:
# TODO: apps/frappe/frappe/patches/v14_0/drop_unused_indexes.py # def drop_index_if_exists()
# TODO: apps/frappe/frappe/database/postgres/schema.py # def alter()
def test_add_unique(self) -> None:
# should fail to add a unique constraint on the table in default public schema with those columns which are only present in alt_schema
with self.assertSqlException():
frappe.db.add_unique(f"{self.test_table_name}", ["col_c", "col_d"])
# but should work if the schema is configured to alt_schema
frappe.conf["db_schema"] = "alt_schema"
# should have received the columns of the table from alt_schema
frappe.db.add_unique(f"{self.test_table_name}", ["col_c", "col_d"])
del frappe.conf["db_schema"]
def test_get_table_columns_description(self):
# should only return the columns of the table in the default public schema
columns = frappe.db.get_table_columns_description(f"tab{self.test_table_name}")
self.assertTrue(any([col for col in columns if col["name"] == "col_a"]))
self.assertTrue(any([col for col in columns if col["name"] == "col_b"]))
self.assertFalse(any([col for col in columns if col["name"] == "col_c"]))
self.assertFalse(any([col for col in columns if col["name"] == "col_d"]))
def test_get_column_type(self):
# should return the column type of the column in the default public schema
self.assertEqual(frappe.db.get_column_type(self.test_table_name, "col_a"), "character varying")
# should raise an error for the column in the alt_schema
with self.assertSqlException():
frappe.db.get_column_type(self.test_table_name, "col_c")
def test_search_path(self):
# by default the the public schema tables should be addressed by search path
rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"')
self.assertEqual(
rows,
[
(
"a",
"b",
)
],
) # there should be a single row in the public table
# when schema is changed to alt_schema, the alt_schema tables should be addressed by search path
frappe.conf["db_schema"] = "alt_schema"
frappe.db.connect()
rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"')
self.assertEqual(rows, []) # there are no records in the alt_schema table
del frappe.conf["db_schema"]
class TestDbConnectWithEnvCredentials(FrappeTestCase):
current_site = frappe.local.site