From 75377aaaf5512b6736f502e3a37dd52bc31dd263 Mon Sep 17 00:00:00 2001 From: David Arnold Date: Thu, 5 Dec 2024 00:18:53 +0100 Subject: [PATCH] refactor(typing): type filters (#28218) * chore(typing): type filters * chore(typing): type filters for get_list et al * fix: dashboard chart filter expression * test: fix case with new-style right hand object to equality check * chore: place new typed filter under typing verification * chore: remove debug print statment * chore: inverse logic of type guard * fix: add float to filter value types * chore: clarify value naming --- frappe/__init__.py | 10 +- .../dashboard_chart/dashboard_chart.py | 10 +- frappe/model/db_query.py | 70 ++--- frappe/tests/test_db_query.py | 6 +- frappe/types/__init__.py | 1 + frappe/types/filter.py | 279 ++++++++++++++++++ frappe/utils/data.py | 52 ++-- pyproject.toml | 2 + 8 files changed, 343 insertions(+), 87 deletions(-) create mode 100644 frappe/types/filter.py diff --git a/frappe/__init__.py b/frappe/__init__.py index afc601812a..017688e4a7 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -53,7 +53,7 @@ from frappe.utils.data import cint, cstr, sbool # Local application imports from .exceptions import * -from .types.frappedict import _dict +from .types import Filters, FilterSignature, FilterTuple, _dict from .utils.jinja import ( get_email_from_template, get_jenv, @@ -1383,7 +1383,13 @@ def get_doc(*args: Any, **kwargs: Any) -> "Document": return doc -def get_last_doc(doctype, filters=None, order_by="creation desc", *, for_update=False): +def get_last_doc( + doctype, + filters: FilterSignature | None = None, + order_by="creation desc", + *, + for_update=False, +): """Get last created document of this type.""" d = get_all(doctype, filters=filters, limit_page_length=1, order_by=order_by, pluck="name") if d: diff --git a/frappe/desk/doctype/dashboard_chart/dashboard_chart.py b/frappe/desk/doctype/dashboard_chart/dashboard_chart.py index 6be19e09f0..786fb8200c 100644 --- a/frappe/desk/doctype/dashboard_chart/dashboard_chart.py +++ b/frappe/desk/doctype/dashboard_chart/dashboard_chart.py @@ -125,7 +125,7 @@ def get( filters = [] # don't include cancelled documents - filters.append([chart.document_type, "docstatus", "<", 2, False]) + filters.append([chart.document_type, "docstatus", "<", 2]) if chart.chart_type == "Group By": chart_config = get_group_by_chart_config(chart, filters) @@ -196,8 +196,8 @@ def get_chart_config(chart, filters, timespan, timegrain, from_date, to_date): from_date = from_date.strftime("%Y-%m-%d") to_date = to_date - filters.append([doctype, datefield, ">=", from_date, False]) - filters.append([doctype, datefield, "<=", to_date, False]) + filters.append([doctype, datefield, ">=", from_date]) + filters.append([doctype, datefield, "<=", to_date]) data = frappe.get_list( doctype, @@ -231,8 +231,8 @@ def get_heatmap_chart_config(chart, filters, heatmap_year): year_start_date = datetime.date(year, 1, 1).strftime("%Y-%m-%d") next_year_start_date = datetime.date(year + 1, 1, 1).strftime("%Y-%m-%d") - filters.append([doctype, datefield, ">", f"{year_start_date}", False]) - filters.append([doctype, datefield, "<", f"{next_year_start_date}", False]) + filters.append([doctype, datefield, ">", f"{year_start_date}"]) + filters.append([doctype, datefield, "<", f"{next_year_start_date}"]) if frappe.db.db_type == "mariadb": timestamp_field = f"unix_timestamp({datefield})" diff --git a/frappe/model/db_query.py b/frappe/model/db_query.py index c616d391d5..549d81f54a 100644 --- a/frappe/model/db_query.py +++ b/frappe/model/db_query.py @@ -7,7 +7,7 @@ import datetime import json import re from collections import Counter -from collections.abc import Sequence +from collections.abc import Mapping, Sequence import frappe import frappe.defaults @@ -21,6 +21,7 @@ from frappe.model.meta import get_table_columns from frappe.model.utils import is_virtual_doctype from frappe.model.utils.user_settings import get_user_settings, update_user_settings from frappe.query_builder.utils import Column +from frappe.types import Filters, FilterSignature, FilterTuple from frappe.utils import ( cint, cstr, @@ -28,7 +29,6 @@ from frappe.utils import ( get_filter, get_time, get_timespan_date_range, - make_filter_tuple, ) from frappe.utils.data import DateTimeLikeObject, get_datetime, getdate, sbool @@ -79,8 +79,8 @@ class DatabaseQuery: def execute( self, fields=None, - filters=None, - or_filters=None, + filters: FilterSignature | str | None = None, + or_filters: FilterSignature | None = None, docstatus=None, group_by=None, order_by=DefaultOrderBy, @@ -137,8 +137,18 @@ class DatabaseQuery: if as_list and not isinstance(self.fields, (Sequence | str)) and len(self.fields) > 1: frappe.throw(_("Fields must be a list or tuple when as_list is enabled")) - self.filters = filters or [] - self.or_filters = or_filters or [] + self.filters: Filters + self.or_filters: Filters + for k, _filters in { + "filters": filters or Filters(), + "or_filters": or_filters or Filters(), + }.items(): + if isinstance(_filters, str): + _filters = json.loads(_filters) + if not isinstance(_filters, Filters): + _filters = Filters(_filters, doctype=self.doctype) + setattr(self, k, _filters) + self.docstatus = docstatus or [] self.group_by = group_by self.order_by = order_by @@ -362,16 +372,6 @@ class DatabaseQuery: field = f"{field} as {alias}" self.fields[self.fields.index(original_field)] = field - for filter_name in ["filters", "or_filters"]: - filters = getattr(self, filter_name) - if isinstance(filters, str): - filters = json.loads(filters) - - if isinstance(filters, dict): - fdict = filters - filters = [make_filter_tuple(self.doctype, key, value) for key, value in fdict.items()] - setattr(self, filter_name, filters) - def sanitize_fields(self): """ regex : ^.*[,();].* @@ -554,27 +554,11 @@ class DatabaseQuery: def set_optional_columns(self): """Removes optional columns like `_user_tags`, `_comments` etc. if not in table""" - # remove from fields - to_remove = [] - for fld in self.fields: - to_remove.extend(fld for f in optional_fields if f in fld and f not in self.columns) - for fld in to_remove: - del self.fields[self.fields.index(fld)] - # remove from filters - to_remove = [] - for each in self.filters: - if isinstance(each, str): - each = [each] - - to_remove.extend( - each for element in each if element in optional_fields and element not in self.columns - ) - for each in to_remove: - if isinstance(self.filters, dict): - del self.filters[each] - else: - self.filters.remove(each) + self.fields[:] = [f for f in self.fields if f not in optional_fields or f in self.columns] + self.filters[:] = [ + f for f in self.filters if f.fieldname not in optional_fields or f.fieldname in self.columns + ] def build_conditions(self): self.conditions = [] @@ -588,19 +572,13 @@ class DatabaseQuery: if match_conditions: self.conditions.append(f"({match_conditions})") - def build_filter_conditions(self, filters, conditions: list, ignore_permissions=None): + def build_filter_conditions(self, filters: Filters, conditions: list, ignore_permissions=None): """build conditions from user filters""" if ignore_permissions is not None: self.flags.ignore_permissions = ignore_permissions - if isinstance(filters, dict): - filters = [filters] - for f in filters: - if isinstance(f, str): - conditions.append(f) - else: - conditions.append(self.prepare_filter_condition(f)) + conditions.append(self.prepare_filter_condition(f)) def remove_field(self, idx: int): if self.as_list: @@ -701,7 +679,7 @@ class DatabaseQuery: self.fields[i + j : i + j + 1] = permitted_fields j = j + len(permitted_fields) - 1 - def prepare_filter_condition(self, f): + def prepare_filter_condition(self, ft: FilterTuple) -> str: """Return a filter condition in the format: ifnull(`tabDocType`.`fieldname`, fallback) operator "value" @@ -712,7 +690,7 @@ class DatabaseQuery: from frappe.boot import get_additional_filters_from_hooks additional_filters_config = get_additional_filters_from_hooks() - f = get_filter(self.doctype, f, additional_filters_config) + f: FilterTuple = get_filter(self.doctype, ft, additional_filters_config) tname = "`tab" + f.doctype + "`" if tname not in self.tables: diff --git a/frappe/tests/test_db_query.py b/frappe/tests/test_db_query.py index df3cc1e27a..8904e51a38 100644 --- a/frappe/tests/test_db_query.py +++ b/frappe/tests/test_db_query.py @@ -1082,8 +1082,12 @@ class TestDBQuery(IntegrationTestCase): class VirtualDocType: @staticmethod def get_list(args=None, limit_page_length=0, doctype=None): + from frappe.types.filter import FilterTuple + # Backward compatibility - self.assertEqual(args["filters"], [["Virtual DocType", "name", "=", "test"]]) + self.assertEqual( + args["filters"], [FilterTuple(doctype="Virtual DocType", fieldname="name", value="test")] + ) self.assertEqual(limit_page_length, 1) self.assertEqual(doctype, "Virtual DocType") diff --git a/frappe/types/__init__.py b/frappe/types/__init__.py index de1873b02a..bb33bcc053 100644 --- a/frappe/types/__init__.py +++ b/frappe/types/__init__.py @@ -1,2 +1,3 @@ from .docref import DocRef +from .filter import Filters, FilterSignature, FilterTuple from .frappedict import _dict diff --git a/frappe/types/filter.py b/frappe/types/filter.py new file mode 100644 index 0000000000..6d6cf1b30b --- /dev/null +++ b/frappe/types/filter.py @@ -0,0 +1,279 @@ +import textwrap +from collections import defaultdict +from collections.abc import Generator, Iterable, Mapping, Sequence +from datetime import date, datetime +from itertools import groupby +from operator import attrgetter +from typing import Any, NamedTuple, TypeAlias, TypeGuard, TypeVar, cast + +from pypika import Column +from typing_extensions import Self, override + +from .docref import DocRef + +Doct: TypeAlias = str +Fld: TypeAlias = str +Op: TypeAlias = str +DateTime: TypeAlias = datetime | date +_Value: TypeAlias = str | int | float | None | DateTime | Column +_InputValue: TypeAlias = _Value | DocRef | bool +Value: TypeAlias = _Value | Sequence[_Value] +InputValue: TypeAlias = _InputValue | Sequence[_InputValue] + + +FilterTupleSpec: TypeAlias = ( + tuple[Fld, InputValue] | tuple[Fld, Op, InputValue] | tuple[Doct, Fld, Op, InputValue] +) +FilterMappingSpec: TypeAlias = Mapping[Fld, _InputValue | tuple[Op, InputValue]] + + +class Sentinel: + def __bool__(self) -> bool: + return False + + @override + def __str__(self) -> str: + return "UNSPECIFIED" + + +UNSPECIFIED = Sentinel() + +T = TypeVar("T") + + +def is_unspecified(value: T | Sentinel) -> TypeGuard[Sentinel]: + return value is UNSPECIFIED + + +class _FilterTuple(NamedTuple): + doctype: Doct + fieldname: Fld + operator: Op + value: Value + + +def _type_narrow(v: _InputValue) -> _Value: + if isinstance(v, bool): # beware: bool derives int in _Value + return int(v) + elif isinstance(v, _Value): + return v + elif isinstance(v, DocRef): # type: ignore[redundant-expr] + return v.__value__() + else: + raise ValueError( + f"Value must be one of types: {', '.join(str(t.__name__) for t in _InputValue.__args__)}; found {type(v)}" + ) + + +class FilterTuple(_FilterTuple): + """A named tuple representing a filter condition.""" + + def __new__( + cls, + s: FilterTupleSpec | None = None, + /, + *, + doctype: Doct | Sentinel = UNSPECIFIED, + fieldname: Fld | Sentinel = UNSPECIFIED, + operator: Op = "=", + value: InputValue | Sentinel = UNSPECIFIED, + ) -> Self: + """ + Create a new FilterTuple instance. + Args: + s: A sequence representing the filter tuple. + doctype: The document type. + fieldname: The field name. + operator: The comparison operator. + value: The value to compare against. + Returns: + A new FilterTuple instance. + """ + try: + if isinstance(s, Sequence): + if len(s) == 2: + fieldname, value = s + elif len(s) == 3: + fieldname, operator, value = s + elif len(s) == 4: # type: ignore[redundant-expr] + doctype, fieldname, operator, value = s + else: + raise ValueError(f"Invalid sequence length: {len(s)}. Expected 2, 3, or 4 elements.") + if is_unspecified(doctype) or doctype is None: + raise ValueError("doctype is required") + if is_unspecified(fieldname) or fieldname is None: + raise ValueError("fieldname is required") + if is_unspecified(value): + raise ValueError("value is required; can be None") + + # soundness + if operator in ("in", "not in") and isinstance(value, str): + value = value.split(",") + + _value: Value + if isinstance(value, _InputValue): + _value = _type_narrow(value) + else: + _value = tuple(_type_narrow(v) for v in value) + + return super().__new__( + cls, + doctype=doctype, + fieldname=fieldname, + operator=operator, + value=_value, + ) + + except Exception as e: + error_context = ( + f"Error creating FilterTuple:\n" + f"Input: {s}, doctype={doctype}, fieldname={fieldname}, operator={operator}, value={value}\n" + f"Error: {e!s}\n" + f"Usage: FilterTuple( (fieldname, value), doctype=dt )\n" + f" FilterTuple( (fieldname, operator, value), doctype=dt )\n" + f" FilterTuple( (doctype, fieldname, operator, value) )\n" + f" FilterTuple( doctype=doctype, fieldname=fieldname, operator=operator, value=value )" + ) + raise ValueError(error_context) from e + + @override + def __str__(self) -> str: + value_repr = f"'{self.value}'" if isinstance(self.value, str) else repr(self.value) + return f"<{self.doctype}>.{self.fieldname} {self.operator} {value_repr}" + + +class Filters(list[FilterTuple]): + """A sequence of filter tuples representing multiple filter conditions.""" + + def __init__( + self, + /, + *s: FilterTuple + | FilterTupleSpec + | FilterMappingSpec + | Sequence[FilterTuple | FilterTupleSpec | FilterMappingSpec], + doctype: Doct | Sentinel = UNSPECIFIED, + ) -> None: + """ + Create a new Filters instance. + + Args: + s: A sequence of FilterTuple or FilterTupleSpec, or a FilterMappingSpec. + doctype: The document type for the filters. + + Returns: + A new Filters instance. + """ + super().__init__() + try: + # only one argument + if len(s) == 1: + # and that is an empty sequence + if len(s[0]) == 0: + return + # compat: unpack if first argument was Sequence of Sequences + if ( + not isinstance(s[0], FilterTuple) + and isinstance(s[0], Sequence) + and not isinstance(s[0][0], str) # it's a FilterTupleSpec + and isinstance(s[0][0], Sequence | Mapping) + ): + self.extend( + cast(Iterable[FilterTuple | FilterTupleSpec | FilterMappingSpec], s[0]), doctype + ) + else: + self.extend(cast(Iterable[FilterTuple | FilterTupleSpec | FilterMappingSpec], s), doctype) + else: + self.extend(cast(Iterable[FilterTuple | FilterTupleSpec | FilterMappingSpec], s), doctype) + except Exception as e: + error_lines = str(e).split("\n") + indented_error = error_lines[0] + "\n" + textwrap.indent("\n".join(error_lines[1:]), " ") + error_context = ( + f"\nError creating Filters:\n" + f"Input: {s}, doctype={doctype}\n" + f"Usage: Filters( FilterTuple(...), ... )\n" + f" Filters( (fieldnam, value), ... doctype=dt )\n" + f" Filters( (fieldname, operator, value), ... doctype=dt )\n" + f" Filters( (doctype, fieldname, operator, value), ... )\n" + f" Filters( {{'fieldname': value, ...}}, ... doctype=dt )\n" + f" Filters( {{'fieldname': (operator, value), ...}}, ... doctype=dt )\n\n" + f"Cause:\n{indented_error}" + ) + raise ValueError(error_context) from e + + if self: # only optimize non-empty; avoid infinit recursion + self.optimize() + + @override + def extend( + self, + values: Iterable[FilterTuple | FilterTupleSpec | FilterMappingSpec], + doctype: Doct | Sentinel = UNSPECIFIED, + ) -> None: + for item in values: + self.append(item, doctype) + + @override + def append( + self, value: FilterTuple | FilterTupleSpec | FilterMappingSpec, doctype: Doct | Sentinel = UNSPECIFIED + ) -> None: + if isinstance(value, FilterTuple): + super().append(value) + elif isinstance(value, Mapping): + if is_unspecified(doctype) or doctype is None: + raise ValueError("When initiated with a mapping, doctype keyword argument is required") + self._init_from_mapping(value, doctype) + elif isinstance(value, Sequence): # type: ignore[redundant-expr] + super().append(FilterTuple(value, doctype=doctype)) + else: + raise TypeError(f"Expected FilterTruple, Mapping or Sequence, got {type(value).__name__}") + + def _init_from_mapping(self, s: FilterMappingSpec, doctype: Doct) -> None: + for k, v in s.items(): + if isinstance(v, _InputValue): + self.append(FilterTuple(doctype=doctype, fieldname=k, value=v)) + elif isinstance(v, Sequence): # type: ignore[redundant-expr] + self.append(FilterTuple(doctype=doctype, fieldname=k, operator=v[0], value=v[1])) + else: + raise ValueError(f"Invalid value for key '{k}': expected value or (operator, value[s]) tuple") + + def optimize(self) -> None: + """Optimize the filters by grouping '=' operators into 'in' operators where possible.""" + + def group_key(f: FilterTuple) -> tuple[str, str, bool]: + return (f.doctype, f.fieldname, f.operator == "=") + + optimized = Filters() + for (doctype, fieldname, collatable), filters in groupby(sorted(self, key=group_key), key=group_key): + if not collatable: + optimized.extend(filters) + else: + + def _values() -> Generator[_Value, None, None]: + for f in filters: + # f.value is already narrowed to Val when we optimize over fully initialized FilterTuple + yield cast(_Value, f.value) # = operator only is allowed to have _Value + + values = tuple(_values()) + + _op = "in" if len(values) > 1 else "=" + optimized.append( + FilterTuple( + doctype=doctype, + fieldname=fieldname, + operator=_op, + value=values if _op == "in" else values[0], + ) + ) + self[:] = optimized + + @override + def __str__(self) -> str: + if not self: + return "Filters()" + + filters_str = "\n".join(f" {filter}" for filter in self) + return f"Filters(\n{filters_str}\n)" + + +FilterSignature: TypeAlias = Filters | FilterTuple | FilterMappingSpec | FilterTupleSpec diff --git a/frappe/utils/data.py b/frappe/utils/data.py index c83f9bbc1b..c5d62029ad 100644 --- a/frappe/utils/data.py +++ b/frappe/utils/data.py @@ -25,6 +25,7 @@ from dateutil.relativedelta import relativedelta import frappe from frappe.desk.utils import slug from frappe.locale import get_date_format, get_first_day_of_the_week, get_number_format, get_time_format +from frappe.types.filter import Filters, FilterSignature, FilterTuple from frappe.utils.deprecations import deprecated from frappe.utils.number_format import NUMBER_FORMAT_MAP, NumberFormat @@ -52,6 +53,8 @@ TimespanOptions = Literal[ if typing.TYPE_CHECKING: + from collections.abc import Mapping + from PIL.ImageFile import ImageFile as PILImageFile T = TypeVar("T") @@ -1984,19 +1987,14 @@ operator_map = { } -def evaluate_filters(doc, filters: dict | list | tuple): +def evaluate_filters(doc: "Mapping", filters: FilterSignature): """Return True if doc matches filters.""" - if isinstance(filters, dict): - for key, value in filters.items(): - f = get_filter(None, {key: value}) - if not compare(doc.get(f.fieldname), f.operator, f.value, f.fieldtype): - return False - - elif isinstance(filters, list | tuple): - for d in filters: - f = get_filter(None, d) - if not compare(doc.get(f.fieldname), f.operator, f.value, f.fieldtype): - return False + if not isinstance(filters, Filters): + filters = Filters(filters, doctype=doc.get("doctype")) + for d in filters: + f = get_filter(None, d) + if not compare(doc.get(f.fieldname), f.operator, f.value, f.fieldtype): + return False return True @@ -2011,7 +2009,7 @@ def compare(val1: Any, condition: str, val2: Any, fieldtype: str | None = None): return False -def get_filter(doctype: str, f: dict | list | tuple, filters_config=None) -> "frappe._dict": +def get_filter(doctype: str, filters: FilterSignature, filters_config=None) -> "frappe._dict": """Return a `_dict` like: { @@ -2025,30 +2023,18 @@ def get_filter(doctype: str, f: dict | list | tuple, filters_config=None) -> "fr from frappe.database.utils import NestedSetHierarchy from frappe.model import child_table_fields, default_fields, optional_fields - if isinstance(f, dict): - key, value = next(iter(f.items())) - f = make_filter_tuple(doctype, key, value) + ft: FilterTuple + if isinstance(filters, FilterTuple): + ft = filters + elif not isinstance(filters, Filters): + ft = Filters(filters, doctype=doctype)[0] + else: + ft = filters[0] - if not isinstance(f, list | tuple): - frappe.throw(frappe._("Filter must be a tuple or list (in a list)")) - - if len(f) == 3: - f = (doctype, f[0], f[1], f[2]) - elif len(f) > 4: - f = f[0:4] - elif len(f) != 4: - frappe.throw( - frappe._("Filter must have 4 values (doctype, fieldname, operator, value): {0}").format(str(f)) - ) - - f = frappe._dict(doctype=f[0], fieldname=f[1], operator=f[2], value=f[3]) + f = frappe._dict(doctype=ft[0], fieldname=ft[1], operator=ft[2], value=ft[3]) sanitize_column(f.fieldname) - if not f.operator: - # if operator is missing - f.operator = "=" - valid_operators = ( "=", "!=", diff --git a/pyproject.toml b/pyproject.toml index dee710b240..0156e4a64d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,7 @@ dev = [ "types-six", "types-vobject", "types-zxcvbn", + "pypika-stubs", # contributed ] test = [ "unittest-xml-reporting~=3.2.0", @@ -224,6 +225,7 @@ files = [ "frappe/types/DF.py", "frappe/types/docref.py", "frappe/types/frappedict.py", + "frappe/types/filter.py", ] exclude = [ # permanent excludes