Merge pull request #35940 from AarDG10/fix-orderby-pg

This commit is contained in:
Suraj Shetty 2026-02-22 13:09:43 +05:30 committed by GitHub
commit 0c211aa4a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 138 additions and 6 deletions

View file

@ -262,6 +262,8 @@ class Engine:
self.field_aliases = set() self.field_aliases = set()
self.db_query_compat = db_query_compat self.db_query_compat = db_query_compat
self.permitted_fields_cache = {} # Cache for get_permitted_fields results self.permitted_fields_cache = {} # Cache for get_permitted_fields results
self.is_aggregate_query = False
self._grouped_queries = set()
if isinstance(table, Table): if isinstance(table, Table):
self.table = table self.table = table
@ -308,19 +310,24 @@ class Engine:
if for_update: if for_update:
self.query = self.query.for_update(skip_locked=skip_locked, nowait=not wait) self.query = self.query.for_update(skip_locked=skip_locked, nowait=not wait)
if any(isinstance(f, functions.AggregateFunction) for f in getattr(self, "fields", [])):
# check if any field in select is aggregated (done to prevent breaking queries in postgres due to order by rule)
self.is_aggregate_query = True
if group_by: if group_by:
self.is_aggregate_query = True # for postgres (group by used with order by)
self.apply_group_by(group_by) self.apply_group_by(group_by)
if order_by: if order_by:
if not ( if not (
self.is_postgres and is_select and (distinct or group_by) self.is_postgres and is_select and distinct
): # ignore in Postgres since order by fields need to appear in select distinct ): # ignore in Postgres since order by fields need to appear in select distinct
self.apply_order_by(order_by) self.apply_order_by(order_by)
else: else:
warnings.warn( warnings.warn(
( (
"ORDER BY fields have been ignored because PostgreSQL requires them to " "ORDER BY fields have been ignored because PostgreSQL requires them to "
"appear in the SELECT list when using DISTINCT or GROUP BY." "appear in the SELECT list when using with DISTINCT"
), ),
UserWarning, UserWarning,
stacklevel=2, stacklevel=2,
@ -340,7 +347,7 @@ class Engine:
# Track field aliases for use in group_by/order_by # Track field aliases for use in group_by/order_by
for field in self.fields: for field in self.fields:
if isinstance(field, Field | DynamicTableField) and field.alias: if isinstance(field, Field | DynamicTableField | AggregateFunction) and field.alias:
self.field_aliases.add(field.alias) self.field_aliases.add(field.alias)
if self.apply_permissions: if self.apply_permissions:
@ -1190,8 +1197,24 @@ class Engine:
# Note: Comma handling is done in parse_fields before this method is called # Note: Comma handling is done in parse_fields before this method is called
return self.parse_string_field(field) return self.parse_string_field(field)
def _normalize_postgres_order_field(self, field):
"""In PostgreSQL order_by fields need to either be in group_by or be aggregated
when used with select and group_by"""
current_sql = field.get_sql() if hasattr(field, "get_sql") else str(field)
if current_sql in self._grouped_queries:
return field
clean_name = current_sql.strip('"')
if clean_name in self.field_aliases:
return field
if not isinstance(field, functions.AggregateFunction):
return functions.Max(field)
return field
def apply_group_by(self, group_by: str | None = None): def apply_group_by(self, group_by: str | None = None):
parsed_group_by_fields = self._validate_group_by(group_by) parsed_group_by_fields = self._validate_group_by(group_by)
self._grouped_queries = {
f.get_sql() if hasattr(f, "get_sql") else str(f) for f in parsed_group_by_fields
}
self.query = self.query.groupby(*parsed_group_by_fields) self.query = self.query.groupby(*parsed_group_by_fields)
def apply_order_by(self, order_by: str | None): def apply_order_by(self, order_by: str | None):
@ -1201,7 +1224,12 @@ class Engine:
parsed_order_fields = self._validate_order_by(order_by) parsed_order_fields = self._validate_order_by(order_by)
for order_field, order_direction in parsed_order_fields: for order_field, order_direction in parsed_order_fields:
self.query = self.query.orderby(order_field, order=order_direction) if self.is_postgres and self.is_aggregate_query:
self.query = self.query.orderby(
self._normalize_postgres_order_field(order_field), order=order_direction
)
else:
self.query = self.query.orderby(order_field, order=order_direction)
def _apply_default_order_by(self): def _apply_default_order_by(self):
"""Apply default ordering based on configured DocType metadata""" """Apply default ordering based on configured DocType metadata"""
@ -1220,14 +1248,24 @@ class Engine:
order_direction = Order.desc if spec_order == "desc" else Order.asc order_direction = Order.desc if spec_order == "desc" else Order.asc
else: else:
order_direction = Order.asc if spec_order == "asc" else Order.desc order_direction = Order.asc if spec_order == "asc" else Order.desc
self.query = self.query.orderby(field, order=order_direction) if self.is_postgres and self.is_aggregate_query:
self.query = self.query.orderby(
self._normalize_postgres_order_field(field), order=order_direction
)
else:
self.query = self.query.orderby(field, order=order_direction)
else: else:
field = self.table[sort_field] field = self.table[sort_field]
if self.db_query_compat: if self.db_query_compat:
order_direction = Order.desc if sort_order.lower() == "desc" else Order.asc order_direction = Order.desc if sort_order.lower() == "desc" else Order.asc
else: else:
order_direction = Order.asc if sort_order.lower() == "asc" else Order.desc order_direction = Order.asc if sort_order.lower() == "asc" else Order.desc
self.query = self.query.orderby(field, order=order_direction) if self.is_postgres and self.is_aggregate_query:
self.query = self.query.orderby(
self._normalize_postgres_order_field(field), order=order_direction
)
else:
self.query = self.query.orderby(field, order=order_direction)
def _parse_backtick_field_notation(self, field_name: str) -> tuple[str, str] | None: def _parse_backtick_field_notation(self, field_name: str) -> tuple[str, str] | None:
""" """

View file

@ -2345,6 +2345,100 @@ class TestQuery(IntegrationTestCase):
# the filter should still apply and return no results # the filter should still apply and return no results
self.assertEqual(len(result), 0, "Filter should not be bypassed by shared doc OR condition") self.assertEqual(len(result), 0, "Filter should not be bypassed by shared doc OR condition")
@run_only_if(db_type_is.POSTGRES)
def test_order_by_group_by_postgres(self):
"""PostgreSQL specific test that tests if order_by fields are correctly handled when used with group_by"""
# test order by fields already in group by (no aggregate needed)
query = frappe.qb.get_query(
"User",
fields=["creation as created_date", {"COUNT": "*"}],
group_by="created_date",
order_by="created_date",
).get_sql()
self.assertQueryEqual(
query,
'SELECT "creation" "created_date",COUNT(*) FROM "tabUser" GROUP BY "created_date" ORDER BY "created_date" DESC',
)
# test order by fields not in group by (aggregate needed)
query = frappe.qb.get_query(
"User",
fields=["creation as created_date", {"COUNT": "*"}],
group_by="created_date",
order_by="name",
).get_sql()
self.assertQueryEqual(
query,
'SELECT "creation" "created_date",COUNT(*) FROM "tabUser" GROUP BY "created_date" ORDER BY MAX("name") DESC',
)
query = frappe.qb.get_query(
"User",
fields=["user_type as type", "enabled as status", {"COUNT": "*"}],
group_by="type, status",
order_by="status asc",
).get_sql()
self.assertQueryEqual(
query,
'SELECT "user_type" "type","enabled" "status",COUNT(*) FROM "tabUser" GROUP BY "type","status" ORDER BY "status" ASC',
)
# test no double aggregation rule
query = frappe.qb.get_query(
"User",
fields=["creation", {"COUNT": "*", "as": "total"}],
group_by="creation",
order_by="total desc",
).get_sql()
self.assertQueryEqual(
query,
'SELECT "creation",COUNT(*) "total" FROM "tabUser" GROUP BY "creation" ORDER BY "total" DESC',
)
# test multiple order_by fields not in group_by
query = frappe.qb.get_query(
"User",
fields=["user_type", {"COUNT": "*"}],
group_by="user_type",
order_by="creation desc, modified asc",
).get_sql()
self.assertIn('MAX("creation") DESC', query)
self.assertIn('MAX("modified") ASC', query)
# for queries that have aggregate fields selected but not grouped (these queries are redundant but exist in some parts of codebase)
query = frappe.qb.get_query(
"User", fields=[{"COUNT": "*", "as": "result"}], order_by="creation desc"
).get_sql()
self.assertQueryEqual(query, 'SELECT COUNT(*) "result" FROM "tabUser" ORDER BY MAX("creation") DESC')
# test in case user uses `original_col` name instead of alias
query = frappe.qb.get_query(
"User", fields=["name as user_name"], group_by="user_name", order_by="user_name"
)
a = query.run()
query = frappe.qb.get_query("User", fields=["name as user_name"], group_by="name", order_by="name")
b = query.run()
query = frappe.qb.get_query(
"User", fields=["name as user_name"], group_by="name", order_by="user_name"
)
c = query.run()
query = frappe.qb.get_query(
"User", fields=["name as user_name"], group_by="user_name", order_by="name"
)
d = query.run()
for val in [b, c, d]:
self.assertEqual(a, val, "Query result mismatch detected.")
@run_only_if(db_type_is.POSTGRES) @run_only_if(db_type_is.POSTGRES)
def test_ifnull_fallback_postgres(self): def test_ifnull_fallback_postgres(self):
"""Test ifnull fallback in postgres""" """Test ifnull fallback in postgres"""