diff --git a/frappe/database/query.py b/frappe/database/query.py index ad7b3f83b4..6297e297a4 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -1,9 +1,11 @@ import operator import re +from functools import cached_property from typing import Any, Dict, List, Tuple, Union import frappe from frappe import _ +from frappe.boot import get_additional_filters_from_hooks from frappe.model.db_query import get_timespan_date_range from frappe.query_builder import Criterion, Field, Order, Table @@ -138,6 +140,7 @@ def change_orderby(order: str): return order[0], Order.desc +# default operators OPERATOR_MAP: Dict[str, "function"] = { "+": operator.add, "=": operator.eq, @@ -157,12 +160,35 @@ OPERATOR_MAP: Dict[str, "function"] = { "between": func_between, "is": func_is, "timespan": func_timespan, + # TODO: Add support for nested set + # TODO: Add support for custom operators (WIP) - via filters_config hooks } class Query: tables: dict = {} + @cached_property + def OPERATOR_MAP(self): + # default operators + all_operators = OPERATOR_MAP.copy() + + # update with site-specific custom operators + additional_filters_config = get_additional_filters_from_hooks() + + if additional_filters_config: + from frappe.utils.commands import warn + + warn("'filters_config' hook is not completely implemented yet in frappe.db.query engine") + + for operator, function in additional_filters_config.items(): + if callable(function): + all_operators.update({operator.casefold(): function}) + elif isinstance(function, dict): + all_operators[operator.casefold()] = frappe.get_attr(function.get("get_field"))()["operator"] + + return all_operators + def get_condition(self, table: Union[str, Table], **kwargs) -> frappe.qb: """Get initial table object @@ -243,14 +269,14 @@ class Query: if isinstance(filters, list): for f in filters: if not isinstance(f, (list, tuple)): - _operator = OPERATOR_MAP[filters[1].casefold()] + _operator = self.OPERATOR_MAP[filters[1].casefold()] if not isinstance(filters[0], str): conditions = make_function(filters[0], filters[2]) break conditions = conditions.where(_operator(Field(filters[0]), filters[2])) break else: - _operator = OPERATOR_MAP[f[-2].casefold()] + _operator = self.OPERATOR_MAP[f[-2].casefold()] if len(f) == 4: table_object = self.get_table(f[0]) _field = table_object[f[1]] @@ -279,13 +305,13 @@ class Query: for key in filters: value = filters.get(key) - _operator = OPERATOR_MAP["="] + _operator = self.OPERATOR_MAP["="] if not isinstance(key, str): conditions = conditions.where(make_function(key, value)) continue if isinstance(value, (list, tuple)): - _operator = OPERATOR_MAP[value[0].casefold()] + _operator = self.OPERATOR_MAP[value[0].casefold()] conditions = conditions.where(_operator(Field(key), value[1])) else: if value is not None: diff --git a/frappe/utils/commands.py b/frappe/utils/commands.py index a610872f03..bbc09b3034 100644 --- a/frappe/utils/commands.py +++ b/frappe/utils/commands.py @@ -59,7 +59,7 @@ def log(message, colour=""): print(colour + message + end_line) -def warn(message, category=None): +def warn(message, category=None, stacklevel=2): from warnings import warn - warn(message=message, category=category, stacklevel=2) + warn(message=message, category=category, stacklevel=stacklevel)