From afd95691e9448f24279abd51ea4c3f38cd24b4ae Mon Sep 17 00:00:00 2001 From: Philipp Gruener Date: Thu, 4 Jul 2024 19:10:34 +0200 Subject: [PATCH 1/5] fix: Fixed schema isolation/support for postgres connectivity. --- frappe/database/postgres/database.py | 56 +++++++-- frappe/tests/test_db.py | 168 +++++++++++++++++++++++++++ 2 files changed, 214 insertions(+), 10 deletions(-) diff --git a/frappe/database/postgres/database.py b/frappe/database/postgres/database.py index 7bae004986..e7a15c72ff 100644 --- a/frappe/database/postgres/database.py +++ b/frappe/database/postgres/database.py @@ -20,6 +20,7 @@ from psycopg2.errors import ( SyntaxError, ) from psycopg2.extensions import ISOLATION_LEVEL_REPEATABLE_READ +from psycopg2 import sql import frappe from frappe.database.database import Database @@ -169,6 +170,15 @@ class PostgresDatabase(PostgresExceptionUtil, Database): def last_query(self): return LazyDecode(self._cursor.query) + @property + def db_schema(self): + return re.sub(r'["\']', '', frappe.conf.get("db_schema", "public")) + + def connect(self): + super().connect() + + self._cursor.execute("SET search_path TO %s", (self.db_schema,)) + def get_connection(self): conn_settings = { "dbname": self.cur_db_name, @@ -228,12 +238,30 @@ class PostgresDatabase(PostgresExceptionUtil, Database): for d in self.sql( """select table_name from information_schema.tables - where table_catalog='{}' + where table_catalog=%s and table_type = 'BASE TABLE' - and table_schema='{}'""".format(self.cur_db_name, frappe.conf.get("db_schema", "public")) + and table_schema=%s""", (self.cur_db_name, self.db_schema) ) ] + def get_db_table_columns(self, table) -> list[str]: + """Returns list of column names from given table.""" + columns = frappe.cache.hget("table_columns", table) + if columns is None: + information_schema = frappe.qb.Schema("information_schema") + + columns = ( + frappe.qb.from_(information_schema.columns) + .select(information_schema.columns.column_name) + .where((information_schema.columns.table_name == table) & (information_schema.columns.table_schema == self.db_schema)) + .run(pluck=True) + ) + + if columns: + frappe.cache.hset("table_columns", table, columns) + + return columns + def format_date(self, date): if not date: return "0001-01-01" @@ -260,7 +288,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database): def describe(self, doctype: str) -> list | tuple: table_name = get_table_name(doctype) return self.sql( - f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}'" + f"SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_NAME = '{table_name}' and table_schema='{frappe.conf.get('db_schema', 'public')}'" ) def change_column_type( @@ -349,8 +377,9 @@ class PostgresDatabase(PostgresExceptionUtil, Database): def has_index(self, table_name, index_name): return self.sql( - f"""SELECT 1 FROM pg_indexes WHERE tablename='{table_name}' - and indexname='{index_name}' limit 1""" + """SELECT 1 FROM pg_indexes WHERE tablename=%s + and schemaname = %s + and indexname=%s limit 1""", (table_name, self.db_schema, index_name) ) def add_index(self, doctype: str, fields: list, index_name: str | None = None): @@ -360,7 +389,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database): index_name = index_name or self.get_index_name(fields) fields_str = '", "'.join(re.sub(r"\(.*\)", "", field) for field in fields) - self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON `{table_name}` ("{fields_str}")') + self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON "{self.db_schema}"."{table_name}" ("{fields_str}")') def add_unique(self, doctype, fields, constraint_name=None): if isinstance(fields, str): @@ -374,13 +403,18 @@ class PostgresDatabase(PostgresExceptionUtil, Database): FROM information_schema.TABLE_CONSTRAINTS WHERE table_name=%s AND constraint_type='UNIQUE' + AND constraint_schema=%s AND CONSTRAINT_NAME=%s""", - ("tab" + doctype, constraint_name), + ("tab" + doctype, self.db_schema, constraint_name), ): self.commit() + self.sql( - """ALTER TABLE `tab{}` - ADD CONSTRAINT {} UNIQUE ({})""".format(doctype, constraint_name, ", ".join(fields)) + sql.SQL("""ALTER TABLE {schema}.{table} + ADD CONSTRAINT {constraint} UNIQUE ({fields})""").format(schema=sql.Identifier(self.db_schema), + table=sql.Identifier('tab' + doctype), + constraint=sql.Identifier(constraint_name), + fields=sql.SQL(', ').join(map(sql.Identifier, fields))).as_string(self._conn) ) def get_table_columns_description(self, table_name): @@ -404,9 +438,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database): indexdef LIKE '%UNIQUE INDEX%' AS unique, indexdef NOT LIKE '%UNIQUE INDEX%' AS index FROM pg_indexes - WHERE tablename='{table_name}') b + WHERE tablename='{table_name}' AND schemaname='{self.db_schema}') b ON SUBSTRING(b.indexdef, '(.*)') LIKE CONCAT('%', a.column_name, '%') WHERE a.table_name = '{table_name}' + AND a.table_schema = '{self.db_schema}' GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length, a.is_nullable; """, as_dict=1, @@ -423,6 +458,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database): .where( (information_schema.columns.table_name == table) & (information_schema.columns.column_name == column) + & (information_schema.columns.table_schema == self.db_schema) ) .run(pluck=True)[0] ) diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index 715522d868..77470b153b 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -1079,3 +1079,171 @@ class TestSqlIterator(FrappeTestCase): def test_unbuffered_cursor(self): with frappe.db.unbuffered_cursor(): self.test_db_sql_iterator() + +class ExtFrappeTestCase(FrappeTestCase): + def assertSqlException(self): + class SqlExceptionContextManager: + def __init__(self, test_case): + self.test_case = test_case + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self.test_case.fail("Expected exception but none was raised") + else: + frappe.db.rollback() + # Returning True suppresses the exception + return True + + return SqlExceptionContextManager(self) + +@run_only_if(db_type_is.POSTGRES) +class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase): + test_table_name = "TestSchemaTable" + + def setUp(self, rollback=False) -> None: + if rollback: + frappe.db.rollback() + + + if frappe.db.sql("""SELECT 1 + FROM information_schema.schemata + WHERE schema_name = 'alt_schema' + limit 1 """): + self.cleanup() + + frappe.db.sql( + f""" + CREATE SCHEMA alt_schema; + + CREATE TABLE "public"."tab{self.test_table_name}" ( + col_a VARCHAR, + col_b VARCHAR + ); + + CREATE TABLE "alt_schema"."tab{self.test_table_name}" ( + col_c VARCHAR PRIMARY KEY, + col_d VARCHAR + ); + + CREATE TABLE "alt_schema"."tab{self.test_table_name}_2" ( + col_c VARCHAR, + col_d VARCHAR + ); + + CREATE TABLE "alt_schema"."tabUser" ( + col_c VARCHAR, + col_d VARCHAR + ); + + insert into "public"."tab{self.test_table_name}" (col_a, col_b) values ('a', 'b'); + """ + ) + + def tearDown(self) -> None: + self.cleanup() + + def cleanup(self) -> None: + frappe.db.sql(f""" + DROP TABLE "public"."tab{self.test_table_name}"; + DROP TABLE "alt_schema"."tab{self.test_table_name}"; + DROP TABLE "alt_schema"."tab{self.test_table_name}_2"; + DROP TABLE "alt_schema"."tabUser"; + DROP SCHEMA "alt_schema" CASCADE; + """) + # DROP USER u_alt_schema; + + def test_get_tables(self) -> None: + tables = frappe.db.get_tables(cached=False) + + # should have received the table {test_table_name} only once (from public schema) + count = sum([1 for table in tables if f"tab{self.test_table_name}" in table]) + self.assertEqual(count, 1) + + # should not have received {test_table_name}_2, as selection should only be from public schema + self.assertNotIn(f"tab{self.test_table_name}_2", tables) + + def test_db_table_columns(self) -> None: + columns = frappe.db.get_table_columns(self.test_table_name) + + # should have received the columns of the table from public schema + self.assertEqual(columns, ["col_a", "col_b"]) + + frappe.conf["db_schema"] = "alt_schema" + frappe.cache.delete_key("table_columns") # remove table columns cache for next try from alt_schema + + # should have received the columns of the table from alt_schema + columns = frappe.db.get_table_columns(self.test_table_name) + self.assertEqual(columns, ["col_c", "col_d"]) + + del frappe.conf["db_schema"] + frappe.cache.delete_key("table_columns") + + def test_describe(self) -> None: + self.assertSequenceEqual([("col_a",), ("col_b",)], frappe.db.describe(self.test_table_name)) + + def test_has_index(self) -> None: + # should not find any index on the table in default public schema (as it is only in the alt_schema) + self.assertFalse(frappe.db.has_index(f"tab{self.test_table_name}", f"tab{self.test_table_name}_pkey")) + + def test_add_index(self) -> None: + frappe.conf["db_schema"] = "alt_schema" + + # only dummy tabUser table in alt_schema has "col_c" column + frappe.db.add_index("User", ("col_c",)) + + del frappe.conf["db_schema"] + frappe.cache.delete_key("table_columns") + + # the index creation in the default schema should fail + with self.assertSqlException(): + frappe.db.add_index(doctype="User", fields=("col_c",)) + + # TODO: is there some method like remove_index: + # TODO: apps/frappe/frappe/patches/v14_0/drop_unused_indexes.py # def drop_index_if_exists() + # TODO: apps/frappe/frappe/database/postgres/schema.py # def alter() + + def test_add_unique(self) -> None: + # should fail to add a unique constraint on the table in default public schema with those columns which are only present in alt_schema + with self.assertSqlException(): + frappe.db.add_unique(f"{self.test_table_name}", ["col_c", "col_d"]) + + # but should work if the schema is configured to alt_schema + frappe.conf["db_schema"] = "alt_schema" + + # should have received the columns of the table from alt_schema + frappe.db.add_unique(f"{self.test_table_name}", ["col_c", "col_d"]) + + del frappe.conf["db_schema"] + + def test_get_table_columns_description(self): + # should only return the columns of the table in the default public schema + columns = frappe.db.get_table_columns_description(f'tab{self.test_table_name}') + + self.assertTrue(any([col for col in columns if col["name"] == "col_a"])) + self.assertTrue(any([col for col in columns if col["name"] == "col_b"])) + self.assertFalse(any([col for col in columns if col["name"] == "col_c"])) + self.assertFalse(any([col for col in columns if col["name"] == "col_d"])) + + def test_get_column_type(self): + # should return the column type of the column in the default public schema + self.assertEqual(frappe.db.get_column_type(self.test_table_name, "col_a"), "character varying") + + # should raise an error for the column in the alt_schema + with self.assertSqlException(): + frappe.db.get_column_type(self.test_table_name, "col_c") + + def test_search_path(self): + # by default the the public schema tables should be addressed by search path + rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"') + self.assertEqual(rows, [("a", "b",)]) # there should be a single row in the public table + + # when schema is changed to alt_schema, the alt_schema tables should be addressed by search path + frappe.conf["db_schema"] = "alt_schema" + frappe.db.connect() + rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"') + self.assertEqual(rows, []) # there are no records in the alt_schema table + + del frappe.conf["db_schema"] From d3591c71701c30a85156f5771706aef8755623eb Mon Sep 17 00:00:00 2001 From: Philipp Gruener Date: Thu, 4 Jul 2024 23:05:44 +0200 Subject: [PATCH 2/5] fix: Added missing ruff adjustments --- frappe/database/postgres/database.py | 35 +++++++++++++++-------- frappe/tests/test_db.py | 42 ++++++++++++++++++---------- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/frappe/database/postgres/database.py b/frappe/database/postgres/database.py index e7a15c72ff..c3a69473ee 100644 --- a/frappe/database/postgres/database.py +++ b/frappe/database/postgres/database.py @@ -2,6 +2,7 @@ import re import psycopg2 import psycopg2.extensions +from psycopg2 import sql from psycopg2.errorcodes import ( CLASS_INTEGRITY_CONSTRAINT_VIOLATION, DEADLOCK_DETECTED, @@ -20,7 +21,6 @@ from psycopg2.errors import ( SyntaxError, ) from psycopg2.extensions import ISOLATION_LEVEL_REPEATABLE_READ -from psycopg2 import sql import frappe from frappe.database.database import Database @@ -172,7 +172,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database): @property def db_schema(self): - return re.sub(r'["\']', '', frappe.conf.get("db_schema", "public")) + return re.sub(r'["\']', "", frappe.conf.get("db_schema", "public")) def connect(self): super().connect() @@ -240,7 +240,8 @@ class PostgresDatabase(PostgresExceptionUtil, Database): from information_schema.tables where table_catalog=%s and table_type = 'BASE TABLE' - and table_schema=%s""", (self.cur_db_name, self.db_schema) + and table_schema=%s""", + (self.cur_db_name, self.db_schema), ) ] @@ -253,7 +254,10 @@ class PostgresDatabase(PostgresExceptionUtil, Database): columns = ( frappe.qb.from_(information_schema.columns) .select(information_schema.columns.column_name) - .where((information_schema.columns.table_name == table) & (information_schema.columns.table_schema == self.db_schema)) + .where( + (information_schema.columns.table_name == table) + & (information_schema.columns.table_schema == self.db_schema) + ) .run(pluck=True) ) @@ -379,7 +383,8 @@ class PostgresDatabase(PostgresExceptionUtil, Database): return self.sql( """SELECT 1 FROM pg_indexes WHERE tablename=%s and schemaname = %s - and indexname=%s limit 1""", (table_name, self.db_schema, index_name) + and indexname=%s limit 1""", + (table_name, self.db_schema, index_name), ) def add_index(self, doctype: str, fields: list, index_name: str | None = None): @@ -389,7 +394,9 @@ class PostgresDatabase(PostgresExceptionUtil, Database): index_name = index_name or self.get_index_name(fields) fields_str = '", "'.join(re.sub(r"\(.*\)", "", field) for field in fields) - self.sql_ddl(f'CREATE INDEX IF NOT EXISTS "{index_name}" ON "{self.db_schema}"."{table_name}" ("{fields_str}")') + self.sql_ddl( + f'CREATE INDEX IF NOT EXISTS "{index_name}" ON "{self.db_schema}"."{table_name}" ("{fields_str}")' + ) def add_unique(self, doctype, fields, constraint_name=None): if isinstance(fields, str): @@ -410,11 +417,17 @@ class PostgresDatabase(PostgresExceptionUtil, Database): self.commit() self.sql( - sql.SQL("""ALTER TABLE {schema}.{table} - ADD CONSTRAINT {constraint} UNIQUE ({fields})""").format(schema=sql.Identifier(self.db_schema), - table=sql.Identifier('tab' + doctype), - constraint=sql.Identifier(constraint_name), - fields=sql.SQL(', ').join(map(sql.Identifier, fields))).as_string(self._conn) + sql.SQL( + """ALTER TABLE {schema}.{table} + ADD CONSTRAINT {constraint} UNIQUE ({fields})""" + ) + .format( + schema=sql.Identifier(self.db_schema), + table=sql.Identifier("tab" + doctype), + constraint=sql.Identifier(constraint_name), + fields=sql.SQL(", ").join(map(sql.Identifier, fields)), + ) + .as_string(self._conn) ) def get_table_columns_description(self, table_name): diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index 77470b153b..2d939f7a72 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -1080,6 +1080,7 @@ class TestSqlIterator(FrappeTestCase): with frappe.db.unbuffered_cursor(): self.test_db_sql_iterator() + class ExtFrappeTestCase(FrappeTestCase): def assertSqlException(self): class SqlExceptionContextManager: @@ -1099,6 +1100,7 @@ class ExtFrappeTestCase(FrappeTestCase): return SqlExceptionContextManager(self) + @run_only_if(db_type_is.POSTGRES) class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase): test_table_name = "TestSchemaTable" @@ -1107,11 +1109,12 @@ class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase): if rollback: frappe.db.rollback() - - if frappe.db.sql("""SELECT 1 + if frappe.db.sql( + """SELECT 1 FROM information_schema.schemata WHERE schema_name = 'alt_schema' - limit 1 """): + limit 1 """ + ): self.cleanup() frappe.db.sql( @@ -1146,14 +1149,15 @@ class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase): self.cleanup() def cleanup(self) -> None: - frappe.db.sql(f""" - DROP TABLE "public"."tab{self.test_table_name}"; - DROP TABLE "alt_schema"."tab{self.test_table_name}"; - DROP TABLE "alt_schema"."tab{self.test_table_name}_2"; - DROP TABLE "alt_schema"."tabUser"; - DROP SCHEMA "alt_schema" CASCADE; - """) - # DROP USER u_alt_schema; + frappe.db.sql( + f""" + DROP TABLE "public"."tab{self.test_table_name}"; + DROP TABLE "alt_schema"."tab{self.test_table_name}"; + DROP TABLE "alt_schema"."tab{self.test_table_name}_2"; + DROP TABLE "alt_schema"."tabUser"; + DROP SCHEMA "alt_schema" CASCADE; + """ + ) def test_get_tables(self) -> None: tables = frappe.db.get_tables(cached=False) @@ -1172,7 +1176,7 @@ class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase): self.assertEqual(columns, ["col_a", "col_b"]) frappe.conf["db_schema"] = "alt_schema" - frappe.cache.delete_key("table_columns") # remove table columns cache for next try from alt_schema + frappe.cache.delete_key("table_columns") # remove table columns cache for next try from alt_schema # should have received the columns of the table from alt_schema columns = frappe.db.get_table_columns(self.test_table_name) @@ -1220,7 +1224,7 @@ class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase): def test_get_table_columns_description(self): # should only return the columns of the table in the default public schema - columns = frappe.db.get_table_columns_description(f'tab{self.test_table_name}') + columns = frappe.db.get_table_columns_description(f"tab{self.test_table_name}") self.assertTrue(any([col for col in columns if col["name"] == "col_a"])) self.assertTrue(any([col for col in columns if col["name"] == "col_b"])) @@ -1238,12 +1242,20 @@ class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase): def test_search_path(self): # by default the the public schema tables should be addressed by search path rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"') - self.assertEqual(rows, [("a", "b",)]) # there should be a single row in the public table + self.assertEqual( + rows, + [ + ( + "a", + "b", + ) + ], + ) # there should be a single row in the public table # when schema is changed to alt_schema, the alt_schema tables should be addressed by search path frappe.conf["db_schema"] = "alt_schema" frappe.db.connect() rows = frappe.db.sql(f'select * from "tab{self.test_table_name}"') - self.assertEqual(rows, []) # there are no records in the alt_schema table + self.assertEqual(rows, []) # there are no records in the alt_schema table del frappe.conf["db_schema"] From 2e0ca7d3a7f25e094664f63ba6078293f78343fd Mon Sep 17 00:00:00 2001 From: Philipp Gruener Date: Thu, 4 Jul 2024 23:25:46 +0200 Subject: [PATCH 3/5] fix: simplified functional call to iterator (semgrep) --- frappe/database/postgres/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frappe/database/postgres/database.py b/frappe/database/postgres/database.py index c3a69473ee..caba53e242 100644 --- a/frappe/database/postgres/database.py +++ b/frappe/database/postgres/database.py @@ -425,7 +425,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database): schema=sql.Identifier(self.db_schema), table=sql.Identifier("tab" + doctype), constraint=sql.Identifier(constraint_name), - fields=sql.SQL(", ").join(map(sql.Identifier, fields)), + fields=sql.SQL(", ").join(sql.Identifier(field) for field in fields), ) .as_string(self._conn) ) From ffcd6d1ff5569abc82db4453b8e75353412f6953 Mon Sep 17 00:00:00 2001 From: Philipp Gruener Date: Tue, 16 Jul 2024 22:52:31 +0200 Subject: [PATCH 4/5] fix: Added missing newlines after merge (for pre-commit checks) --- frappe/tests/test_db.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index d4ccb78d65..63c0dba7da 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -1080,6 +1080,7 @@ class TestSqlIterator(FrappeTestCase): with frappe.db.unbuffered_cursor(): self.test_db_sql_iterator() + class ExtFrappeTestCase(FrappeTestCase): def assertSqlException(self): class SqlExceptionContextManager: @@ -1259,6 +1260,7 @@ class TestPostgresSchemaQueryIndependence(ExtFrappeTestCase): del frappe.conf["db_schema"] + class TestDbConnectWithEnvCredentials(FrappeTestCase): current_site = frappe.local.site From e2a1c506feb36f4919eabaca41824537c94ca6c3 Mon Sep 17 00:00:00 2001 From: Philipp Gruener Date: Sat, 20 Jul 2024 13:23:03 +0200 Subject: [PATCH 5/5] fix: Added adjustments --- frappe/database/postgres/database.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/frappe/database/postgres/database.py b/frappe/database/postgres/database.py index caba53e242..56ae71a9ad 100644 --- a/frappe/database/postgres/database.py +++ b/frappe/database/postgres/database.py @@ -172,7 +172,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database): @property def db_schema(self): - return re.sub(r'["\']', "", frappe.conf.get("db_schema", "public")) + return frappe.conf.get("db_schema", "public").replace("'", "").replace('"', "") def connect(self): super().connect() @@ -247,22 +247,22 @@ class PostgresDatabase(PostgresExceptionUtil, Database): def get_db_table_columns(self, table) -> list[str]: """Returns list of column names from given table.""" - columns = frappe.cache.hget("table_columns", table) - if columns is None: - information_schema = frappe.qb.Schema("information_schema") + if (columns := frappe.cache.hget("table_columns", table)) is not None: + return columns - columns = ( - frappe.qb.from_(information_schema.columns) - .select(information_schema.columns.column_name) - .where( - (information_schema.columns.table_name == table) - & (information_schema.columns.table_schema == self.db_schema) - ) - .run(pluck=True) + information_schema = frappe.qb.Schema("information_schema") + + columns = ( + frappe.qb.from_(information_schema.columns) + .select(information_schema.columns.column_name) + .where( + (information_schema.columns.table_name == table) + & (information_schema.columns.table_schema == self.db_schema) ) + .run(pluck=True) + ) - if columns: - frappe.cache.hset("table_columns", table, columns) + frappe.cache.hset("table_columns", table, columns) return columns