diff --git a/frappe/database/postgres/database.py b/frappe/database/postgres/database.py index 7bae004986..56ae71a9ad 100644 --- a/frappe/database/postgres/database.py +++ b/frappe/database/postgres/database.py @@ -2,6 +2,7 @@ import re import psycopg2 import psycopg2.extensions +from psycopg2 import sql from psycopg2.errorcodes import ( CLASS_INTEGRITY_CONSTRAINT_VIOLATION, DEADLOCK_DETECTED, @@ -169,6 +170,15 @@ class PostgresDatabase(PostgresExceptionUtil, Database): def last_query(self): return LazyDecode(self._cursor.query) + @property + def db_schema(self): + return frappe.conf.get("db_schema", "public").replace("'", "").replace('"', "") + + def connect(self): + super().connect() + + self._cursor.execute("SET search_path TO %s", (self.db_schema,)) + def get_connection(self): conn_settings = { "dbname": self.cur_db_name, @@ -228,12 +238,34 @@ class PostgresDatabase(PostgresExceptionUtil, Database): for d in self.sql( """select table_name from information_schema.tables - where table_catalog='{}' + where table_catalog=%s and table_type = 'BASE TABLE' - and table_schema='{}'""".format(self.cur_db_name, frappe.conf.get("db_schema", "public")) + and table_schema=%s""", + (self.cur_db_name, self.db_schema), ) ] + def get_db_table_columns(self, table) -> list[str]: + """Returns list of column names from given table.""" + if (columns := frappe.cache.hget("table_columns", table)) is not None: + return columns + + information_schema = frappe.qb.Schema("information_schema") + + columns = ( + frappe.qb.from_(information_schema.columns) + .select(information_schema.columns.column_name) + .where( + (information_schema.columns.table_name == table) + & (information_schema.columns.table_schema == self.db_schema) + ) + .run(pluck=True) + ) + + frappe.cache.hset("table_columns", table, columns) + + return columns + def format_date(self, date): if not date: return "0001-01-01" @@ -260,7 +292,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database): def describe(self, doctype: str) -> list | tuple: table_name = get_table_name(doctype) return self.sql( - f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}'" + f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}' and table_schema='{frappe.conf.get('db_schema', 'public')}'" ) def change_column_type( @@ -349,8 +381,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database): def has_index(self, table_name, index_name): return self.sql( - f"""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}' - and indexname='{index_name}' limit 1""" + """SELECT 1 FROM pg_indexes WHERE tablename=%s + and schemaname = %s + and indexname=%s limit 1""", + (table_name, self.db_schema, index_name), ) def add_index(self, doctype: str, fields: list, index_name: str | None = None): @@ -360,7 +394,9 @@ class PostgresDatabase(PostgresExceptionUtil, Database): index_name = index_name or self.get_index_name(fields) fields_str = '", "'.join(re.sub(r"\(.*\)", "", field) for field in fields) - self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON `{table_name}` ("{fields_str}")') + self.sql_ddl( + f'CREATE INDEX IF NOT EXISTS "{index_name}" ON "{self.db_schema}"."{table_name}" ("{fields_str}")' + ) def add_unique(self, doctype, fields, constraint_name=None): if isinstance(fields, str): @@ -374,13 +410,24 @@ class PostgresDatabase(PostgresExceptionUtil, Database): FROM information_schema.TABLE_CONSTRAINTS WHERE table_name=%s AND constraint_type='UNIQUE' + AND constraint_schema=%s AND CONSTRAINT_NAME=%s""", - ("tab" + doctype, constraint_name), + ("tab" + doctype, self.db_schema, constraint_name), ): self.commit() + self.sql( - """ALTER TABLE `tab{}` - ADD CONSTRAINT {} UNIQUE ({})""".format(doctype, constraint_name, ", ".join(fields)) + sql.SQL( + """ALTER TABLE {schema}.{table} + ADD CONSTRAINT {constraint} UNIQUE ({fields})""" + ) + .format( + schema=sql.Identifier(self.db_schema), + table=sql.Identifier("tab" + doctype), + constraint=sql.Identifier(constraint_name), + fields=sql.SQL(", ").join(sql.Identifier(field) for field in fields), + ) + .as_string(self._conn) ) def get_table_columns_description(self, table_name): @@ -404,9 +451,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database): indexdef LIKE '%UNIQUE INDEX%' AS unique, indexdef NOT LIKE '%UNIQUE INDEX%' AS index FROM pg_indexes - WHERE tablename='{table_name}') b + WHERE tablename='{table_name}' AND schemaname='{self.db_schema}') b ON SUBSTRING(b.indexdef, '(.*)') LIKE CONCAT('%', a.column_name, '%') WHERE a.table_name = '{table_name}' + AND a.table_schema = '{self.db_schema}' GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length, a.is_nullable; """, as_dict=1, @@ -423,6 +471,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database): .where( (information_schema.columns.table_name == table) & (information_schema.columns.column_name == column) + & (information_schema.columns.table_schema == self.db_schema) ) .run(pluck=True)[0] ) diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index ee57d48e94..63c0dba7da 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -1081,6 +1081,186 @@ class TestSqlIterator(FrappeTestCase): self.test_db_sql_iterator() +class ExtFrappeTestCase(FrappeTestCase): + def assertSqlException(self): + class SqlExceptionContextManager: + def __init__(self, test_case): + self.test_case = test_case + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self.test_case.fail("Expected exception but none was raised") + else: + frappe.db.rollback() + # Returning True suppresses the exception + return True + + return SqlExceptionContextManager(self) + + +@run_only_if(db_type_is.POSTGRES) +class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase): + test_table_name = "TestSchemaTable" + + def setUp(self, rollback=False) -> None: + if rollback: + frappe.db.rollback() + + if frappe.db.sql( + """SELECT 1 + FROM information_schema.schemata + WHERE schema_name = 'alt_schema' + limit 1 """ + ): + self.cleanup() + + frappe.db.sql( + f""" + CREATE SCHEMA alt_schema; + + CREATE TABLE "public"."tab{self.test_table_name}" ( + col_a VARCHAR, + col_b VARCHAR + ); + + CREATE TABLE "alt_schema"."tab{self.test_table_name}" ( + col_c VARCHAR PRIMARY KEY, + col_d VARCHAR + ); + + CREATE TABLE "alt_schema"."tab{self.test_table_name}_2" ( + col_c VARCHAR, + col_d VARCHAR + ); + + CREATE TABLE "alt_schema"."tabUser" ( + col_c VARCHAR, + col_d VARCHAR + ); + + insert into "public"."tab{self.test_table_name}" (col_a, col_b) values ('a', 'b'); + """ + ) + + def tearDown(self) -> None: + self.cleanup() + + def cleanup(self) -> None: + frappe.db.sql( + f""" + DROP TABLE "public"."tab{self.test_table_name}"; + DROP TABLE "alt_schema"."tab{self.test_table_name}"; + DROP TABLE "alt_schema"."tab{self.test_table_name}_2"; + DROP TABLE "alt_schema"."tabUser"; + DROP SCHEMA "alt_schema" CASCADE; + """ + ) + + def test_get_tables(self) -> None: + tables = frappe.db.get_tables(cached=False) + + # should have received the table {test_table_name} only once (from public schema) + count = sum([1 for table in tables if f"tab{self.test_table_name}" in table]) + self.assertEqual(count, 1) + + # should not have received {test_table_name}_2, as selection should only be from public schema + self.assertNotIn(f"tab{self.test_table_name}_2", tables) + + def test_db_table_columns(self) -> None: + columns = frappe.db.get_table_columns(self.test_table_name) + + # should have received the columns of the table from public schema + self.assertEqual(columns, ["col_a", "col_b"]) + + frappe.conf["db_schema"] = "alt_schema" + frappe.cache.delete_key("table_columns") # remove table columns cache for next try from alt_schema + + # should have received the columns of the table from alt_schema + columns = frappe.db.get_table_columns(self.test_table_name) + self.assertEqual(columns, ["col_c", "col_d"]) + + del frappe.conf["db_schema"] + frappe.cache.delete_key("table_columns") + + def test_describe(self) -> None: + self.assertSequenceEqual([("col_a",), ("col_b",)], frappe.db.describe(self.test_table_name)) + + def test_has_index(self) -> None: + # should not find any index on the table in default public schema (as it is only in the alt_schema) + self.assertFalse(frappe.db.has_index(f"tab{self.test_table_name}", f"tab{self.test_table_name}_pkey")) + + def test_add_index(self) -> None: + frappe.conf["db_schema"] = "alt_schema" + + # only dummy tabUser table in alt_schema has "col_c" column + frappe.db.add_index("User", ("col_c",)) + + del frappe.conf["db_schema"] + frappe.cache.delete_key("table_columns") + + # the index creation in the default schema should fail + with self.assertSqlException(): + frappe.db.add_index(doctype="User", fields=("col_c",)) + + # TODO: is there some method like remove_index: + # TODO: apps/frappe/frappe/patches/v14_0/drop_unused_indexes.py # def drop_index_if_exists() + # TODO: apps/frappe/frappe/database/postgres/schema.py # def alter() + + def test_add_unique(self) -> None: + # should fail to add a unique constraint on the table in default public schema with those columns which are only present in alt_schema + with self.assertSqlException(): + frappe.db.add_unique(f"{self.test_table_name}", ["col_c", "col_d"]) + + # but should work if the schema is configured to alt_schema + frappe.conf["db_schema"] = "alt_schema" + + # should have received the columns of the table from alt_schema + frappe.db.add_unique(f"{self.test_table_name}", ["col_c", "col_d"]) + + del frappe.conf["db_schema"] + + def test_get_table_columns_description(self): + # should only return the columns of the table in the default public schema + columns = frappe.db.get_table_columns_description(f"tab{self.test_table_name}") + + self.assertTrue(any([col for col in columns if col["name"] == "col_a"])) + self.assertTrue(any([col for col in columns if col["name"] == "col_b"])) + self.assertFalse(any([col for col in columns if col["name"] == "col_c"])) + self.assertFalse(any([col for col in columns if col["name"] == "col_d"])) + + def test_get_column_type(self): + # should return the column type of the column in the default public schema + self.assertEqual(frappe.db.get_column_type(self.test_table_name, "col_a"), "character varying") + + # should raise an error for the column in the alt_schema + with self.assertSqlException(): + frappe.db.get_column_type(self.test_table_name, "col_c") + + def test_search_path(self): + # by default the the public schema tables should be addressed by search path + rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"') + self.assertEqual( + rows, + [ + ( + "a", + "b", + ) + ], + ) # there should be a single row in the public table + + # when schema is changed to alt_schema, the alt_schema tables should be addressed by search path + frappe.conf["db_schema"] = "alt_schema" + frappe.db.connect() + rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"') + self.assertEqual(rows, []) # there are no records in the alt_schema table + + del frappe.conf["db_schema"] + + class TestDbConnectWithEnvCredentials(FrappeTestCase): current_site = frappe.local.site