Merge pull request #35940 from AarDG10/fix-orderby-pg
This commit is contained in:
commit
0c211aa4a0
2 changed files with 138 additions and 6 deletions
|
|
@ -262,6 +262,8 @@ class Engine:
|
|||
self.field_aliases = set()
|
||||
self.db_query_compat = db_query_compat
|
||||
self.permitted_fields_cache = {} # Cache for get_permitted_fields results
|
||||
self.is_aggregate_query = False
|
||||
self._grouped_queries = set()
|
||||
|
||||
if isinstance(table, Table):
|
||||
self.table = table
|
||||
|
|
@ -308,19 +310,24 @@ class Engine:
|
|||
if for_update:
|
||||
self.query = self.query.for_update(skip_locked=skip_locked, nowait=not wait)
|
||||
|
||||
if any(isinstance(f, functions.AggregateFunction) for f in getattr(self, "fields", [])):
|
||||
# check if any field in select is aggregated (done to prevent breaking queries in postgres due to order by rule)
|
||||
self.is_aggregate_query = True
|
||||
|
||||
if group_by:
|
||||
self.is_aggregate_query = True # for postgres (group by used with order by)
|
||||
self.apply_group_by(group_by)
|
||||
|
||||
if order_by:
|
||||
if not (
|
||||
self.is_postgres and is_select and (distinct or group_by)
|
||||
self.is_postgres and is_select and distinct
|
||||
): # ignore in Postgres since order by fields need to appear in select distinct
|
||||
self.apply_order_by(order_by)
|
||||
else:
|
||||
warnings.warn(
|
||||
(
|
||||
"ORDER BY fields have been ignored because PostgreSQL requires them to "
|
||||
"appear in the SELECT list when using DISTINCT or GROUP BY."
|
||||
"appear in the SELECT list when using with DISTINCT"
|
||||
),
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
|
|
@ -340,7 +347,7 @@ class Engine:
|
|||
|
||||
# Track field aliases for use in group_by/order_by
|
||||
for field in self.fields:
|
||||
if isinstance(field, Field | DynamicTableField) and field.alias:
|
||||
if isinstance(field, Field | DynamicTableField | AggregateFunction) and field.alias:
|
||||
self.field_aliases.add(field.alias)
|
||||
|
||||
if self.apply_permissions:
|
||||
|
|
@ -1190,8 +1197,24 @@ class Engine:
|
|||
# Note: Comma handling is done in parse_fields before this method is called
|
||||
return self.parse_string_field(field)
|
||||
|
||||
def _normalize_postgres_order_field(self, field):
|
||||
"""In PostgreSQL order_by fields need to either be in group_by or be aggregated
|
||||
when used with select and group_by"""
|
||||
current_sql = field.get_sql() if hasattr(field, "get_sql") else str(field)
|
||||
if current_sql in self._grouped_queries:
|
||||
return field
|
||||
clean_name = current_sql.strip('"')
|
||||
if clean_name in self.field_aliases:
|
||||
return field
|
||||
if not isinstance(field, functions.AggregateFunction):
|
||||
return functions.Max(field)
|
||||
return field
|
||||
|
||||
def apply_group_by(self, group_by: str | None = None):
|
||||
parsed_group_by_fields = self._validate_group_by(group_by)
|
||||
self._grouped_queries = {
|
||||
f.get_sql() if hasattr(f, "get_sql") else str(f) for f in parsed_group_by_fields
|
||||
}
|
||||
self.query = self.query.groupby(*parsed_group_by_fields)
|
||||
|
||||
def apply_order_by(self, order_by: str | None):
|
||||
|
|
@ -1201,7 +1224,12 @@ class Engine:
|
|||
|
||||
parsed_order_fields = self._validate_order_by(order_by)
|
||||
for order_field, order_direction in parsed_order_fields:
|
||||
self.query = self.query.orderby(order_field, order=order_direction)
|
||||
if self.is_postgres and self.is_aggregate_query:
|
||||
self.query = self.query.orderby(
|
||||
self._normalize_postgres_order_field(order_field), order=order_direction
|
||||
)
|
||||
else:
|
||||
self.query = self.query.orderby(order_field, order=order_direction)
|
||||
|
||||
def _apply_default_order_by(self):
|
||||
"""Apply default ordering based on configured DocType metadata"""
|
||||
|
|
@ -1220,14 +1248,24 @@ class Engine:
|
|||
order_direction = Order.desc if spec_order == "desc" else Order.asc
|
||||
else:
|
||||
order_direction = Order.asc if spec_order == "asc" else Order.desc
|
||||
self.query = self.query.orderby(field, order=order_direction)
|
||||
if self.is_postgres and self.is_aggregate_query:
|
||||
self.query = self.query.orderby(
|
||||
self._normalize_postgres_order_field(field), order=order_direction
|
||||
)
|
||||
else:
|
||||
self.query = self.query.orderby(field, order=order_direction)
|
||||
else:
|
||||
field = self.table[sort_field]
|
||||
if self.db_query_compat:
|
||||
order_direction = Order.desc if sort_order.lower() == "desc" else Order.asc
|
||||
else:
|
||||
order_direction = Order.asc if sort_order.lower() == "asc" else Order.desc
|
||||
self.query = self.query.orderby(field, order=order_direction)
|
||||
if self.is_postgres and self.is_aggregate_query:
|
||||
self.query = self.query.orderby(
|
||||
self._normalize_postgres_order_field(field), order=order_direction
|
||||
)
|
||||
else:
|
||||
self.query = self.query.orderby(field, order=order_direction)
|
||||
|
||||
def _parse_backtick_field_notation(self, field_name: str) -> tuple[str, str] | None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2345,6 +2345,100 @@ class TestQuery(IntegrationTestCase):
|
|||
# the filter should still apply and return no results
|
||||
self.assertEqual(len(result), 0, "Filter should not be bypassed by shared doc OR condition")
|
||||
|
||||
@run_only_if(db_type_is.POSTGRES)
|
||||
def test_order_by_group_by_postgres(self):
|
||||
"""PostgreSQL specific test that tests if order_by fields are correctly handled when used with group_by"""
|
||||
# test order by fields already in group by (no aggregate needed)
|
||||
query = frappe.qb.get_query(
|
||||
"User",
|
||||
fields=["creation as created_date", {"COUNT": "*"}],
|
||||
group_by="created_date",
|
||||
order_by="created_date",
|
||||
).get_sql()
|
||||
|
||||
self.assertQueryEqual(
|
||||
query,
|
||||
'SELECT "creation" "created_date",COUNT(*) FROM "tabUser" GROUP BY "created_date" ORDER BY "created_date" DESC',
|
||||
)
|
||||
|
||||
# test order by fields not in group by (aggregate needed)
|
||||
query = frappe.qb.get_query(
|
||||
"User",
|
||||
fields=["creation as created_date", {"COUNT": "*"}],
|
||||
group_by="created_date",
|
||||
order_by="name",
|
||||
).get_sql()
|
||||
|
||||
self.assertQueryEqual(
|
||||
query,
|
||||
'SELECT "creation" "created_date",COUNT(*) FROM "tabUser" GROUP BY "created_date" ORDER BY MAX("name") DESC',
|
||||
)
|
||||
|
||||
query = frappe.qb.get_query(
|
||||
"User",
|
||||
fields=["user_type as type", "enabled as status", {"COUNT": "*"}],
|
||||
group_by="type, status",
|
||||
order_by="status asc",
|
||||
).get_sql()
|
||||
|
||||
self.assertQueryEqual(
|
||||
query,
|
||||
'SELECT "user_type" "type","enabled" "status",COUNT(*) FROM "tabUser" GROUP BY "type","status" ORDER BY "status" ASC',
|
||||
)
|
||||
|
||||
# test no double aggregation rule
|
||||
query = frappe.qb.get_query(
|
||||
"User",
|
||||
fields=["creation", {"COUNT": "*", "as": "total"}],
|
||||
group_by="creation",
|
||||
order_by="total desc",
|
||||
).get_sql()
|
||||
|
||||
self.assertQueryEqual(
|
||||
query,
|
||||
'SELECT "creation",COUNT(*) "total" FROM "tabUser" GROUP BY "creation" ORDER BY "total" DESC',
|
||||
)
|
||||
|
||||
# test multiple order_by fields not in group_by
|
||||
query = frappe.qb.get_query(
|
||||
"User",
|
||||
fields=["user_type", {"COUNT": "*"}],
|
||||
group_by="user_type",
|
||||
order_by="creation desc, modified asc",
|
||||
).get_sql()
|
||||
|
||||
self.assertIn('MAX("creation") DESC', query)
|
||||
self.assertIn('MAX("modified") ASC', query)
|
||||
|
||||
# for queries that have aggregate fields selected but not grouped (these queries are redundant but exist in some parts of codebase)
|
||||
query = frappe.qb.get_query(
|
||||
"User", fields=[{"COUNT": "*", "as": "result"}], order_by="creation desc"
|
||||
).get_sql()
|
||||
|
||||
self.assertQueryEqual(query, 'SELECT COUNT(*) "result" FROM "tabUser" ORDER BY MAX("creation") DESC')
|
||||
|
||||
# test in case user uses `original_col` name instead of alias
|
||||
query = frappe.qb.get_query(
|
||||
"User", fields=["name as user_name"], group_by="user_name", order_by="user_name"
|
||||
)
|
||||
a = query.run()
|
||||
|
||||
query = frappe.qb.get_query("User", fields=["name as user_name"], group_by="name", order_by="name")
|
||||
b = query.run()
|
||||
|
||||
query = frappe.qb.get_query(
|
||||
"User", fields=["name as user_name"], group_by="name", order_by="user_name"
|
||||
)
|
||||
c = query.run()
|
||||
|
||||
query = frappe.qb.get_query(
|
||||
"User", fields=["name as user_name"], group_by="user_name", order_by="name"
|
||||
)
|
||||
d = query.run()
|
||||
|
||||
for val in [b, c, d]:
|
||||
self.assertEqual(a, val, "Query result mismatch detected.")
|
||||
|
||||
@run_only_if(db_type_is.POSTGRES)
|
||||
def test_ifnull_fallback_postgres(self):
|
||||
"""Test ifnull fallback in postgres"""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue