From a12d855147e39a3667b369972b941bd9d7322a79 Mon Sep 17 00:00:00 2001 From: AarDG10 Date: Mon, 10 Nov 2025 17:46:24 +0530 Subject: [PATCH] test(postgres): fix mariadb specific identifiers for postgres queries using normalize_sql --- .../user_invitation/test_user_invitation.py | 13 ++- frappe/tests/classes/unit_test_case.py | 2 + frappe/tests/test_query.py | 99 +++++++------------ 3 files changed, 46 insertions(+), 68 deletions(-) diff --git a/frappe/core/doctype/user_invitation/test_user_invitation.py b/frappe/core/doctype/user_invitation/test_user_invitation.py index 7345cc300d..634c739d8e 100644 --- a/frappe/core/doctype/user_invitation/test_user_invitation.py +++ b/frappe/core/doctype/user_invitation/test_user_invitation.py @@ -13,7 +13,6 @@ from frappe.core.api.user_invitation import ( ) from frappe.core.doctype.user_invitation.user_invitation import mark_expired_invitations from frappe.tests import IntegrationTestCase -from frappe.tests.test_query import convert_identifier_quotes emails = [ "test_user_invite1@example.com", @@ -56,18 +55,18 @@ class IntegrationTestUserInvitation(IntegrationTestCase): @classmethod def delete_all_user_roles(cls): - frappe.db.sql(convert_identifier_quotes("DELETE FROM `tabUser Role`")) + query = "DELETE FROM `tabUser Role`" + frappe.db.sql(cls.normalize_sql(query)) @classmethod def delete_all_invitations(cls): - frappe.db.sql(convert_identifier_quotes("DELETE FROM `tabUser Invitation`")) + query = "DELETE FROM `tabUser Invitation`" + frappe.db.sql(cls.normalize_sql(query)) @classmethod def delete_invitation(cls, name: str): - frappe.db.sql( - convert_identifier_quotes("DELETE FROM `tabUser Invitation` WHERE name = %s"), - name, - ) + query = "DELETE FROM `tabUser Invitation` WHERE name = %s" + frappe.db.sql(cls.normalize_sql(query), name) def setUp(self): super().setUp() diff --git a/frappe/tests/classes/unit_test_case.py b/frappe/tests/classes/unit_test_case.py index 45b8e56963..65aa4e616d 100644 --- a/frappe/tests/classes/unit_test_case.py +++ b/frappe/tests/classes/unit_test_case.py @@ -105,6 +105,8 @@ class UnitTestCase(unittest.TestCase, BaseTestCase): """Formats SQL consistently so simple string comparisons can work on them.""" import sqlparse + if frappe.db.db_type == "postgres": + query = query.replace("`", '"') return sqlparse.format(query.strip(), keyword_case="upper", reindent=True, strip_comments=True) diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index 2cf933b7a6..d99bc68b1d 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -63,20 +63,12 @@ def create_tree_docs(): d.insert() -def convert_identifier_quotes(query): - """util to replace mariadb query idenitfifiers with postgres ones""" - return query.replace("`", '"') if frappe.db.db_type == "postgres" else query - - class TestQuery(IntegrationTestCase): def setUp(self): setup_for_tests() def test_multiple_tables_in_filters(self): - expected_query = convert_identifier_quotes( - "SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' AND `tabDocField`.`parentfield`='fields' WHERE `tabDocField`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'" - ) - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", ["*"], @@ -85,7 +77,7 @@ class TestQuery(IntegrationTestCase): ["DocType", "parent", "=", "something"], ], ).get_sql(), - expected_query, + "SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' AND `tabDocField`.`parentfield`='fields' WHERE `tabDocField`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'", ) def test_string_fields(self): @@ -359,88 +351,72 @@ class TestQuery(IntegrationTestCase): ) def test_filters(self): - expected_query = convert_identifier_quotes( - "SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name`='frappe'" - ) - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", fields=["name"], filters={"module.app_name": "frappe"}, ).get_sql(), - expected_query, + "SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name`='frappe'", ) - expected_query = convert_identifier_quotes( - "SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name` LIKE 'frap%'" - ) - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", fields=["name"], filters={"module.app_name": ("like", "frap%")}, ).get_sql(), - expected_query, + "SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name` LIKE 'frap%'", ) - expected_query = convert_identifier_quotes( - "SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabDocPerm` ON `tabDocPerm`.`parent`=`tabDocType`.`name` AND `tabDocPerm`.`parenttype`='DocType' AND `tabDocPerm`.`parentfield`='permissions' WHERE `tabDocPerm`.`role`='System Manager'" - ) - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", fields=["name"], filters={"permissions.role": "System Manager"}, ).get_sql(), - expected_query, + "SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabDocPerm` ON `tabDocPerm`.`parent`=`tabDocType`.`name` AND `tabDocPerm`.`parenttype`='DocType' AND `tabDocPerm`.`parentfield`='permissions' WHERE `tabDocPerm`.`role`='System Manager'", ) - expected_query = convert_identifier_quotes("SELECT `module` FROM `tabDocType` WHERE `name`=''") - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", fields=["module"], filters="", ).get_sql(), - expected_query, + "SELECT `module` FROM `tabDocType` WHERE `name`=''", ) - expected_query = convert_identifier_quotes( - "SELECT `name` FROM `tabDocType` WHERE `name` IN ('ToDo','Note')" - ) - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", filters=["ToDo", "Note"], ).get_sql(), - expected_query, + "SELECT `name` FROM `tabDocType` WHERE `name` IN ('ToDo','Note')", ) - expected_query = convert_identifier_quotes("SELECT `name` FROM `tabDocType` WHERE `name` IN ('')") - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", filters={"name": ("in", [])}, ).get_sql(), - expected_query, + "SELECT `name` FROM `tabDocType` WHERE `name` IN ('')", ) - expected_query = convert_identifier_quotes("SELECT `name` FROM `tabDocType` WHERE `name` IN (1,2,3)") - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", filters=[1, 2, 3], ).get_sql(), - expected_query, + "SELECT `name` FROM `tabDocType` WHERE `name` IN (1,2,3)", ) - expected_query = convert_identifier_quotes("SELECT `name` FROM `tabDocType`") - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", filters=[], ).get_sql(), - expected_query, + "SELECT `name` FROM `tabDocType`", ) def test_nested_filters(self): @@ -1520,7 +1496,7 @@ class TestQuery(IntegrationTestCase): # Test simple function without alias query = frappe.qb.get_query("User", fields=["user_type", {"COUNT": "name"}], group_by="user_type") sql = query.get_sql() - self.assertIn(convert_identifier_quotes("COUNT(`name`)"), sql) + self.assertIn(self.normalize_sql("COUNT(`name`)"), sql) self.assertIn("GROUP BY", sql) # Test function with alias @@ -1528,52 +1504,54 @@ class TestQuery(IntegrationTestCase): "User", fields=[{"COUNT": "name", "as": "total_users"}], group_by="user_type" ) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("COUNT(`name`) `total_users`"), sql) + self.assertIn(self.normalize_sql("COUNT(`name`) `total_users`"), sql) # Test SUM function with alias query = frappe.qb.get_query( "User", fields=[{"SUM": "enabled", "as": "total_enabled"}], group_by="user_type" ) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("SUM(`enabled`) `total_enabled`"), sql) + self.assertIn(self.normalize_sql("SUM(`enabled`) `total_enabled`"), sql) # Test MAX function query = frappe.qb.get_query( "User", fields=[{"MAX": "creation", "as": "latest_user"}], group_by="user_type" ) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("MAX(`creation`) `latest_user`"), sql) + self.assertIn(self.normalize_sql("MAX(`creation`) `latest_user`"), sql) # Test MIN function query = frappe.qb.get_query( "User", fields=[{"MIN": "creation", "as": "earliest_user"}], group_by="user_type" ) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("MIN(`creation`) `earliest_user`"), sql) + self.assertIn(self.normalize_sql("MIN(`creation`) `earliest_user`"), sql) # Test AVG function query = frappe.qb.get_query( "User", fields=[{"AVG": "enabled", "as": "avg_enabled"}], group_by="user_type" ) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("AVG(`enabled`) `avg_enabled`"), sql) + self.assertIn(self.normalize_sql("AVG(`enabled`) `avg_enabled`"), sql) # Test ABS function query = frappe.qb.get_query("User", fields=[{"ABS": "enabled", "as": "abs_enabled"}]) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("ABS(`enabled`) `abs_enabled`"), sql) + self.assertIn(self.normalize_sql("ABS(`enabled`) `abs_enabled`"), sql) # Test IFNULL function with two parameters query = frappe.qb.get_query( "User", fields=[{"IFNULL": ["first_name", "'Unknown'"], "as": "safe_name"}] ) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("IFNULL(`first_name`,'Unknown') `safe_name`"), sql) + self.assertIn( + self.normalize_sql("IFNULL(`first_name`,'Unknown') `safe_name`"), self.normalize_sql(sql) + ) # Test TIMESTAMP function query = frappe.qb.get_query("User", fields=[{"TIMESTAMP": "creation", "as": "ts"}]) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("TIMESTAMP(`creation`) `ts`"), sql) + self.assertIn(self.normalize_sql("TIMESTAMP(`creation`) `ts`"), self.normalize_sql(sql)) # Test mixed regular fields and function fields query = frappe.qb.get_query( @@ -1586,21 +1564,23 @@ class TestQuery(IntegrationTestCase): group_by="user_type", ) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("`user_type`"), sql) - self.assertIn(convert_identifier_quotes("COUNT(`name`) `total_users`"), sql) - self.assertIn(convert_identifier_quotes("MAX(`creation`) `latest_creation`"), sql) + self.assertIn(self.normalize_sql("`user_type`"), sql) + self.assertIn(self.normalize_sql("COUNT(`name`) `total_users`"), sql) + self.assertIn(self.normalize_sql("MAX(`creation`) `latest_creation`"), sql) # Test NOW function with no arguments query = frappe.qb.get_query("User", fields=[{"NOW": None, "as": "current_time"}]) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("NOW() `current_time`"), sql) + self.assertIn(self.normalize_sql("NOW() `current_time`"), sql) # Test CONCAT function (which is supported) query = frappe.qb.get_query( "User", fields=[{"CONCAT": ["first_name", "last_name"], "as": "full_name"}] ) sql = query.get_sql() - self.assertIn(convert_identifier_quotes("CONCAT(`first_name`,`last_name`) `full_name`"), sql) + self.assertIn( + self.normalize_sql("CONCAT(`first_name`,`last_name`) `full_name`"), self.normalize_sql(sql) + ) # Test unsupported function validation with self.assertRaises(frappe.ValidationError) as cm: @@ -1618,10 +1598,7 @@ class TestQuery(IntegrationTestCase): self.assertIn("Unsupported function or invalid field name: DROP", str(cm.exception)) def test_not_equal_condition_on_none(self): - expected_query = convert_identifier_quotes( - "SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' AND `tabDocField`.`parentfield`='fields' WHERE `tabDocField`.`name` IS NULL AND `tabDocType`.`parent` IS NOT NULL" - ) - self.assertEqual( + self.assertQueryEqual( frappe.qb.get_query( "DocType", ["*"], @@ -1630,7 +1607,7 @@ class TestQuery(IntegrationTestCase): ["DocType", "parent", "!=", None], ], ).get_sql(), - expected_query, + "SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' AND `tabDocField`.`parentfield`='fields' WHERE `tabDocField`.`name` IS NULL AND `tabDocType`.`parent` IS NOT NULL", )