From 726fcfdb796407e71c7a459a255b201417ef8c87 Mon Sep 17 00:00:00 2001 From: Faris Ansari Date: Sun, 25 Dec 2022 23:19:11 +0530 Subject: [PATCH] refactor: qb.engine - simplify - qb.engine.get_query -> qb.get_query - qb.engine.build_conditions -> qb.get_query --- frappe/__init__.py | 4 +- frappe/database/database.py | 22 +- frappe/database/query.py | 707 +++++++----------- .../desk/doctype/number_card/number_card.py | 2 +- frappe/desk/listview.py | 2 +- frappe/query_builder/__init__.py | 2 +- frappe/query_builder/functions.py | 2 +- frappe/query_builder/utils.py | 5 +- frappe/tests/test_db_query.py | 4 +- frappe/tests/test_query.py | 88 +-- frappe/utils/goal.py | 2 +- 11 files changed, 347 insertions(+), 493 deletions(-) diff --git a/frappe/__init__.py b/frappe/__init__.py index 2d491ca068..6b9d157003 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -23,7 +23,7 @@ import click from werkzeug.local import Local, release_local from frappe.query_builder import ( - get_qb_engine, + get_query, get_query_builder, patch_query_aggregation, patch_query_execute, @@ -244,7 +244,7 @@ def init(site: str, sites_path: str = ".", new_site: bool = False) -> None: local.session = _dict() local.dev_server = _dev_server local.qb = get_query_builder(local.conf.db_type or "mariadb") - local.qb.engine = get_qb_engine() + local.qb.get_query = get_query setup_module_map() if not _qb_patched.get(local.conf.db_type): diff --git a/frappe/database/database.py b/frappe/database/database.py index dfcc9dfe58..acbd28c9d7 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -620,7 +620,7 @@ class Database: return [map(values.get, fields)] else: - r = frappe.qb.engine.get_query( + r = frappe.qb.get_query( "Singles", filters={"field": ("in", tuple(fields)), "doctype": doctype}, fields=["field", "value"], @@ -653,7 +653,7 @@ class Database: # Get coulmn and value of the single doctype Accounts Settings account_settings = frappe.db.get_singles_dict("Accounts Settings") """ - queried_result = frappe.qb.engine.get_query( + queried_result = frappe.qb.get_query( "Singles", filters={"doctype": doctype}, fields=["field", "value"], @@ -726,7 +726,7 @@ class Database: if cache and fieldname in self.value_cache[doctype]: return self.value_cache[doctype][fieldname] - val = frappe.qb.engine.get_query( + val = frappe.qb.get_query( table="Singles", filters={"doctype": doctype, "field": fieldname}, fields="value", @@ -766,10 +766,10 @@ class Database: distinct=False, limit=None, ): - query = frappe.qb.engine.get_query( + query = frappe.qb.get_query( table=doctype, filters=filters, - orderby=order_by, + order_by=order_by, for_update=for_update, fields=fields, distinct=distinct, @@ -795,7 +795,7 @@ class Database: as_dict=False, ): if names := list(filter(None, names)): - return frappe.qb.engine.get_query( + return frappe.qb.get_query( doctype, fields=field, filters=names, @@ -852,7 +852,7 @@ class Database: frappe.clear_document_cache(dt, dt) else: - query = frappe.qb.engine.build_conditions(table=dt, filters=dn, update=True) + query = frappe.qb.get_query(table=dt, filters=dn, update=True) if isinstance(dn, str): frappe.clear_document_cache(dt, dn) @@ -1017,9 +1017,9 @@ class Database: cache_count = frappe.cache().get_value(f"doctype:count:{dt}") if cache_count is not None: return cache_count - count = frappe.qb.engine.get_query( - table=dt, filters=filters, fields=Count("*"), distinct=distinct - ).run(debug=debug)[0][0] + count = frappe.qb.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct).run( + debug=debug + )[0][0] if not filters and cache: frappe.cache().set_value(f"doctype:count:{dt}", count, expires_in_sec=86400) return count @@ -1160,7 +1160,7 @@ class Database: Doctype name can be passed directly, it will be pre-pended with `tab`. """ filters = filters or kwargs.get("conditions") - query = frappe.qb.engine.build_conditions(table=doctype, filters=filters).delete() + query = frappe.qb.get_query(table=doctype, filters=filters).delete() if "debug" not in kwargs: kwargs["debug"] = debug return query.run(**kwargs) diff --git a/frappe/database/query.py b/frappe/database/query.py index a9dab02744..88de7f7088 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Callable import sqlparse from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder +from pypika.queries import QueryBuilder import frappe from frappe import _ @@ -171,18 +172,16 @@ def table_from_string(table: str) -> "DocType": return frappe.qb.DocType(table_name=table_name) -def get_nested_set_hierarchy_result(hierarchy: str, field: str, table: str): - ref_doctype = table +def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str): + table = frappe.qb.DocType(doctype) try: - lft, rgt = ( - frappe.qb.from_(ref_doctype).select("lft", "rgt").where(Field("name") == field).run()[0] - ) + lft, rgt = frappe.qb.from_(table).select("lft", "rgt").where(table.name == name).run()[0] except IndexError: lft, rgt = None, None if hierarchy in ("descendants of", "not descendants of"): result = ( - frappe.qb.from_(ref_doctype) + frappe.qb.from_(table) .select(Field("name")) .where(Field("lft") > lft) .where(Field("rgt") < rgt) @@ -192,7 +191,7 @@ def get_nested_set_hierarchy_result(hierarchy: str, field: str, table: str): else: # Get ancestor elements of a DocType with a tree structure result = ( - frappe.qb.from_(ref_doctype) + frappe.qb.from_(table) .select(Field("name")) .where(Field("lft") < lft) .where(Field("rgt") > rgt) @@ -232,37 +231,67 @@ OPERATOR_MAP: dict[str, Callable] = { class Engine: tables: dict[str, str] = {} - @cached_property - def OPERATOR_MAP(self): - # default operators - all_operators = OPERATOR_MAP.copy() + def get_query( + self, + table: str, + fields: list | tuple | None = None, + filters: dict[str, str | int] | str | int | list[list | str | int] | None = None, + pluck: str | None = None, + order_by: str | None = None, + group_by: str | None = None, + limit: int | None = None, + offset: int | None = None, + distinct: bool = False, + for_update: bool = False, + update: bool = False, + into: bool = False, + ) -> MySQLQueryBuilder | PostgreSQLQueryBuilder: + # Clean up state before each query + self.is_mariadb = frappe.db.db_type == "mariadb" + self.is_postgres = frappe.db.db_type == "postgres" + self.tables = {} + self.implicit_joins = set() - # TODO: update with site-specific custom operators / removed previous buggy implementation - if frappe.get_hooks("filters_config"): - from frappe.utils.commands import warn + self.doctype = table + self.table = self.get_table(table) - warn( - "The 'filters_config' hook used to add custom operators is not yet implemented" - " in frappe.db.query engine. Use db_query (frappe.get_list) instead." - ) + if update: + self.query = frappe.qb.update(self.table) + elif into: + self.query = frappe.qb.into(self.table) + else: + self.query = frappe.qb.from_(self.table) - return all_operators + self.fields = self.parse_fields(fields) + if not self.fields: + self.fields = [getattr(self.table, pluck or "name")] - def get_condition(self, table: str | Table, **kwargs) -> frappe.qb: - """Get initial table object + for field in self.fields: + if isinstance(field, DynamicTableField): + self.query = field.apply(self.query) + else: + self.query = self.query.select(field) - Args: - table (str): DocType + self.apply_filters(filters) + self.apply_implicit_joins() + self.apply_order_by(order_by) - Returns: - frappe.qb: DocType with initial condition - """ - table_object = self.get_table(table) - if kwargs.get("update"): - return frappe.qb.update(table_object) - if kwargs.get("into"): - return frappe.qb.into(table_object) - return frappe.qb.from_(table_object) + if limit: + self.query = self.query.limit(limit) + + if offset: + self.query = self.query.offset(offset) + + if distinct: + self.query = self.query.distinct() + + if for_update: + self.query = self.query.for_update() + + if group_by: + self.query = self.query.groupby(group_by) + + return self.query def get_table(self, table_name: str | Table) -> Table: if isinstance(table_name, Table): @@ -272,178 +301,93 @@ class Engine: self.tables[table_name] = frappe.qb.DocType(table_name) return self.tables[table_name] - def criterion_query(self, table: str, criterion: Criterion, **kwargs) -> frappe.qb: - """Generate filters from Criterion objects - - Args: - table (str): DocType - criterion (Criterion): Filters - - Returns: - frappe.qb: condition object - """ - condition = self.add_conditions(self.get_condition(table, **kwargs), **kwargs) - return condition.where(criterion) - - def add_conditions(self, conditions: frappe.qb, **kwargs): - """Adding additional conditions - - Args: - conditions (frappe.qb): built conditions - - Returns: - conditions (frappe.qb): frappe.qb object - """ - if kwargs.get("orderby") and kwargs.get("orderby") != "KEEP_DEFAULT_ORDERING": - orderby = kwargs.get("orderby") - if isinstance(orderby, str) and len(orderby.split()) > 1: - for ordby in orderby.split(","): - if ordby := ordby.strip(): - orderby, order = change_orderby(ordby) - conditions = conditions.orderby(orderby, order=order) - else: - conditions = conditions.orderby(orderby, order=kwargs.get("order") or Order.desc) - - if kwargs.get("limit"): - conditions = conditions.limit(kwargs.get("limit")) - conditions = conditions.offset(kwargs.get("offset", 0)) - - if kwargs.get("distinct"): - conditions = conditions.distinct() - - if kwargs.get("for_update"): - conditions = conditions.for_update() - - if kwargs.get("groupby"): - conditions = conditions.groupby(kwargs.get("groupby")) - - return conditions - - def misc_query(self, table: str, filters: list | tuple = None, **kwargs): - """Build conditions using the given Lists or Tuple filters - - Args: - table (str): DocType - filters (Union[List, Tuple], optional): Filters. Defaults to None. - """ - conditions = self.get_condition(table, **kwargs) + def apply_filters( + self, filters: dict[str, str | int | list] | str | int | list[list] | None = None + ): if not filters: - return conditions - if isinstance(filters, list): - for f in filters: - if isinstance(f, (list, tuple)): - _operator = self.OPERATOR_MAP[f[-2].casefold()] - if len(f) == 4: - table_object = self.get_table(f[0]) - _field = table_object[f[1]] - else: - _field = Field(f[0]) - conditions = conditions.where(_operator(_field, f[-1])) - elif isinstance(f, dict): - conditions = self.dict_query(table, f, **kwargs) - else: - _operator = self.OPERATOR_MAP[filters[1].casefold()] - if not isinstance(filters[0], str): - conditions = self.make_function_for_filters(filters[0], filters[2]) - break - conditions = conditions.where(_operator(Field(filters[0]), filters[2])) - break + return - return self.add_conditions(conditions, **kwargs) - - def dict_query(self, table: str, filters: dict[str, str | int] = None, **kwargs) -> frappe.qb: - """Build conditions using the given dictionary filters - - Args: - table (str): DocType - filters (Dict[str, Union[str, int]], optional): Filters. Defaults to None. - - Returns: - frappe.qb: conditions object - """ - conditions = self.get_condition(table, **kwargs) - if isinstance(table, str): - table = frappe.qb.DocType(table) - if not filters: - conditions = self.add_conditions(conditions, **kwargs) - return conditions - - for key, value in filters.items(): - if isinstance(value, bool): - filters.update({key: str(int(value))}) - - filters = { - (self.get_function_object(k) if has_function(k) else k): v for k, v in filters.items() - } - for key in filters: - value = filters.get(key) - _operator = self.OPERATOR_MAP["="] - - if not isinstance(key, str): - conditions = conditions.where(self.make_function_for_filters(key, value)) - continue - # Nested set support - if isinstance(value, (list, tuple)): - if value[0] in self.OPERATOR_MAP["nested_set"]: - hierarchy, _field = value - result = get_nested_set_hierarchy_result(hierarchy, _field, table) - _operator = ( - self.OPERATOR_MAP["not in"] - if hierarchy in ("not ancestors of", "not descendants of") - else self.OPERATOR_MAP["in"] - ) - if result: - result = list(itertools.chain.from_iterable(result)) - conditions = conditions.where(_operator(getattr(table, key), result)) - else: - conditions = conditions.where(_operator(getattr(table, key), ("",))) - # Allow additional conditions - break - - _operator = self.OPERATOR_MAP[value[0].casefold()] - _value = value[1] if value[1] else ("",) - conditions = conditions.where(_operator(getattr(table, key), _value)) - else: - if value is not None: - conditions = conditions.where(_operator(getattr(table, key), value)) - else: - _table = conditions._from[0] - field = getattr(_table, key) - conditions = conditions.where(field.isnull()) - - return self.add_conditions(conditions, **kwargs) - - def build_conditions( - self, table: str, filters: dict[str, str | int] | str | int = None, **kwargs - ) -> frappe.qb: - """Build conditions for sql query - - Args: - filters (Union[Dict[str, Union[str, int]], str, int]): conditions in Dict - table (str): DocType - - Returns: - frappe.qb: frappe.qb conditions object - """ - if isinstance(filters, int) or isinstance(filters, str): + if isinstance(filters, (str, int)): filters = {"name": str(filters)} if isinstance(filters, Criterion): - criterion = self.criterion_query(table, filters, **kwargs) + self.query = self.query.where(filters) + + elif isinstance(filters, dict): + self.apply_dict_filters(filters) elif isinstance(filters, (list, tuple)): - criterion = self.misc_query(table, filters, **kwargs) + self.apply_list_filters(filters) + def apply_list_filters(self, filters: list[list]): + for filter in filters: + if len(filter) == 2: + field, value = filter + self._apply_filter(field, value) + elif len(filter) == 3: + field, operator, value = filter + self._apply_filter(field, value, operator) + elif len(filter) == 4: + doctype, field, operator, value = filter + self._apply_filter(field, value, operator, doctype) + + def apply_dict_filters(self, filters: dict[str, str | int | list]): + for key in filters: + value = filters.get(key) + self._apply_filter(key, value) + + def _apply_filter( + self, field: str, value: str | int | list | None, operator: str = "=", doctype: str | None = None + ): + _field = field + _value = value + _operator = operator + + if has_function(field): + _field = self.get_function_object(field) + elif not doctype or doctype == self.doctype: + _field = self.table[field] + elif doctype: + _field = self.get_table(doctype)[field] + + # keep track of implicit join if child table is referenced + if doctype and doctype != self.doctype: + meta = frappe.get_meta(doctype) + if meta.istable: + self.implicit_joins.add((doctype, "child")) + + if isinstance(_value, (str, int)): + _value = str(_value) + elif isinstance(_value, (list, tuple)): + _operator, _value = _value + elif isinstance(_value, bool): + _value = int(_value) + + if isinstance(_value, str) and has_function(_value): + _value = self.get_function_object(_value) + + # Nested set + if _operator in self.OPERATOR_MAP["nested_set"]: + hierarchy = _operator + docname = _value + result = get_nested_set_hierarchy_result(self.doctype, docname, hierarchy) + operator_fn = ( + self.OPERATOR_MAP["not in"] + if hierarchy in ("not ancestors of", "not descendants of") + else self.OPERATOR_MAP["in"] + ) + if result: + result = list(itertools.chain.from_iterable(result)) + self.query = self.query.where(operator_fn(_field, result)) + else: + self.query = self.query.where(operator_fn(_field, ("",))) + return + + operator_fn = self.OPERATOR_MAP[_operator.casefold()] + if _value is None and isinstance(_field, Field): + self.query = self.query.where(_field.isnull()) else: - criterion = self.dict_query(filters=filters, table=table, **kwargs) - - return criterion - - def make_function_for_filters(self, key, value: int | str): - value = list(value) - if isinstance(value[1], str) and has_function(value[1]): - value[1] = self.get_function_object(value[1]) - return OPERATOR_MAP[value[0].casefold()](key, value[1]) + self.query = self.query.where(operator_fn(_field, _value)) def get_function_object(self, field: str) -> "Function": """Expects field to look like 'SUM(*)' or 'name' or something similar. Returns PyPika Function object""" @@ -495,84 +439,12 @@ class Engine: # Fall back for functions not present in `SqlFunctions`` return Function(func, *_args, alias=alias or None) - def function_objects_from_string(self, fields): - fields = list(map(lambda str: str.strip(), COMMA_PATTERN.split(fields))) - return self.function_objects_from_list(fields=fields) - - def function_objects_from_list(self, fields): - functions = [] - for field in fields: - field = field.casefold() if (isinstance(field, str) and "`" not in field) else field - if not issubclass(type(field), Criterion): - if any([f"{func}(" in field for func in SQL_FUNCTIONS]) or "(" in field: - functions.append(field) - - return [self.get_function_object(function) for function in functions] - - def remove_string_functions(self, fields, function_objects): - """Remove string functions from fields which have already been converted to function objects""" - - def _remove_string_aliasing(function, fields: list | str): - if function.alias: - to_replace = " as " + function.alias.casefold() - if to_replace in fields: - fields = fields.replace(to_replace, "") - elif " as " + f"`{function.alias.casefold()}" in fields: - fields = fields.replace(" as " + f"`{function.alias.casefold()}`", "") - return fields - - for function in function_objects: - if isinstance(fields, str): - fields = _remove_string_aliasing(function, fields) - fields = BRACKETS_PATTERN.sub("", re.sub(function.name, "", fields, flags=re.IGNORECASE)) - # Check if only comma is left in fields after stripping functions. - if "," in fields and (len(fields.strip()) == 1): - fields = "" - else: - updated_fields = [] - for field in fields: - if isinstance(field, str): - field = _remove_string_aliasing(function, field) - substituted_string = ( - BRACKETS_PATTERN.sub("", field).strip().casefold() - if "`" not in field - else BRACKETS_PATTERN.sub("", field).strip() - ) - # This is done to avoid casefold of table name. - if substituted_string.casefold() == function.name.casefold(): - replaced_string = substituted_string.casefold().replace(function.name.casefold(), "") - else: - replaced_string = substituted_string.replace(function.name.casefold(), "") - updated_fields.append(replaced_string) - fields = [field for field in updated_fields if field] - return fields - - def get_fieldnames_from_child_table(self, doctype, fields): - # Hacky and flaky implementation of implicit joins. - # convert child_table.fieldname to `tabChild DocType`.`fieldname` - _fields = [] - for field in fields: - if "." in field and "tab" not in field: - alias = None - if " as " in field: - field, alias = field.split(" as ") - fieldname, linked_fieldname = field.split(".") - linked_doctype = frappe.get_meta(doctype).get_field(fieldname).options - - field = f"`tab{linked_doctype}`.`{linked_fieldname}`" - if alias: - field = f"{field} {alias}" - _fields.append(field) - return _fields - def sanitize_fields(self, fields: str | list | tuple): - is_mariadb = frappe.db.db_type == "mariadb" - def _sanitize_field(field: str): if not isinstance(field, str): return field stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower") - if is_mariadb: + if self.is_mariadb: return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field) return stripped_field @@ -583,174 +455,88 @@ class Engine: return fields - def get_list_fields(self, table: str, fields: list) -> list: - updated_fields = [] - if issubclass(type(fields), Criterion) or "*" in fields: - return fields - fields = self.get_fieldnames_from_child_table(doctype=table, fields=fields) - for field in fields: - if not isinstance(field, Criterion) and field: - if " as " in field: - field, reference = field.split(" as ") - if "`" in field: - updated_fields.append(PseudoColumnMapper(f"{field} {reference}")) - else: - updated_fields.append(Field(field.strip()).as_(reference)) - elif "`" in str(field): - updated_fields.append(PseudoColumnMapper(field.strip())) - else: - updated_fields.append(Field(field)) - return updated_fields + def parse_string_field(self, field: str): + if field == "*": + return self.table.star + alias = None + if " as " in field: + field, alias = field.split(" as ") + if "`" in field: + if alias: + return PseudoColumnMapper(f"{field} {alias}") + return PseudoColumnMapper(field) + if alias: + return self.table[field].as_(alias) + return self.table[field] - def get_string_fields(self, fields: str) -> Field: - if fields == "*": - return fields - if "`" in fields: - fields = PseudoColumnMapper(fields) - if " as " in str(fields): - fields, reference = str(fields).split(" as ") - if "`" in str(fields): - fields = PseudoColumnMapper(f"{fields} {reference}") - else: - fields = Field(fields).as_(reference) - return fields - - def set_fields(self, table: str, fields, **kwargs) -> list: - fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" + def parse_fields(self, fields: str | list | tuple | None) -> list: + if not fields: + return [] fields = self.sanitize_fields(fields) - if isinstance(fields, list) and None in fields and Field not in fields: - return None - function_objects = [] - is_list = isinstance(fields, (list, tuple, set)) - if is_list and len(fields) == 1: - fields = fields[0] - is_list = False + if isinstance(fields, (list, tuple, set)) and None in fields and Field not in fields: + return [] - if is_list: - function_objects += self.function_objects_from_list(fields=fields) + if not isinstance(fields, (list, tuple)): + fields = [fields] - is_str = isinstance(fields, str) - if is_str: - fields = fields.casefold() if "`" not in fields else fields - function_objects += self.function_objects_from_string(fields=fields) - - fields = self.remove_string_functions(fields, function_objects) - - if is_str and "," in fields: - fields = [field.replace(" ", "") if "as" not in field else field for field in fields.split(",")] - is_list, is_str = True, False - - if is_str: - fields = self.get_string_fields(fields) - if not is_str and fields: - fields = self.get_list_fields(table, fields) - - # Need to check instance again since fields modified. - if not isinstance(fields, (list, tuple, set)): - fields = [fields] if fields else [] - - fields.extend(function_objects) - return fields - - def join_child_tables( - self, - criterion: Criterion, - join_type: str, - child_table: Table, - parent_table: Table, - ) -> Criterion: - if self.joined_tables.get(join_type) != child_table: - criterion = getattr(criterion, join_type)(child_table).on( - (child_table.parent == parent_table.name) - & (child_table.parenttype == TAB_PATTERN.sub("", parent_table._table_name)) - ) - self.joined_tables[join_type] = child_table - return criterion - - def join(self, criterion, fields, table, join_type): - """Handles all join operations on criterion objects""" - has_join = False - table_pattern = ( - re.compile(r"`\btab\w+") if frappe.db.db_type == "mariadb" else re.compile(r'"\btab\w+') - ) - - def _update_pypika_fields(field): - if not is_pypika_function_object(field): - field = field if isinstance(field, (str, PseudoColumnMapper)) else field.get_sql() - if not table_pattern.search(str(field)): - if isinstance(field, PseudoColumnMapper): - field = field.get_sql() - return getattr(frappe.qb.DocType(table), field) - else: - return field + def parse_field(field: str): + if has_function(field): + return self.get_function_object(field) + elif parsed := DynamicTableField.parse(field, self.doctype): + return parsed else: - field.args = [getattr(frappe.qb.DocType(table), arg.get_sql()) for arg in field.args] - return field + return self.parse_string_field(field) - if not isinstance(fields, Criterion): - for field in fields: - # Only perform this bit if foreign doctype in fields - if ( - not is_pypika_function_object(field) - and (str(field).startswith('"tab') or str(field).startswith("`tab")) - and (f"`tab{table}`" not in str(field) and f'tab{table}"' not in str(field)) - ): - has_join = True - child_table = table_from_string(str(field)) - parent_table = frappe.qb.DocType(table) if not isinstance(table, Table) else table - criterion = self.join_child_tables( - criterion=criterion, - join_type=join_type, - child_table=child_table, - parent_table=parent_table, - ) + _fields = [] + for field in fields: + if isinstance(field, Criterion): + _fields.append(field) + elif isinstance(field, str): + if "," in field: + field = field.casefold() if "`" not in field else field + field_list = COMMA_PATTERN.split(field) + for field in field_list: + if _field := field.strip(): + _fields.append(parse_field(_field)) + else: + _fields.append(parse_field(field)) - if has_join: - fields = [_update_pypika_fields(field) for field in fields] + return _fields - if len(self.tables) > 1: - parent_table = self.tables[table] - child_tables = list(self.tables.values())[1:] - for child_table in child_tables: - criterion = self.join_child_tables( - criterion, - join_type=join_type, - child_table=child_table, - parent_table=parent_table, + def apply_implicit_joins(self): + for d in self.implicit_joins: + doctype, join_type = d + table = self.get_table(doctype) + if join_type == "child": + self.query = self.query.left_join(table).on( + (table.parent == self.table.name) & (table.parenttype == self.doctype) ) - return criterion, fields + def apply_order_by(self, order_by: str | None): + if not order_by or order_by == "KEEP_DEFAULT_ORDERING": + return + for declaration in order_by.split(","): + if _order_by := declaration.strip(): + parts = _order_by.split(" ") + order_field, order_direction = parts[0], parts[1] if len(parts) > 1 else "asc" + order_direction = Order.asc if order_direction.lower() == "asc" else Order.desc + self.query = self.query.orderby(order_field, order=order_direction) - def get_query( - self, - table: str, - fields: list | tuple, - filters: dict[str, str | int] | str | int | list[list | str | int] = None, - **kwargs, - ) -> MySQLQueryBuilder | PostgreSQLQueryBuilder: - # Clean up state before each query - self.tables = {} - self.joined_tables = {} - self.linked_doctype = None - self.fieldname = None + @cached_property + def OPERATOR_MAP(self): + # default operators + all_operators = OPERATOR_MAP.copy() - criterion = self.build_conditions(table, filters, **kwargs) - fields = self.set_fields(table, fields, **kwargs) - join_type = kwargs.get("join").replace(" ", "_") if kwargs.get("join") else "left_join" - criterion, fields = self.join( - criterion=criterion, fields=fields, table=table, join_type=join_type - ) + # TODO: update with site-specific custom operators / removed previous buggy implementation + if frappe.get_hooks("filters_config"): + from frappe.utils.commands import warn - if isinstance(fields, (list, tuple)): - query = criterion.select(*fields) + warn( + "The 'filters_config' hook used to add custom operators is not yet implemented" + " in frappe.db.query engine. Use db_query (frappe.get_list) instead." + ) - elif isinstance(fields, Criterion): - query = criterion.select(fields) - - else: - query = criterion.select(fields) - - return query + return all_operators class Permission: @@ -781,3 +567,80 @@ class Permission: @staticmethod def get_tables_from_query(query: str): return [table for table in WORDS_PATTERN.findall(query) if table.startswith("tab")] + + +class DynamicTableField: + def __init__( + self, + doctype: str, + fieldname: str, + parent_doctype: str, + alias: str | None = None, + ) -> None: + self.doctype = doctype + self.fieldname = fieldname + self.alias = alias + self.parent_doctype = parent_doctype + + def __str__(self) -> str: + table_name = f"`tab{self.doctype}`" + fieldname = f"`{self.fieldname}`" + if frappe.db.db_type == "postgres": + table_name = table_name.replace("`", '"') + fieldname = fieldname.replace("`", '"') + alias = f"AS {self.alias}" if self.alias else "" + return f"{table_name}.{fieldname} {alias}".strip() + + @staticmethod + def parse(field: str, doctype: str): + if "." in field: + alias = None + if " as " in field: + field, alias = field.split(" as ") + if field.startswith("`tab") or field.startswith('"tab'): + _, child_doctype, child_field = re.search(r'([`"])tab(.+?)\1.\1(.+)\1', field).groups() + if child_doctype == doctype: + return + return ChildTableField(child_doctype, child_field, doctype, alias=alias) + else: + linked_fieldname, fieldname = field.split(".") + linked_field = frappe.get_meta(doctype).get_field(linked_fieldname) + linked_doctype = linked_field.options + if linked_field.fieldtype == "Link": + return LinkTableField(linked_doctype, fieldname, doctype, linked_fieldname, alias=alias) + elif linked_field.fieldtype in frappe.model.table_fields: + return ChildTableField(linked_doctype, fieldname, doctype, alias=alias) + + def apply(self, query: QueryBuilder) -> QueryBuilder: + raise NotImplementedError + + +class ChildTableField(DynamicTableField): + def apply(self, query: QueryBuilder) -> QueryBuilder: + table = frappe.qb.DocType(self.doctype) + main_table = frappe.qb.DocType(self.parent_doctype) + if not query.is_joined(table): + query = query.left_join(table).on( + (table.parent == main_table.name) & (table.parenttype == self.parent_doctype) + ) + return query.select(getattr(table, self.fieldname).as_(self.alias or None)) + + +class LinkTableField(DynamicTableField): + def __init__( + self, + doctype: str, + fieldname: str, + parent_doctype: str, + link_fieldname: str, + alias: str | None = None, + ) -> None: + super().__init__(doctype, fieldname, parent_doctype, alias=alias) + self.link_fieldname = link_fieldname + + def apply(self, query: QueryBuilder) -> QueryBuilder: + table = frappe.qb.DocType(self.doctype) + main_table = frappe.qb.DocType(self.parent_doctype) + if not query.is_joined(table): + query = query.left_join(table).on(table.name == getattr(main_table, self.link_fieldname)) + return query.select(getattr(table, self.fieldname).as_(self.alias or None)) diff --git a/frappe/desk/doctype/number_card/number_card.py b/frappe/desk/doctype/number_card/number_card.py index 8e808ff635..30ec99644a 100644 --- a/frappe/desk/doctype/number_card/number_card.py +++ b/frappe/desk/doctype/number_card/number_card.py @@ -200,7 +200,7 @@ def get_cards_for_user(doctype, txt, searchfield, start, page_len, filters): if txt: search_conditions = [numberCard[field].like(f"%{txt}%") for field in searchfields] - condition_query = frappe.qb.engine.build_conditions(doctype, filters) + condition_query = frappe.qb.get_query(doctype, filters) return ( condition_query.select(numberCard.name, numberCard.label, numberCard.document_type) diff --git a/frappe/desk/listview.py b/frappe/desk/listview.py index ea6eb6259c..8b514444df 100644 --- a/frappe/desk/listview.py +++ b/frappe/desk/listview.py @@ -36,7 +36,7 @@ def get_group_by_count(doctype: str, current_filters: str, field: str) -> list[d ToDo = DocType("ToDo") User = DocType("User") count = Count("*").as_("count") - filtered_records = frappe.qb.engine.build_conditions(doctype, current_filters).select("name") + filtered_records = frappe.qb.get_query(doctype, filters=current_filters).select("name") return ( frappe.qb.from_(ToDo) diff --git a/frappe/query_builder/__init__.py b/frappe/query_builder/__init__.py index eb1d9df08f..b1f242f78c 100644 --- a/frappe/query_builder/__init__.py +++ b/frappe/query_builder/__init__.py @@ -7,7 +7,7 @@ from frappe.query_builder.terms import ParameterizedFunction, ParameterizedValue from frappe.query_builder.utils import ( Column, DocType, - get_qb_engine, + get_query, get_query_builder, patch_query_aggregation, patch_query_execute, diff --git a/frappe/query_builder/functions.py b/frappe/query_builder/functions.py index 24e2ee0e5f..b1e4e7eff1 100644 --- a/frappe/query_builder/functions.py +++ b/frappe/query_builder/functions.py @@ -103,7 +103,7 @@ class Cast_(Function): def _aggregate(function, dt, fieldname, filters, **kwargs): return ( - frappe.qb.engine.build_conditions(dt, filters) + frappe.qb.get_query(dt, filters=filters) .select(function(PseudoColumn(fieldname))) .run(**kwargs)[0][0] or 0 diff --git a/frappe/query_builder/utils.py b/frappe/query_builder/utils.py index be0403a291..f80dd4fc33 100644 --- a/frappe/query_builder/utils.py +++ b/frappe/query_builder/utils.py @@ -3,6 +3,7 @@ from importlib import import_module from typing import Any, Callable, get_type_hints from pypika import Query +from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder from pypika.queries import Column from pypika.terms import PseudoColumn @@ -55,10 +56,10 @@ def get_query_builder(type_of_db: str) -> Postgres | MariaDB: return picks[db] -def get_qb_engine(): +def get_query(*args, **kwargs) -> MySQLQueryBuilder | PostgreSQLQueryBuilder: from frappe.database.query import Engine - return Engine() + return Engine().get_query(*args, **kwargs) def get_attr(method_string): diff --git a/frappe/tests/test_db_query.py b/frappe/tests/test_db_query.py index 162d3e9d8a..1a3e8735dc 100644 --- a/frappe/tests/test_db_query.py +++ b/frappe/tests/test_db_query.py @@ -229,9 +229,7 @@ class TestReportview(FrappeTestCase): ) def test_none_filter(self): - query = frappe.qb.engine.get_query( - "DocType", fields="name", filters={"restrict_to_domain": None} - ) + query = frappe.qb.get_query("DocType", fields="name", filters={"restrict_to_domain": None}) sql = str(query).replace("`", "").replace('"', "") condition = "restrict_to_domain IS NULL" self.assertIn(condition, sql) diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index 3f48882345..957d76d022 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -56,30 +56,28 @@ class TestQuery(FrappeTestCase): @run_only_if(db_type_is.MARIADB) def test_multiple_tables_in_filters(self): self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "DocType", ["*"], [ - ["BOM Update Log", "name", "like", "f%"], + ["DocField", "name", "like", "f%"], ["DocType", "parent", "=", "something"], ], ).get_sql(), - "SELECT * FROM `tabDocType` LEFT JOIN `tabBOM Update Log` ON `tabBOM Update Log`.`parent`=`tabDocType`.`name` AND `tabBOM Update Log`.`parenttype`='DocType' WHERE `tabBOM Update Log`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'", + "SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' WHERE `tabDocField`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'", ) @run_only_if(db_type_is.MARIADB) def test_string_fields(self): self.assertEqual( - frappe.qb.engine.get_query( - "User", fields="name, email", filters={"name": "Administrator"} - ).get_sql(), + frappe.qb.get_query("User", fields="name, email", filters={"name": "Administrator"}).get_sql(), frappe.qb.from_("User") .select(Field("name"), Field("email")) .where(Field("name") == "Administrator") .get_sql(), ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "User", fields=["`name`, `email`"], filters={"name": "Administrator"} ).get_sql(), frappe.qb.from_("User") @@ -89,7 +87,7 @@ class TestQuery(FrappeTestCase): ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "User", fields=["`tabUser`.`name`", "`tabUser`.`email`"], filters={"name": "Administrator"} ).run(), frappe.qb.from_("User") @@ -99,7 +97,7 @@ class TestQuery(FrappeTestCase): ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "User", fields=["`tabUser`.`name` as owner", "`tabUser`.`email`"], filters={"name": "Administrator"}, @@ -111,7 +109,7 @@ class TestQuery(FrappeTestCase): ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "User", fields=["`tabUser`.`name`, Count(`name`) as count"], filters={"name": "Administrator"} ).run(), frappe.qb.from_("User") @@ -121,7 +119,7 @@ class TestQuery(FrappeTestCase): ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "User", fields=["`tabUser`.`name`, Count(`name`) as `count`"], filters={"name": "Administrator"}, @@ -133,7 +131,7 @@ class TestQuery(FrappeTestCase): ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "User", fields="`tabUser`.`name`, Count(`name`) as `count`", filters={"name": "Administrator"} ).run(), frappe.qb.from_("User") @@ -144,38 +142,34 @@ class TestQuery(FrappeTestCase): def test_functions_fields(self): self.assertEqual( - frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(), + frappe.qb.get_query("User", fields="Count(name)", filters={}).get_sql(), frappe.qb.from_("User").select(Count(Field("name"))).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(), + frappe.qb.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(), frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_query( - "User", fields=["abs(name-email)", "Count(name)"], filters={} - ).get_sql(), + frappe.qb.get_query("User", fields=["abs(name-email)", "Count(name)"], filters={}).get_sql(), frappe.qb.from_("User") .select(Abs(Field("name") - Field("email")), Count(Field("name"))) .get_sql(), ) self.assertEqual( - frappe.qb.engine.get_query("User", fields=[Count("*")], filters={}).get_sql(), + frappe.qb.get_query("User", fields=[Count("*")], filters={}).get_sql(), frappe.qb.from_("User").select(Count("*")).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_query( - "User", fields="timestamp(creation, modified)", filters={} - ).get_sql(), + frappe.qb.get_query("User", fields="timestamp(creation, modified)", filters={}).get_sql(), frappe.qb.from_("User").select(Timestamp(Field("creation"), Field("modified"))).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "User", fields="Count(name) as count, Max(email) as max_email", filters={} ).get_sql(), frappe.qb.from_("User") @@ -186,85 +180,83 @@ class TestQuery(FrappeTestCase): def test_qb_fields(self): user_doctype = frappe.qb.DocType("User") self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( user_doctype, fields=[user_doctype.name, user_doctype.email], filters={} ).get_sql(), frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(), + frappe.qb.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(), frappe.qb.from_(user_doctype).select(user_doctype.email).get_sql(), ) def test_aliasing(self): user_doctype = frappe.qb.DocType("User") self.assertEqual( - frappe.qb.engine.get_query( - user_doctype, fields=["name as owner", "email as id"], filters={} - ).get_sql(), + frappe.qb.get_query("User", fields=["name as owner", "email as id"], filters={}).get_sql(), frappe.qb.from_(user_doctype) .select(user_doctype.name.as_("owner"), user_doctype.email.as_("id")) .get_sql(), ) self.assertEqual( - frappe.qb.engine.get_query( - user_doctype, fields="name as owner, email as id", filters={} - ).get_sql(), + frappe.qb.get_query(user_doctype, fields="name as owner, email as id", filters={}).get_sql(), frappe.qb.from_(user_doctype) .select(user_doctype.name.as_("owner"), user_doctype.email.as_("id")) .get_sql(), ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( user_doctype, fields=["Count(name) as count", "email as id"], filters={} ).get_sql(), frappe.qb.from_(user_doctype) - .select(user_doctype.email.as_("id"), Count(Field("name")).as_("count")) + .select(Count(Field("name")).as_("count"), user_doctype.email.as_("id")) .get_sql(), ) @run_only_if(db_type_is.MARIADB) def test_filters(self): self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "User", filters={"IfNull(name, " ")": ("<", Now())}, fields=["Max(name)"] ).run(), frappe.qb.from_("User").select(Max(Field("name"))).where(Ifnull("name", "") < Now()).run(), ) def test_implicit_join_query(self): + self.maxDiff = None + self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "Note", filters={"name": "Test Note Title"}, fields=["name", "`tabNote Seen By`.`user` as seen_by"], ).get_sql(), - "SELECT `tabNote`.`name`,`tabNote Seen By`.`user` seen_by FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace( + "SELECT `tabNote`.`name`,`tabNote Seen By`.`user` `seen_by` FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace( "`", '"' if frappe.db.db_type == "postgres" else "`" ), ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "Note", filters={"name": "Test Note Title"}, fields=["name", "`tabNote Seen By`.`user` as seen_by", "`tabNote Seen By`.`idx` as idx"], ).get_sql(), - "SELECT `tabNote`.`name`,`tabNote Seen By`.`user` seen_by,`tabNote Seen By`.`idx` idx FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace( + "SELECT `tabNote`.`name`,`tabNote Seen By`.`user` `seen_by`,`tabNote Seen By`.`idx` `idx` FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace( "`", '"' if frappe.db.db_type == "postgres" else "`" ), ) self.assertEqual( - frappe.qb.engine.get_query( + frappe.qb.get_query( "Note", filters={"name": "Test Note Title"}, fields=["name", "seen_by.user as seen_by", "`tabNote Seen By`.`idx` as idx"], ).get_sql(), - "SELECT `tabNote`.`name`,`tabNote Seen By`.`user` seen_by,`tabNote Seen By`.`idx` idx FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace( + "SELECT `tabNote`.`name`,`tabNote Seen By`.`user` `seen_by`,`tabNote Seen By`.`idx` `idx` FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace( "`", '"' if frappe.db.db_type == "postgres" else "`" ), ) @@ -272,40 +264,40 @@ class TestQuery(FrappeTestCase): @run_only_if(db_type_is.MARIADB) def test_comment_stripping(self): self.assertNotIn( - "email", frappe.qb.engine.get_query("User", fields=["name", "#email"], filters={}).get_sql() + "email", frappe.qb.get_query("User", fields=["name", "#email"], filters={}).get_sql() ) def test_nestedset(self): frappe.db.sql("delete from `tabDocType` where `name` = 'Test Tree DocType'") frappe.db.sql_ddl("drop table if exists `tabTest Tree DocType`") create_tree_docs() - descendants_result = frappe.qb.engine.get_query( + descendants_result = frappe.qb.get_query( "Test Tree DocType", fields=["name"], filters={"name": ("descendants of", "Parent 1")}, - orderby="modified", + order_by="modified", ).run(as_list=1) # Format decendants result descendants_result = list(itertools.chain.from_iterable(descendants_result)) self.assertListEqual(descendants_result, get_descendants_of("Test Tree DocType", "Parent 1")) - ancestors_result = frappe.qb.engine.get_query( + ancestors_result = frappe.qb.get_query( "Test Tree DocType", fields=["name"], filters={"name": ("ancestors of", "Child 2")}, - orderby="modified", + order_by="modified", ).run(as_list=1) # Format ancestors result ancestors_result = list(itertools.chain.from_iterable(ancestors_result)) self.assertListEqual(ancestors_result, get_ancestors_of("Test Tree DocType", "Child 2")) - not_descendants_result = frappe.qb.engine.get_query( + not_descendants_result = frappe.qb.get_query( "Test Tree DocType", fields=["name"], filters={"name": ("not descendants of", "Parent 1")}, - orderby="modified", + order_by="modified", ).run(as_dict=1) self.assertListEqual( @@ -317,11 +309,11 @@ class TestQuery(FrappeTestCase): ), ) - not_ancestors_result = frappe.qb.engine.get_query( + not_ancestors_result = frappe.qb.get_query( "Test Tree DocType", fields=["name"], filters={"name": ("not ancestors of", "Child 2")}, - orderby="modified", + order_by="modified", ).run(as_dict=1) self.assertListEqual( diff --git a/frappe/utils/goal.py b/frappe/utils/goal.py index 13c4633031..0dcadc5ec6 100644 --- a/frappe/utils/goal.py +++ b/frappe/utils/goal.py @@ -24,7 +24,7 @@ def get_monthly_results( date_format = "%m-%Y" if frappe.db.db_type != "postgres" else "MM-YYYY" return dict( - frappe.qb.engine.build_conditions(table=goal_doctype, filters=filters) + frappe.qb.get_query(table=goal_doctype, filters=filters) .select( DateFormat(Table[date_col], date_format).as_("month_year"), Function(aggregation, goal_field),