test(postgres): fix mariadb specific identifiers for postgres queries using normalize_sql

This commit is contained in:
AarDG10 2025-11-10 17:46:24 +05:30
parent a991bad767
commit a12d855147
3 changed files with 46 additions and 68 deletions

View file

@ -13,7 +13,6 @@ from frappe.core.api.user_invitation import (
)
from frappe.core.doctype.user_invitation.user_invitation import mark_expired_invitations
from frappe.tests import IntegrationTestCase
from frappe.tests.test_query import convert_identifier_quotes
emails = [
"test_user_invite1@example.com",
@ -56,18 +55,18 @@ class IntegrationTestUserInvitation(IntegrationTestCase):
@classmethod
def delete_all_user_roles(cls):
frappe.db.sql(convert_identifier_quotes("DELETE FROM `tabUser Role`"))
query = "DELETE FROM `tabUser Role`"
frappe.db.sql(cls.normalize_sql(query))
@classmethod
def delete_all_invitations(cls):
frappe.db.sql(convert_identifier_quotes("DELETE FROM `tabUser Invitation`"))
query = "DELETE FROM `tabUser Invitation`"
frappe.db.sql(cls.normalize_sql(query))
@classmethod
def delete_invitation(cls, name: str):
frappe.db.sql(
convert_identifier_quotes("DELETE FROM `tabUser Invitation` WHERE name = %s"),
name,
)
query = "DELETE FROM `tabUser Invitation` WHERE name = %s"
frappe.db.sql(cls.normalize_sql(query), name)
def setUp(self):
super().setUp()

View file

@ -105,6 +105,8 @@ class UnitTestCase(unittest.TestCase, BaseTestCase):
"""Formats SQL consistently so simple string comparisons can work on them."""
import sqlparse
if frappe.db.db_type == "postgres":
query = query.replace("`", '"')
return sqlparse.format(query.strip(), keyword_case="upper", reindent=True, strip_comments=True)

View file

@ -63,20 +63,12 @@ def create_tree_docs():
d.insert()
def convert_identifier_quotes(query):
"""util to replace mariadb query idenitfifiers with postgres ones"""
return query.replace("`", '"') if frappe.db.db_type == "postgres" else query
class TestQuery(IntegrationTestCase):
def setUp(self):
setup_for_tests()
def test_multiple_tables_in_filters(self):
expected_query = convert_identifier_quotes(
"SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' AND `tabDocField`.`parentfield`='fields' WHERE `tabDocField`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'"
)
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
["*"],
@ -85,7 +77,7 @@ class TestQuery(IntegrationTestCase):
["DocType", "parent", "=", "something"],
],
).get_sql(),
expected_query,
"SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' AND `tabDocField`.`parentfield`='fields' WHERE `tabDocField`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'",
)
def test_string_fields(self):
@ -359,88 +351,72 @@ class TestQuery(IntegrationTestCase):
)
def test_filters(self):
expected_query = convert_identifier_quotes(
"SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name`='frappe'"
)
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
fields=["name"],
filters={"module.app_name": "frappe"},
).get_sql(),
expected_query,
"SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name`='frappe'",
)
expected_query = convert_identifier_quotes(
"SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name` LIKE 'frap%'"
)
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
fields=["name"],
filters={"module.app_name": ("like", "frap%")},
).get_sql(),
expected_query,
"SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name` LIKE 'frap%'",
)
expected_query = convert_identifier_quotes(
"SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabDocPerm` ON `tabDocPerm`.`parent`=`tabDocType`.`name` AND `tabDocPerm`.`parenttype`='DocType' AND `tabDocPerm`.`parentfield`='permissions' WHERE `tabDocPerm`.`role`='System Manager'"
)
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
fields=["name"],
filters={"permissions.role": "System Manager"},
).get_sql(),
expected_query,
"SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabDocPerm` ON `tabDocPerm`.`parent`=`tabDocType`.`name` AND `tabDocPerm`.`parenttype`='DocType' AND `tabDocPerm`.`parentfield`='permissions' WHERE `tabDocPerm`.`role`='System Manager'",
)
expected_query = convert_identifier_quotes("SELECT `module` FROM `tabDocType` WHERE `name`=''")
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
fields=["module"],
filters="",
).get_sql(),
expected_query,
"SELECT `module` FROM `tabDocType` WHERE `name`=''",
)
expected_query = convert_identifier_quotes(
"SELECT `name` FROM `tabDocType` WHERE `name` IN ('ToDo','Note')"
)
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
filters=["ToDo", "Note"],
).get_sql(),
expected_query,
"SELECT `name` FROM `tabDocType` WHERE `name` IN ('ToDo','Note')",
)
expected_query = convert_identifier_quotes("SELECT `name` FROM `tabDocType` WHERE `name` IN ('')")
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
filters={"name": ("in", [])},
).get_sql(),
expected_query,
"SELECT `name` FROM `tabDocType` WHERE `name` IN ('')",
)
expected_query = convert_identifier_quotes("SELECT `name` FROM `tabDocType` WHERE `name` IN (1,2,3)")
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
filters=[1, 2, 3],
).get_sql(),
expected_query,
"SELECT `name` FROM `tabDocType` WHERE `name` IN (1,2,3)",
)
expected_query = convert_identifier_quotes("SELECT `name` FROM `tabDocType`")
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
filters=[],
).get_sql(),
expected_query,
"SELECT `name` FROM `tabDocType`",
)
def test_nested_filters(self):
@ -1520,7 +1496,7 @@ class TestQuery(IntegrationTestCase):
# Test simple function without alias
query = frappe.qb.get_query("User", fields=["user_type", {"COUNT": "name"}], group_by="user_type")
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("COUNT(`name`)"), sql)
self.assertIn(self.normalize_sql("COUNT(`name`)"), sql)
self.assertIn("GROUP BY", sql)
# Test function with alias
@ -1528,52 +1504,54 @@ class TestQuery(IntegrationTestCase):
"User", fields=[{"COUNT": "name", "as": "total_users"}], group_by="user_type"
)
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("COUNT(`name`) `total_users`"), sql)
self.assertIn(self.normalize_sql("COUNT(`name`) `total_users`"), sql)
# Test SUM function with alias
query = frappe.qb.get_query(
"User", fields=[{"SUM": "enabled", "as": "total_enabled"}], group_by="user_type"
)
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("SUM(`enabled`) `total_enabled`"), sql)
self.assertIn(self.normalize_sql("SUM(`enabled`) `total_enabled`"), sql)
# Test MAX function
query = frappe.qb.get_query(
"User", fields=[{"MAX": "creation", "as": "latest_user"}], group_by="user_type"
)
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("MAX(`creation`) `latest_user`"), sql)
self.assertIn(self.normalize_sql("MAX(`creation`) `latest_user`"), sql)
# Test MIN function
query = frappe.qb.get_query(
"User", fields=[{"MIN": "creation", "as": "earliest_user"}], group_by="user_type"
)
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("MIN(`creation`) `earliest_user`"), sql)
self.assertIn(self.normalize_sql("MIN(`creation`) `earliest_user`"), sql)
# Test AVG function
query = frappe.qb.get_query(
"User", fields=[{"AVG": "enabled", "as": "avg_enabled"}], group_by="user_type"
)
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("AVG(`enabled`) `avg_enabled`"), sql)
self.assertIn(self.normalize_sql("AVG(`enabled`) `avg_enabled`"), sql)
# Test ABS function
query = frappe.qb.get_query("User", fields=[{"ABS": "enabled", "as": "abs_enabled"}])
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("ABS(`enabled`) `abs_enabled`"), sql)
self.assertIn(self.normalize_sql("ABS(`enabled`) `abs_enabled`"), sql)
# Test IFNULL function with two parameters
query = frappe.qb.get_query(
"User", fields=[{"IFNULL": ["first_name", "'Unknown'"], "as": "safe_name"}]
)
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("IFNULL(`first_name`,'Unknown') `safe_name`"), sql)
self.assertIn(
self.normalize_sql("IFNULL(`first_name`,'Unknown') `safe_name`"), self.normalize_sql(sql)
)
# Test TIMESTAMP function
query = frappe.qb.get_query("User", fields=[{"TIMESTAMP": "creation", "as": "ts"}])
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("TIMESTAMP(`creation`) `ts`"), sql)
self.assertIn(self.normalize_sql("TIMESTAMP(`creation`) `ts`"), self.normalize_sql(sql))
# Test mixed regular fields and function fields
query = frappe.qb.get_query(
@ -1586,21 +1564,23 @@ class TestQuery(IntegrationTestCase):
group_by="user_type",
)
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("`user_type`"), sql)
self.assertIn(convert_identifier_quotes("COUNT(`name`) `total_users`"), sql)
self.assertIn(convert_identifier_quotes("MAX(`creation`) `latest_creation`"), sql)
self.assertIn(self.normalize_sql("`user_type`"), sql)
self.assertIn(self.normalize_sql("COUNT(`name`) `total_users`"), sql)
self.assertIn(self.normalize_sql("MAX(`creation`) `latest_creation`"), sql)
# Test NOW function with no arguments
query = frappe.qb.get_query("User", fields=[{"NOW": None, "as": "current_time"}])
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("NOW() `current_time`"), sql)
self.assertIn(self.normalize_sql("NOW() `current_time`"), sql)
# Test CONCAT function (which is supported)
query = frappe.qb.get_query(
"User", fields=[{"CONCAT": ["first_name", "last_name"], "as": "full_name"}]
)
sql = query.get_sql()
self.assertIn(convert_identifier_quotes("CONCAT(`first_name`,`last_name`) `full_name`"), sql)
self.assertIn(
self.normalize_sql("CONCAT(`first_name`,`last_name`) `full_name`"), self.normalize_sql(sql)
)
# Test unsupported function validation
with self.assertRaises(frappe.ValidationError) as cm:
@ -1618,10 +1598,7 @@ class TestQuery(IntegrationTestCase):
self.assertIn("Unsupported function or invalid field name: DROP", str(cm.exception))
def test_not_equal_condition_on_none(self):
expected_query = convert_identifier_quotes(
"SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' AND `tabDocField`.`parentfield`='fields' WHERE `tabDocField`.`name` IS NULL AND `tabDocType`.`parent` IS NOT NULL"
)
self.assertEqual(
self.assertQueryEqual(
frappe.qb.get_query(
"DocType",
["*"],
@ -1630,7 +1607,7 @@ class TestQuery(IntegrationTestCase):
["DocType", "parent", "!=", None],
],
).get_sql(),
expected_query,
"SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' AND `tabDocField`.`parentfield`='fields' WHERE `tabDocField`.`name` IS NULL AND `tabDocType`.`parent` IS NOT NULL",
)