fix(query): aggregate order_field when used with select group_by
This commit is contained in:
parent
cba1150676
commit
cb68c2df32
1 changed files with 34 additions and 4 deletions
|
|
@ -266,6 +266,8 @@ class Engine:
|
|||
self.field_aliases = set()
|
||||
self.db_query_compat = db_query_compat
|
||||
self.permitted_fields_cache = {} # Cache for get_permitted_fields results
|
||||
self.is_aggregate_query = False
|
||||
self._grouped_sqls = set()
|
||||
|
||||
if isinstance(table, Table):
|
||||
self.table = table
|
||||
|
|
@ -314,11 +316,12 @@ class Engine:
|
|||
self.query = self.query.for_update(skip_locked=skip_locked, nowait=not wait)
|
||||
|
||||
if group_by:
|
||||
self.is_aggregate_query = True # for postgres (group by used with order by)
|
||||
self.apply_group_by(group_by)
|
||||
|
||||
if order_by:
|
||||
if not (
|
||||
self.is_postgres and is_select and (distinct or group_by)
|
||||
self.is_postgres and is_select and distinct
|
||||
): # ignore in Postgres since order by fields need to appear in select distinct
|
||||
self.apply_order_by(order_by)
|
||||
else:
|
||||
|
|
@ -351,6 +354,10 @@ class Engine:
|
|||
for field in self.fields:
|
||||
if isinstance(field, Field | DynamicTableField) and field.alias:
|
||||
self.field_aliases.add(field.alias)
|
||||
elif self.is_postgres and getattr(
|
||||
field, "alias", None
|
||||
): # captures aggregate functions (for pg order by fix)
|
||||
self.field_aliases.add(field.alias)
|
||||
|
||||
if self.apply_permissions:
|
||||
self.fields = self.apply_field_permissions()
|
||||
|
|
@ -1120,8 +1127,25 @@ class Engine:
|
|||
# Note: Comma handling is done in parse_fields before this method is called
|
||||
return self.parse_string_field(field)
|
||||
|
||||
def _normalize_postgres_order_field(self, field):
|
||||
"""In PostgreSQL order_by fields need to either be in group_by or be aggregated
|
||||
when used with select and group_by"""
|
||||
if self.is_postgres and self.is_aggregate_query:
|
||||
current_sql = field.get_sql() if hasattr(field, "get_sql") else str(field)
|
||||
if current_sql in self._grouped_sqls:
|
||||
return field
|
||||
clean_name = current_sql.strip('"')
|
||||
if clean_name in self.field_aliases:
|
||||
return field
|
||||
if not isinstance(field, functions.AggregateFunction):
|
||||
return functions.Max(field)
|
||||
return field
|
||||
|
||||
def apply_group_by(self, group_by: str | None = None):
|
||||
parsed_group_by_fields = self._validate_group_by(group_by)
|
||||
self._grouped_sqls = {
|
||||
f.get_sql() if hasattr(f, "get_sql") else str(f) for f in parsed_group_by_fields
|
||||
}
|
||||
self.query = self.query.groupby(*parsed_group_by_fields)
|
||||
|
||||
def apply_order_by(self, order_by: str | None):
|
||||
|
|
@ -1131,7 +1155,9 @@ class Engine:
|
|||
|
||||
parsed_order_fields = self._validate_order_by(order_by)
|
||||
for order_field, order_direction in parsed_order_fields:
|
||||
self.query = self.query.orderby(order_field, order=order_direction)
|
||||
self.query = self.query.orderby(
|
||||
self._normalize_postgres_order_field(order_field), order=order_direction
|
||||
)
|
||||
|
||||
def _apply_default_order_by(self):
|
||||
"""Apply default ordering based on configured DocType metadata"""
|
||||
|
|
@ -1150,14 +1176,18 @@ class Engine:
|
|||
order_direction = Order.desc if spec_order == "desc" else Order.asc
|
||||
else:
|
||||
order_direction = Order.asc if spec_order == "asc" else Order.desc
|
||||
self.query = self.query.orderby(field, order=order_direction)
|
||||
self.query = self.query.orderby(
|
||||
self._normalize_postgres_order_field(field), order=order_direction
|
||||
)
|
||||
else:
|
||||
field = self.table[sort_field]
|
||||
if self.db_query_compat:
|
||||
order_direction = Order.desc if sort_order.lower() == "desc" else Order.asc
|
||||
else:
|
||||
order_direction = Order.asc if sort_order.lower() == "asc" else Order.desc
|
||||
self.query = self.query.orderby(field, order=order_direction)
|
||||
self.query = self.query.orderby(
|
||||
self._normalize_postgres_order_field(field), order=order_direction
|
||||
)
|
||||
|
||||
def _parse_backtick_field_notation(self, field_name: str) -> tuple[str, str] | None:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue