refactor(typing): type filters (#28218)

* chore(typing): type filters

* chore(typing): type filters for get_list et al

* fix: dashboard chart filter expression

* test: fix case with new-style right hand object to equality check

* chore: place new typed filter under typing verification

* chore: remove debug print statment

* chore: inverse logic of type guard

* fix: add float to filter value types

* chore: clarify value naming
This commit is contained in:
David Arnold 2024-12-05 00:18:53 +01:00 committed by GitHub
parent d3cbd2d4be
commit 75377aaaf5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 343 additions and 87 deletions

View file

@ -53,7 +53,7 @@ from frappe.utils.data import cint, cstr, sbool
# Local application imports
from .exceptions import *
from .types.frappedict import _dict
from .types import Filters, FilterSignature, FilterTuple, _dict
from .utils.jinja import (
get_email_from_template,
get_jenv,
@ -1383,7 +1383,13 @@ def get_doc(*args: Any, **kwargs: Any) -> "Document":
return doc
def get_last_doc(doctype, filters=None, order_by="creation desc", *, for_update=False):
def get_last_doc(
doctype,
filters: FilterSignature | None = None,
order_by="creation desc",
*,
for_update=False,
):
"""Get last created document of this type."""
d = get_all(doctype, filters=filters, limit_page_length=1, order_by=order_by, pluck="name")
if d:

View file

@ -125,7 +125,7 @@ def get(
filters = []
# don't include cancelled documents
filters.append([chart.document_type, "docstatus", "<", 2, False])
filters.append([chart.document_type, "docstatus", "<", 2])
if chart.chart_type == "Group By":
chart_config = get_group_by_chart_config(chart, filters)
@ -196,8 +196,8 @@ def get_chart_config(chart, filters, timespan, timegrain, from_date, to_date):
from_date = from_date.strftime("%Y-%m-%d")
to_date = to_date
filters.append([doctype, datefield, ">=", from_date, False])
filters.append([doctype, datefield, "<=", to_date, False])
filters.append([doctype, datefield, ">=", from_date])
filters.append([doctype, datefield, "<=", to_date])
data = frappe.get_list(
doctype,
@ -231,8 +231,8 @@ def get_heatmap_chart_config(chart, filters, heatmap_year):
year_start_date = datetime.date(year, 1, 1).strftime("%Y-%m-%d")
next_year_start_date = datetime.date(year + 1, 1, 1).strftime("%Y-%m-%d")
filters.append([doctype, datefield, ">", f"{year_start_date}", False])
filters.append([doctype, datefield, "<", f"{next_year_start_date}", False])
filters.append([doctype, datefield, ">", f"{year_start_date}"])
filters.append([doctype, datefield, "<", f"{next_year_start_date}"])
if frappe.db.db_type == "mariadb":
timestamp_field = f"unix_timestamp({datefield})"

View file

@ -7,7 +7,7 @@ import datetime
import json
import re
from collections import Counter
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
import frappe
import frappe.defaults
@ -21,6 +21,7 @@ from frappe.model.meta import get_table_columns
from frappe.model.utils import is_virtual_doctype
from frappe.model.utils.user_settings import get_user_settings, update_user_settings
from frappe.query_builder.utils import Column
from frappe.types import Filters, FilterSignature, FilterTuple
from frappe.utils import (
cint,
cstr,
@ -28,7 +29,6 @@ from frappe.utils import (
get_filter,
get_time,
get_timespan_date_range,
make_filter_tuple,
)
from frappe.utils.data import DateTimeLikeObject, get_datetime, getdate, sbool
@ -79,8 +79,8 @@ class DatabaseQuery:
def execute(
self,
fields=None,
filters=None,
or_filters=None,
filters: FilterSignature | str | None = None,
or_filters: FilterSignature | None = None,
docstatus=None,
group_by=None,
order_by=DefaultOrderBy,
@ -137,8 +137,18 @@ class DatabaseQuery:
if as_list and not isinstance(self.fields, (Sequence | str)) and len(self.fields) > 1:
frappe.throw(_("Fields must be a list or tuple when as_list is enabled"))
self.filters = filters or []
self.or_filters = or_filters or []
self.filters: Filters
self.or_filters: Filters
for k, _filters in {
"filters": filters or Filters(),
"or_filters": or_filters or Filters(),
}.items():
if isinstance(_filters, str):
_filters = json.loads(_filters)
if not isinstance(_filters, Filters):
_filters = Filters(_filters, doctype=self.doctype)
setattr(self, k, _filters)
self.docstatus = docstatus or []
self.group_by = group_by
self.order_by = order_by
@ -362,16 +372,6 @@ class DatabaseQuery:
field = f"{field} as {alias}"
self.fields[self.fields.index(original_field)] = field
for filter_name in ["filters", "or_filters"]:
filters = getattr(self, filter_name)
if isinstance(filters, str):
filters = json.loads(filters)
if isinstance(filters, dict):
fdict = filters
filters = [make_filter_tuple(self.doctype, key, value) for key, value in fdict.items()]
setattr(self, filter_name, filters)
def sanitize_fields(self):
"""
regex : ^.*[,();].*
@ -554,27 +554,11 @@ class DatabaseQuery:
def set_optional_columns(self):
"""Removes optional columns like `_user_tags`, `_comments` etc. if not in table"""
# remove from fields
to_remove = []
for fld in self.fields:
to_remove.extend(fld for f in optional_fields if f in fld and f not in self.columns)
for fld in to_remove:
del self.fields[self.fields.index(fld)]
# remove from filters
to_remove = []
for each in self.filters:
if isinstance(each, str):
each = [each]
to_remove.extend(
each for element in each if element in optional_fields and element not in self.columns
)
for each in to_remove:
if isinstance(self.filters, dict):
del self.filters[each]
else:
self.filters.remove(each)
self.fields[:] = [f for f in self.fields if f not in optional_fields or f in self.columns]
self.filters[:] = [
f for f in self.filters if f.fieldname not in optional_fields or f.fieldname in self.columns
]
def build_conditions(self):
self.conditions = []
@ -588,19 +572,13 @@ class DatabaseQuery:
if match_conditions:
self.conditions.append(f"({match_conditions})")
def build_filter_conditions(self, filters, conditions: list, ignore_permissions=None):
def build_filter_conditions(self, filters: Filters, conditions: list, ignore_permissions=None):
"""build conditions from user filters"""
if ignore_permissions is not None:
self.flags.ignore_permissions = ignore_permissions
if isinstance(filters, dict):
filters = [filters]
for f in filters:
if isinstance(f, str):
conditions.append(f)
else:
conditions.append(self.prepare_filter_condition(f))
conditions.append(self.prepare_filter_condition(f))
def remove_field(self, idx: int):
if self.as_list:
@ -701,7 +679,7 @@ class DatabaseQuery:
self.fields[i + j : i + j + 1] = permitted_fields
j = j + len(permitted_fields) - 1
def prepare_filter_condition(self, f):
def prepare_filter_condition(self, ft: FilterTuple) -> str:
"""Return a filter condition in the format:
ifnull(`tabDocType`.`fieldname`, fallback) operator "value"
@ -712,7 +690,7 @@ class DatabaseQuery:
from frappe.boot import get_additional_filters_from_hooks
additional_filters_config = get_additional_filters_from_hooks()
f = get_filter(self.doctype, f, additional_filters_config)
f: FilterTuple = get_filter(self.doctype, ft, additional_filters_config)
tname = "`tab" + f.doctype + "`"
if tname not in self.tables:

View file

@ -1082,8 +1082,12 @@ class TestDBQuery(IntegrationTestCase):
class VirtualDocType:
@staticmethod
def get_list(args=None, limit_page_length=0, doctype=None):
from frappe.types.filter import FilterTuple
# Backward compatibility
self.assertEqual(args["filters"], [["Virtual DocType", "name", "=", "test"]])
self.assertEqual(
args["filters"], [FilterTuple(doctype="Virtual DocType", fieldname="name", value="test")]
)
self.assertEqual(limit_page_length, 1)
self.assertEqual(doctype, "Virtual DocType")

View file

@ -1,2 +1,3 @@
from .docref import DocRef
from .filter import Filters, FilterSignature, FilterTuple
from .frappedict import _dict

279
frappe/types/filter.py Normal file
View file

@ -0,0 +1,279 @@
import textwrap
from collections import defaultdict
from collections.abc import Generator, Iterable, Mapping, Sequence
from datetime import date, datetime
from itertools import groupby
from operator import attrgetter
from typing import Any, NamedTuple, TypeAlias, TypeGuard, TypeVar, cast
from pypika import Column
from typing_extensions import Self, override
from .docref import DocRef
Doct: TypeAlias = str
Fld: TypeAlias = str
Op: TypeAlias = str
DateTime: TypeAlias = datetime | date
_Value: TypeAlias = str | int | float | None | DateTime | Column
_InputValue: TypeAlias = _Value | DocRef | bool
Value: TypeAlias = _Value | Sequence[_Value]
InputValue: TypeAlias = _InputValue | Sequence[_InputValue]
FilterTupleSpec: TypeAlias = (
tuple[Fld, InputValue] | tuple[Fld, Op, InputValue] | tuple[Doct, Fld, Op, InputValue]
)
FilterMappingSpec: TypeAlias = Mapping[Fld, _InputValue | tuple[Op, InputValue]]
class Sentinel:
def __bool__(self) -> bool:
return False
@override
def __str__(self) -> str:
return "UNSPECIFIED"
UNSPECIFIED = Sentinel()
T = TypeVar("T")
def is_unspecified(value: T | Sentinel) -> TypeGuard[Sentinel]:
return value is UNSPECIFIED
class _FilterTuple(NamedTuple):
doctype: Doct
fieldname: Fld
operator: Op
value: Value
def _type_narrow(v: _InputValue) -> _Value:
if isinstance(v, bool): # beware: bool derives int in _Value
return int(v)
elif isinstance(v, _Value):
return v
elif isinstance(v, DocRef): # type: ignore[redundant-expr]
return v.__value__()
else:
raise ValueError(
f"Value must be one of types: {', '.join(str(t.__name__) for t in _InputValue.__args__)}; found {type(v)}"
)
class FilterTuple(_FilterTuple):
"""A named tuple representing a filter condition."""
def __new__(
cls,
s: FilterTupleSpec | None = None,
/,
*,
doctype: Doct | Sentinel = UNSPECIFIED,
fieldname: Fld | Sentinel = UNSPECIFIED,
operator: Op = "=",
value: InputValue | Sentinel = UNSPECIFIED,
) -> Self:
"""
Create a new FilterTuple instance.
Args:
s: A sequence representing the filter tuple.
doctype: The document type.
fieldname: The field name.
operator: The comparison operator.
value: The value to compare against.
Returns:
A new FilterTuple instance.
"""
try:
if isinstance(s, Sequence):
if len(s) == 2:
fieldname, value = s
elif len(s) == 3:
fieldname, operator, value = s
elif len(s) == 4: # type: ignore[redundant-expr]
doctype, fieldname, operator, value = s
else:
raise ValueError(f"Invalid sequence length: {len(s)}. Expected 2, 3, or 4 elements.")
if is_unspecified(doctype) or doctype is None:
raise ValueError("doctype is required")
if is_unspecified(fieldname) or fieldname is None:
raise ValueError("fieldname is required")
if is_unspecified(value):
raise ValueError("value is required; can be None")
# soundness
if operator in ("in", "not in") and isinstance(value, str):
value = value.split(",")
_value: Value
if isinstance(value, _InputValue):
_value = _type_narrow(value)
else:
_value = tuple(_type_narrow(v) for v in value)
return super().__new__(
cls,
doctype=doctype,
fieldname=fieldname,
operator=operator,
value=_value,
)
except Exception as e:
error_context = (
f"Error creating FilterTuple:\n"
f"Input: {s}, doctype={doctype}, fieldname={fieldname}, operator={operator}, value={value}\n"
f"Error: {e!s}\n"
f"Usage: FilterTuple( (fieldname, value), doctype=dt )\n"
f" FilterTuple( (fieldname, operator, value), doctype=dt )\n"
f" FilterTuple( (doctype, fieldname, operator, value) )\n"
f" FilterTuple( doctype=doctype, fieldname=fieldname, operator=operator, value=value )"
)
raise ValueError(error_context) from e
@override
def __str__(self) -> str:
value_repr = f"'{self.value}'" if isinstance(self.value, str) else repr(self.value)
return f"<{self.doctype}>.{self.fieldname} {self.operator} {value_repr}"
class Filters(list[FilterTuple]):
"""A sequence of filter tuples representing multiple filter conditions."""
def __init__(
self,
/,
*s: FilterTuple
| FilterTupleSpec
| FilterMappingSpec
| Sequence[FilterTuple | FilterTupleSpec | FilterMappingSpec],
doctype: Doct | Sentinel = UNSPECIFIED,
) -> None:
"""
Create a new Filters instance.
Args:
s: A sequence of FilterTuple or FilterTupleSpec, or a FilterMappingSpec.
doctype: The document type for the filters.
Returns:
A new Filters instance.
"""
super().__init__()
try:
# only one argument
if len(s) == 1:
# and that is an empty sequence
if len(s[0]) == 0:
return
# compat: unpack if first argument was Sequence of Sequences
if (
not isinstance(s[0], FilterTuple)
and isinstance(s[0], Sequence)
and not isinstance(s[0][0], str) # it's a FilterTupleSpec
and isinstance(s[0][0], Sequence | Mapping)
):
self.extend(
cast(Iterable[FilterTuple | FilterTupleSpec | FilterMappingSpec], s[0]), doctype
)
else:
self.extend(cast(Iterable[FilterTuple | FilterTupleSpec | FilterMappingSpec], s), doctype)
else:
self.extend(cast(Iterable[FilterTuple | FilterTupleSpec | FilterMappingSpec], s), doctype)
except Exception as e:
error_lines = str(e).split("\n")
indented_error = error_lines[0] + "\n" + textwrap.indent("\n".join(error_lines[1:]), " ")
error_context = (
f"\nError creating Filters:\n"
f"Input: {s}, doctype={doctype}\n"
f"Usage: Filters( FilterTuple(...), ... )\n"
f" Filters( (fieldnam, value), ... doctype=dt )\n"
f" Filters( (fieldname, operator, value), ... doctype=dt )\n"
f" Filters( (doctype, fieldname, operator, value), ... )\n"
f" Filters( {{'fieldname': value, ...}}, ... doctype=dt )\n"
f" Filters( {{'fieldname': (operator, value), ...}}, ... doctype=dt )\n\n"
f"Cause:\n{indented_error}"
)
raise ValueError(error_context) from e
if self: # only optimize non-empty; avoid infinit recursion
self.optimize()
@override
def extend(
self,
values: Iterable[FilterTuple | FilterTupleSpec | FilterMappingSpec],
doctype: Doct | Sentinel = UNSPECIFIED,
) -> None:
for item in values:
self.append(item, doctype)
@override
def append(
self, value: FilterTuple | FilterTupleSpec | FilterMappingSpec, doctype: Doct | Sentinel = UNSPECIFIED
) -> None:
if isinstance(value, FilterTuple):
super().append(value)
elif isinstance(value, Mapping):
if is_unspecified(doctype) or doctype is None:
raise ValueError("When initiated with a mapping, doctype keyword argument is required")
self._init_from_mapping(value, doctype)
elif isinstance(value, Sequence): # type: ignore[redundant-expr]
super().append(FilterTuple(value, doctype=doctype))
else:
raise TypeError(f"Expected FilterTruple, Mapping or Sequence, got {type(value).__name__}")
def _init_from_mapping(self, s: FilterMappingSpec, doctype: Doct) -> None:
for k, v in s.items():
if isinstance(v, _InputValue):
self.append(FilterTuple(doctype=doctype, fieldname=k, value=v))
elif isinstance(v, Sequence): # type: ignore[redundant-expr]
self.append(FilterTuple(doctype=doctype, fieldname=k, operator=v[0], value=v[1]))
else:
raise ValueError(f"Invalid value for key '{k}': expected value or (operator, value[s]) tuple")
def optimize(self) -> None:
"""Optimize the filters by grouping '=' operators into 'in' operators where possible."""
def group_key(f: FilterTuple) -> tuple[str, str, bool]:
return (f.doctype, f.fieldname, f.operator == "=")
optimized = Filters()
for (doctype, fieldname, collatable), filters in groupby(sorted(self, key=group_key), key=group_key):
if not collatable:
optimized.extend(filters)
else:
def _values() -> Generator[_Value, None, None]:
for f in filters:
# f.value is already narrowed to Val when we optimize over fully initialized FilterTuple
yield cast(_Value, f.value) # = operator only is allowed to have _Value
values = tuple(_values())
_op = "in" if len(values) > 1 else "="
optimized.append(
FilterTuple(
doctype=doctype,
fieldname=fieldname,
operator=_op,
value=values if _op == "in" else values[0],
)
)
self[:] = optimized
@override
def __str__(self) -> str:
if not self:
return "Filters()"
filters_str = "\n".join(f" {filter}" for filter in self)
return f"Filters(\n{filters_str}\n)"
FilterSignature: TypeAlias = Filters | FilterTuple | FilterMappingSpec | FilterTupleSpec

View file

@ -25,6 +25,7 @@ from dateutil.relativedelta import relativedelta
import frappe
from frappe.desk.utils import slug
from frappe.locale import get_date_format, get_first_day_of_the_week, get_number_format, get_time_format
from frappe.types.filter import Filters, FilterSignature, FilterTuple
from frappe.utils.deprecations import deprecated
from frappe.utils.number_format import NUMBER_FORMAT_MAP, NumberFormat
@ -52,6 +53,8 @@ TimespanOptions = Literal[
if typing.TYPE_CHECKING:
from collections.abc import Mapping
from PIL.ImageFile import ImageFile as PILImageFile
T = TypeVar("T")
@ -1984,19 +1987,14 @@ operator_map = {
}
def evaluate_filters(doc, filters: dict | list | tuple):
def evaluate_filters(doc: "Mapping", filters: FilterSignature):
"""Return True if doc matches filters."""
if isinstance(filters, dict):
for key, value in filters.items():
f = get_filter(None, {key: value})
if not compare(doc.get(f.fieldname), f.operator, f.value, f.fieldtype):
return False
elif isinstance(filters, list | tuple):
for d in filters:
f = get_filter(None, d)
if not compare(doc.get(f.fieldname), f.operator, f.value, f.fieldtype):
return False
if not isinstance(filters, Filters):
filters = Filters(filters, doctype=doc.get("doctype"))
for d in filters:
f = get_filter(None, d)
if not compare(doc.get(f.fieldname), f.operator, f.value, f.fieldtype):
return False
return True
@ -2011,7 +2009,7 @@ def compare(val1: Any, condition: str, val2: Any, fieldtype: str | None = None):
return False
def get_filter(doctype: str, f: dict | list | tuple, filters_config=None) -> "frappe._dict":
def get_filter(doctype: str, filters: FilterSignature, filters_config=None) -> "frappe._dict":
"""Return a `_dict` like:
{
@ -2025,30 +2023,18 @@ def get_filter(doctype: str, f: dict | list | tuple, filters_config=None) -> "fr
from frappe.database.utils import NestedSetHierarchy
from frappe.model import child_table_fields, default_fields, optional_fields
if isinstance(f, dict):
key, value = next(iter(f.items()))
f = make_filter_tuple(doctype, key, value)
ft: FilterTuple
if isinstance(filters, FilterTuple):
ft = filters
elif not isinstance(filters, Filters):
ft = Filters(filters, doctype=doctype)[0]
else:
ft = filters[0]
if not isinstance(f, list | tuple):
frappe.throw(frappe._("Filter must be a tuple or list (in a list)"))
if len(f) == 3:
f = (doctype, f[0], f[1], f[2])
elif len(f) > 4:
f = f[0:4]
elif len(f) != 4:
frappe.throw(
frappe._("Filter must have 4 values (doctype, fieldname, operator, value): {0}").format(str(f))
)
f = frappe._dict(doctype=f[0], fieldname=f[1], operator=f[2], value=f[3])
f = frappe._dict(doctype=ft[0], fieldname=ft[1], operator=ft[2], value=ft[3])
sanitize_column(f.fieldname)
if not f.operator:
# if operator is missing
f.operator = "="
valid_operators = (
"=",
"!=",

View file

@ -124,6 +124,7 @@ dev = [
"types-six",
"types-vobject",
"types-zxcvbn",
"pypika-stubs", # contributed
]
test = [
"unittest-xml-reporting~=3.2.0",
@ -224,6 +225,7 @@ files = [
"frappe/types/DF.py",
"frappe/types/docref.py",
"frappe/types/frappedict.py",
"frappe/types/filter.py",
]
exclude = [
# permanent excludes