refactor: qb.engine

- simplify
- qb.engine.get_query -> qb.get_query
- qb.engine.build_conditions -> qb.get_query
This commit is contained in:
Faris Ansari 2022-12-25 23:19:11 +05:30
parent 74bce62c62
commit 726fcfdb79
11 changed files with 347 additions and 493 deletions

View file

@ -23,7 +23,7 @@ import click
from werkzeug.local import Local, release_local
from frappe.query_builder import (
get_qb_engine,
get_query,
get_query_builder,
patch_query_aggregation,
patch_query_execute,
@ -244,7 +244,7 @@ def init(site: str, sites_path: str = ".", new_site: bool = False) -> None:
local.session = _dict()
local.dev_server = _dev_server
local.qb = get_query_builder(local.conf.db_type or "mariadb")
local.qb.engine = get_qb_engine()
local.qb.get_query = get_query
setup_module_map()
if not _qb_patched.get(local.conf.db_type):

View file

@ -620,7 +620,7 @@ class Database:
return [map(values.get, fields)]
else:
r = frappe.qb.engine.get_query(
r = frappe.qb.get_query(
"Singles",
filters={"field": ("in", tuple(fields)), "doctype": doctype},
fields=["field", "value"],
@ -653,7 +653,7 @@ class Database:
# Get coulmn and value of the single doctype Accounts Settings
account_settings = frappe.db.get_singles_dict("Accounts Settings")
"""
queried_result = frappe.qb.engine.get_query(
queried_result = frappe.qb.get_query(
"Singles",
filters={"doctype": doctype},
fields=["field", "value"],
@ -726,7 +726,7 @@ class Database:
if cache and fieldname in self.value_cache[doctype]:
return self.value_cache[doctype][fieldname]
val = frappe.qb.engine.get_query(
val = frappe.qb.get_query(
table="Singles",
filters={"doctype": doctype, "field": fieldname},
fields="value",
@ -766,10 +766,10 @@ class Database:
distinct=False,
limit=None,
):
query = frappe.qb.engine.get_query(
query = frappe.qb.get_query(
table=doctype,
filters=filters,
orderby=order_by,
order_by=order_by,
for_update=for_update,
fields=fields,
distinct=distinct,
@ -795,7 +795,7 @@ class Database:
as_dict=False,
):
if names := list(filter(None, names)):
return frappe.qb.engine.get_query(
return frappe.qb.get_query(
doctype,
fields=field,
filters=names,
@ -852,7 +852,7 @@ class Database:
frappe.clear_document_cache(dt, dt)
else:
query = frappe.qb.engine.build_conditions(table=dt, filters=dn, update=True)
query = frappe.qb.get_query(table=dt, filters=dn, update=True)
if isinstance(dn, str):
frappe.clear_document_cache(dt, dn)
@ -1017,9 +1017,9 @@ class Database:
cache_count = frappe.cache().get_value(f"doctype:count:{dt}")
if cache_count is not None:
return cache_count
count = frappe.qb.engine.get_query(
table=dt, filters=filters, fields=Count("*"), distinct=distinct
).run(debug=debug)[0][0]
count = frappe.qb.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct).run(
debug=debug
)[0][0]
if not filters and cache:
frappe.cache().set_value(f"doctype:count:{dt}", count, expires_in_sec=86400)
return count
@ -1160,7 +1160,7 @@ class Database:
Doctype name can be passed directly, it will be pre-pended with `tab`.
"""
filters = filters or kwargs.get("conditions")
query = frappe.qb.engine.build_conditions(table=doctype, filters=filters).delete()
query = frappe.qb.get_query(table=doctype, filters=filters).delete()
if "debug" not in kwargs:
kwargs["debug"] = debug
return query.run(**kwargs)

View file

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Callable
import sqlparse
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
from pypika.queries import QueryBuilder
import frappe
from frappe import _
@ -171,18 +172,16 @@ def table_from_string(table: str) -> "DocType":
return frappe.qb.DocType(table_name=table_name)
def get_nested_set_hierarchy_result(hierarchy: str, field: str, table: str):
ref_doctype = table
def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str):
table = frappe.qb.DocType(doctype)
try:
lft, rgt = (
frappe.qb.from_(ref_doctype).select("lft", "rgt").where(Field("name") == field).run()[0]
)
lft, rgt = frappe.qb.from_(table).select("lft", "rgt").where(table.name == name).run()[0]
except IndexError:
lft, rgt = None, None
if hierarchy in ("descendants of", "not descendants of"):
result = (
frappe.qb.from_(ref_doctype)
frappe.qb.from_(table)
.select(Field("name"))
.where(Field("lft") > lft)
.where(Field("rgt") < rgt)
@ -192,7 +191,7 @@ def get_nested_set_hierarchy_result(hierarchy: str, field: str, table: str):
else:
# Get ancestor elements of a DocType with a tree structure
result = (
frappe.qb.from_(ref_doctype)
frappe.qb.from_(table)
.select(Field("name"))
.where(Field("lft") < lft)
.where(Field("rgt") > rgt)
@ -232,37 +231,67 @@ OPERATOR_MAP: dict[str, Callable] = {
class Engine:
tables: dict[str, str] = {}
@cached_property
def OPERATOR_MAP(self):
# default operators
all_operators = OPERATOR_MAP.copy()
def get_query(
self,
table: str,
fields: list | tuple | None = None,
filters: dict[str, str | int] | str | int | list[list | str | int] | None = None,
pluck: str | None = None,
order_by: str | None = None,
group_by: str | None = None,
limit: int | None = None,
offset: int | None = None,
distinct: bool = False,
for_update: bool = False,
update: bool = False,
into: bool = False,
) -> MySQLQueryBuilder | PostgreSQLQueryBuilder:
# Clean up state before each query
self.is_mariadb = frappe.db.db_type == "mariadb"
self.is_postgres = frappe.db.db_type == "postgres"
self.tables = {}
self.implicit_joins = set()
# TODO: update with site-specific custom operators / removed previous buggy implementation
if frappe.get_hooks("filters_config"):
from frappe.utils.commands import warn
self.doctype = table
self.table = self.get_table(table)
warn(
"The 'filters_config' hook used to add custom operators is not yet implemented"
" in frappe.db.query engine. Use db_query (frappe.get_list) instead."
)
if update:
self.query = frappe.qb.update(self.table)
elif into:
self.query = frappe.qb.into(self.table)
else:
self.query = frappe.qb.from_(self.table)
return all_operators
self.fields = self.parse_fields(fields)
if not self.fields:
self.fields = [getattr(self.table, pluck or "name")]
def get_condition(self, table: str | Table, **kwargs) -> frappe.qb:
"""Get initial table object
for field in self.fields:
if isinstance(field, DynamicTableField):
self.query = field.apply(self.query)
else:
self.query = self.query.select(field)
Args:
table (str): DocType
self.apply_filters(filters)
self.apply_implicit_joins()
self.apply_order_by(order_by)
Returns:
frappe.qb: DocType with initial condition
"""
table_object = self.get_table(table)
if kwargs.get("update"):
return frappe.qb.update(table_object)
if kwargs.get("into"):
return frappe.qb.into(table_object)
return frappe.qb.from_(table_object)
if limit:
self.query = self.query.limit(limit)
if offset:
self.query = self.query.offset(offset)
if distinct:
self.query = self.query.distinct()
if for_update:
self.query = self.query.for_update()
if group_by:
self.query = self.query.groupby(group_by)
return self.query
def get_table(self, table_name: str | Table) -> Table:
if isinstance(table_name, Table):
@ -272,178 +301,93 @@ class Engine:
self.tables[table_name] = frappe.qb.DocType(table_name)
return self.tables[table_name]
def criterion_query(self, table: str, criterion: Criterion, **kwargs) -> frappe.qb:
"""Generate filters from Criterion objects
Args:
table (str): DocType
criterion (Criterion): Filters
Returns:
frappe.qb: condition object
"""
condition = self.add_conditions(self.get_condition(table, **kwargs), **kwargs)
return condition.where(criterion)
def add_conditions(self, conditions: frappe.qb, **kwargs):
"""Adding additional conditions
Args:
conditions (frappe.qb): built conditions
Returns:
conditions (frappe.qb): frappe.qb object
"""
if kwargs.get("orderby") and kwargs.get("orderby") != "KEEP_DEFAULT_ORDERING":
orderby = kwargs.get("orderby")
if isinstance(orderby, str) and len(orderby.split()) > 1:
for ordby in orderby.split(","):
if ordby := ordby.strip():
orderby, order = change_orderby(ordby)
conditions = conditions.orderby(orderby, order=order)
else:
conditions = conditions.orderby(orderby, order=kwargs.get("order") or Order.desc)
if kwargs.get("limit"):
conditions = conditions.limit(kwargs.get("limit"))
conditions = conditions.offset(kwargs.get("offset", 0))
if kwargs.get("distinct"):
conditions = conditions.distinct()
if kwargs.get("for_update"):
conditions = conditions.for_update()
if kwargs.get("groupby"):
conditions = conditions.groupby(kwargs.get("groupby"))
return conditions
def misc_query(self, table: str, filters: list | tuple = None, **kwargs):
"""Build conditions using the given Lists or Tuple filters
Args:
table (str): DocType
filters (Union[List, Tuple], optional): Filters. Defaults to None.
"""
conditions = self.get_condition(table, **kwargs)
def apply_filters(
self, filters: dict[str, str | int | list] | str | int | list[list] | None = None
):
if not filters:
return conditions
if isinstance(filters, list):
for f in filters:
if isinstance(f, (list, tuple)):
_operator = self.OPERATOR_MAP[f[-2].casefold()]
if len(f) == 4:
table_object = self.get_table(f[0])
_field = table_object[f[1]]
else:
_field = Field(f[0])
conditions = conditions.where(_operator(_field, f[-1]))
elif isinstance(f, dict):
conditions = self.dict_query(table, f, **kwargs)
else:
_operator = self.OPERATOR_MAP[filters[1].casefold()]
if not isinstance(filters[0], str):
conditions = self.make_function_for_filters(filters[0], filters[2])
break
conditions = conditions.where(_operator(Field(filters[0]), filters[2]))
break
return
return self.add_conditions(conditions, **kwargs)
def dict_query(self, table: str, filters: dict[str, str | int] = None, **kwargs) -> frappe.qb:
"""Build conditions using the given dictionary filters
Args:
table (str): DocType
filters (Dict[str, Union[str, int]], optional): Filters. Defaults to None.
Returns:
frappe.qb: conditions object
"""
conditions = self.get_condition(table, **kwargs)
if isinstance(table, str):
table = frappe.qb.DocType(table)
if not filters:
conditions = self.add_conditions(conditions, **kwargs)
return conditions
for key, value in filters.items():
if isinstance(value, bool):
filters.update({key: str(int(value))})
filters = {
(self.get_function_object(k) if has_function(k) else k): v for k, v in filters.items()
}
for key in filters:
value = filters.get(key)
_operator = self.OPERATOR_MAP["="]
if not isinstance(key, str):
conditions = conditions.where(self.make_function_for_filters(key, value))
continue
# Nested set support
if isinstance(value, (list, tuple)):
if value[0] in self.OPERATOR_MAP["nested_set"]:
hierarchy, _field = value
result = get_nested_set_hierarchy_result(hierarchy, _field, table)
_operator = (
self.OPERATOR_MAP["not in"]
if hierarchy in ("not ancestors of", "not descendants of")
else self.OPERATOR_MAP["in"]
)
if result:
result = list(itertools.chain.from_iterable(result))
conditions = conditions.where(_operator(getattr(table, key), result))
else:
conditions = conditions.where(_operator(getattr(table, key), ("",)))
# Allow additional conditions
break
_operator = self.OPERATOR_MAP[value[0].casefold()]
_value = value[1] if value[1] else ("",)
conditions = conditions.where(_operator(getattr(table, key), _value))
else:
if value is not None:
conditions = conditions.where(_operator(getattr(table, key), value))
else:
_table = conditions._from[0]
field = getattr(_table, key)
conditions = conditions.where(field.isnull())
return self.add_conditions(conditions, **kwargs)
def build_conditions(
self, table: str, filters: dict[str, str | int] | str | int = None, **kwargs
) -> frappe.qb:
"""Build conditions for sql query
Args:
filters (Union[Dict[str, Union[str, int]], str, int]): conditions in Dict
table (str): DocType
Returns:
frappe.qb: frappe.qb conditions object
"""
if isinstance(filters, int) or isinstance(filters, str):
if isinstance(filters, (str, int)):
filters = {"name": str(filters)}
if isinstance(filters, Criterion):
criterion = self.criterion_query(table, filters, **kwargs)
self.query = self.query.where(filters)
elif isinstance(filters, dict):
self.apply_dict_filters(filters)
elif isinstance(filters, (list, tuple)):
criterion = self.misc_query(table, filters, **kwargs)
self.apply_list_filters(filters)
def apply_list_filters(self, filters: list[list]):
for filter in filters:
if len(filter) == 2:
field, value = filter
self._apply_filter(field, value)
elif len(filter) == 3:
field, operator, value = filter
self._apply_filter(field, value, operator)
elif len(filter) == 4:
doctype, field, operator, value = filter
self._apply_filter(field, value, operator, doctype)
def apply_dict_filters(self, filters: dict[str, str | int | list]):
for key in filters:
value = filters.get(key)
self._apply_filter(key, value)
def _apply_filter(
self, field: str, value: str | int | list | None, operator: str = "=", doctype: str | None = None
):
_field = field
_value = value
_operator = operator
if has_function(field):
_field = self.get_function_object(field)
elif not doctype or doctype == self.doctype:
_field = self.table[field]
elif doctype:
_field = self.get_table(doctype)[field]
# keep track of implicit join if child table is referenced
if doctype and doctype != self.doctype:
meta = frappe.get_meta(doctype)
if meta.istable:
self.implicit_joins.add((doctype, "child"))
if isinstance(_value, (str, int)):
_value = str(_value)
elif isinstance(_value, (list, tuple)):
_operator, _value = _value
elif isinstance(_value, bool):
_value = int(_value)
if isinstance(_value, str) and has_function(_value):
_value = self.get_function_object(_value)
# Nested set
if _operator in self.OPERATOR_MAP["nested_set"]:
hierarchy = _operator
docname = _value
result = get_nested_set_hierarchy_result(self.doctype, docname, hierarchy)
operator_fn = (
self.OPERATOR_MAP["not in"]
if hierarchy in ("not ancestors of", "not descendants of")
else self.OPERATOR_MAP["in"]
)
if result:
result = list(itertools.chain.from_iterable(result))
self.query = self.query.where(operator_fn(_field, result))
else:
self.query = self.query.where(operator_fn(_field, ("",)))
return
operator_fn = self.OPERATOR_MAP[_operator.casefold()]
if _value is None and isinstance(_field, Field):
self.query = self.query.where(_field.isnull())
else:
criterion = self.dict_query(filters=filters, table=table, **kwargs)
return criterion
def make_function_for_filters(self, key, value: int | str):
value = list(value)
if isinstance(value[1], str) and has_function(value[1]):
value[1] = self.get_function_object(value[1])
return OPERATOR_MAP[value[0].casefold()](key, value[1])
self.query = self.query.where(operator_fn(_field, _value))
def get_function_object(self, field: str) -> "Function":
"""Expects field to look like 'SUM(*)' or 'name' or something similar. Returns PyPika Function object"""
@ -495,84 +439,12 @@ class Engine:
# Fall back for functions not present in `SqlFunctions``
return Function(func, *_args, alias=alias or None)
def function_objects_from_string(self, fields):
fields = list(map(lambda str: str.strip(), COMMA_PATTERN.split(fields)))
return self.function_objects_from_list(fields=fields)
def function_objects_from_list(self, fields):
functions = []
for field in fields:
field = field.casefold() if (isinstance(field, str) and "`" not in field) else field
if not issubclass(type(field), Criterion):
if any([f"{func}(" in field for func in SQL_FUNCTIONS]) or "(" in field:
functions.append(field)
return [self.get_function_object(function) for function in functions]
def remove_string_functions(self, fields, function_objects):
"""Remove string functions from fields which have already been converted to function objects"""
def _remove_string_aliasing(function, fields: list | str):
if function.alias:
to_replace = " as " + function.alias.casefold()
if to_replace in fields:
fields = fields.replace(to_replace, "")
elif " as " + f"`{function.alias.casefold()}" in fields:
fields = fields.replace(" as " + f"`{function.alias.casefold()}`", "")
return fields
for function in function_objects:
if isinstance(fields, str):
fields = _remove_string_aliasing(function, fields)
fields = BRACKETS_PATTERN.sub("", re.sub(function.name, "", fields, flags=re.IGNORECASE))
# Check if only comma is left in fields after stripping functions.
if "," in fields and (len(fields.strip()) == 1):
fields = ""
else:
updated_fields = []
for field in fields:
if isinstance(field, str):
field = _remove_string_aliasing(function, field)
substituted_string = (
BRACKETS_PATTERN.sub("", field).strip().casefold()
if "`" not in field
else BRACKETS_PATTERN.sub("", field).strip()
)
# This is done to avoid casefold of table name.
if substituted_string.casefold() == function.name.casefold():
replaced_string = substituted_string.casefold().replace(function.name.casefold(), "")
else:
replaced_string = substituted_string.replace(function.name.casefold(), "")
updated_fields.append(replaced_string)
fields = [field for field in updated_fields if field]
return fields
def get_fieldnames_from_child_table(self, doctype, fields):
# Hacky and flaky implementation of implicit joins.
# convert child_table.fieldname to `tabChild DocType`.`fieldname`
_fields = []
for field in fields:
if "." in field and "tab" not in field:
alias = None
if " as " in field:
field, alias = field.split(" as ")
fieldname, linked_fieldname = field.split(".")
linked_doctype = frappe.get_meta(doctype).get_field(fieldname).options
field = f"`tab{linked_doctype}`.`{linked_fieldname}`"
if alias:
field = f"{field} {alias}"
_fields.append(field)
return _fields
def sanitize_fields(self, fields: str | list | tuple):
is_mariadb = frappe.db.db_type == "mariadb"
def _sanitize_field(field: str):
if not isinstance(field, str):
return field
stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower")
if is_mariadb:
if self.is_mariadb:
return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field)
return stripped_field
@ -583,174 +455,88 @@ class Engine:
return fields
def get_list_fields(self, table: str, fields: list) -> list:
updated_fields = []
if issubclass(type(fields), Criterion) or "*" in fields:
return fields
fields = self.get_fieldnames_from_child_table(doctype=table, fields=fields)
for field in fields:
if not isinstance(field, Criterion) and field:
if " as " in field:
field, reference = field.split(" as ")
if "`" in field:
updated_fields.append(PseudoColumnMapper(f"{field} {reference}"))
else:
updated_fields.append(Field(field.strip()).as_(reference))
elif "`" in str(field):
updated_fields.append(PseudoColumnMapper(field.strip()))
else:
updated_fields.append(Field(field))
return updated_fields
def parse_string_field(self, field: str):
if field == "*":
return self.table.star
alias = None
if " as " in field:
field, alias = field.split(" as ")
if "`" in field:
if alias:
return PseudoColumnMapper(f"{field} {alias}")
return PseudoColumnMapper(field)
if alias:
return self.table[field].as_(alias)
return self.table[field]
def get_string_fields(self, fields: str) -> Field:
if fields == "*":
return fields
if "`" in fields:
fields = PseudoColumnMapper(fields)
if " as " in str(fields):
fields, reference = str(fields).split(" as ")
if "`" in str(fields):
fields = PseudoColumnMapper(f"{fields} {reference}")
else:
fields = Field(fields).as_(reference)
return fields
def set_fields(self, table: str, fields, **kwargs) -> list:
fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name"
def parse_fields(self, fields: str | list | tuple | None) -> list:
if not fields:
return []
fields = self.sanitize_fields(fields)
if isinstance(fields, list) and None in fields and Field not in fields:
return None
function_objects = []
is_list = isinstance(fields, (list, tuple, set))
if is_list and len(fields) == 1:
fields = fields[0]
is_list = False
if isinstance(fields, (list, tuple, set)) and None in fields and Field not in fields:
return []
if is_list:
function_objects += self.function_objects_from_list(fields=fields)
if not isinstance(fields, (list, tuple)):
fields = [fields]
is_str = isinstance(fields, str)
if is_str:
fields = fields.casefold() if "`" not in fields else fields
function_objects += self.function_objects_from_string(fields=fields)
fields = self.remove_string_functions(fields, function_objects)
if is_str and "," in fields:
fields = [field.replace(" ", "") if "as" not in field else field for field in fields.split(",")]
is_list, is_str = True, False
if is_str:
fields = self.get_string_fields(fields)
if not is_str and fields:
fields = self.get_list_fields(table, fields)
# Need to check instance again since fields modified.
if not isinstance(fields, (list, tuple, set)):
fields = [fields] if fields else []
fields.extend(function_objects)
return fields
def join_child_tables(
self,
criterion: Criterion,
join_type: str,
child_table: Table,
parent_table: Table,
) -> Criterion:
if self.joined_tables.get(join_type) != child_table:
criterion = getattr(criterion, join_type)(child_table).on(
(child_table.parent == parent_table.name)
& (child_table.parenttype == TAB_PATTERN.sub("", parent_table._table_name))
)
self.joined_tables[join_type] = child_table
return criterion
def join(self, criterion, fields, table, join_type):
"""Handles all join operations on criterion objects"""
has_join = False
table_pattern = (
re.compile(r"`\btab\w+") if frappe.db.db_type == "mariadb" else re.compile(r'"\btab\w+')
)
def _update_pypika_fields(field):
if not is_pypika_function_object(field):
field = field if isinstance(field, (str, PseudoColumnMapper)) else field.get_sql()
if not table_pattern.search(str(field)):
if isinstance(field, PseudoColumnMapper):
field = field.get_sql()
return getattr(frappe.qb.DocType(table), field)
else:
return field
def parse_field(field: str):
if has_function(field):
return self.get_function_object(field)
elif parsed := DynamicTableField.parse(field, self.doctype):
return parsed
else:
field.args = [getattr(frappe.qb.DocType(table), arg.get_sql()) for arg in field.args]
return field
return self.parse_string_field(field)
if not isinstance(fields, Criterion):
for field in fields:
# Only perform this bit if foreign doctype in fields
if (
not is_pypika_function_object(field)
and (str(field).startswith('"tab') or str(field).startswith("`tab"))
and (f"`tab{table}`" not in str(field) and f'tab{table}"' not in str(field))
):
has_join = True
child_table = table_from_string(str(field))
parent_table = frappe.qb.DocType(table) if not isinstance(table, Table) else table
criterion = self.join_child_tables(
criterion=criterion,
join_type=join_type,
child_table=child_table,
parent_table=parent_table,
)
_fields = []
for field in fields:
if isinstance(field, Criterion):
_fields.append(field)
elif isinstance(field, str):
if "," in field:
field = field.casefold() if "`" not in field else field
field_list = COMMA_PATTERN.split(field)
for field in field_list:
if _field := field.strip():
_fields.append(parse_field(_field))
else:
_fields.append(parse_field(field))
if has_join:
fields = [_update_pypika_fields(field) for field in fields]
return _fields
if len(self.tables) > 1:
parent_table = self.tables[table]
child_tables = list(self.tables.values())[1:]
for child_table in child_tables:
criterion = self.join_child_tables(
criterion,
join_type=join_type,
child_table=child_table,
parent_table=parent_table,
def apply_implicit_joins(self):
for d in self.implicit_joins:
doctype, join_type = d
table = self.get_table(doctype)
if join_type == "child":
self.query = self.query.left_join(table).on(
(table.parent == self.table.name) & (table.parenttype == self.doctype)
)
return criterion, fields
def apply_order_by(self, order_by: str | None):
if not order_by or order_by == "KEEP_DEFAULT_ORDERING":
return
for declaration in order_by.split(","):
if _order_by := declaration.strip():
parts = _order_by.split(" ")
order_field, order_direction = parts[0], parts[1] if len(parts) > 1 else "asc"
order_direction = Order.asc if order_direction.lower() == "asc" else Order.desc
self.query = self.query.orderby(order_field, order=order_direction)
def get_query(
self,
table: str,
fields: list | tuple,
filters: dict[str, str | int] | str | int | list[list | str | int] = None,
**kwargs,
) -> MySQLQueryBuilder | PostgreSQLQueryBuilder:
# Clean up state before each query
self.tables = {}
self.joined_tables = {}
self.linked_doctype = None
self.fieldname = None
@cached_property
def OPERATOR_MAP(self):
# default operators
all_operators = OPERATOR_MAP.copy()
criterion = self.build_conditions(table, filters, **kwargs)
fields = self.set_fields(table, fields, **kwargs)
join_type = kwargs.get("join").replace(" ", "_") if kwargs.get("join") else "left_join"
criterion, fields = self.join(
criterion=criterion, fields=fields, table=table, join_type=join_type
)
# TODO: update with site-specific custom operators / removed previous buggy implementation
if frappe.get_hooks("filters_config"):
from frappe.utils.commands import warn
if isinstance(fields, (list, tuple)):
query = criterion.select(*fields)
warn(
"The 'filters_config' hook used to add custom operators is not yet implemented"
" in frappe.db.query engine. Use db_query (frappe.get_list) instead."
)
elif isinstance(fields, Criterion):
query = criterion.select(fields)
else:
query = criterion.select(fields)
return query
return all_operators
class Permission:
@ -781,3 +567,80 @@ class Permission:
@staticmethod
def get_tables_from_query(query: str):
return [table for table in WORDS_PATTERN.findall(query) if table.startswith("tab")]
class DynamicTableField:
def __init__(
self,
doctype: str,
fieldname: str,
parent_doctype: str,
alias: str | None = None,
) -> None:
self.doctype = doctype
self.fieldname = fieldname
self.alias = alias
self.parent_doctype = parent_doctype
def __str__(self) -> str:
table_name = f"`tab{self.doctype}`"
fieldname = f"`{self.fieldname}`"
if frappe.db.db_type == "postgres":
table_name = table_name.replace("`", '"')
fieldname = fieldname.replace("`", '"')
alias = f"AS {self.alias}" if self.alias else ""
return f"{table_name}.{fieldname} {alias}".strip()
@staticmethod
def parse(field: str, doctype: str):
if "." in field:
alias = None
if " as " in field:
field, alias = field.split(" as ")
if field.startswith("`tab") or field.startswith('"tab'):
_, child_doctype, child_field = re.search(r'([`"])tab(.+?)\1.\1(.+)\1', field).groups()
if child_doctype == doctype:
return
return ChildTableField(child_doctype, child_field, doctype, alias=alias)
else:
linked_fieldname, fieldname = field.split(".")
linked_field = frappe.get_meta(doctype).get_field(linked_fieldname)
linked_doctype = linked_field.options
if linked_field.fieldtype == "Link":
return LinkTableField(linked_doctype, fieldname, doctype, linked_fieldname, alias=alias)
elif linked_field.fieldtype in frappe.model.table_fields:
return ChildTableField(linked_doctype, fieldname, doctype, alias=alias)
def apply(self, query: QueryBuilder) -> QueryBuilder:
raise NotImplementedError
class ChildTableField(DynamicTableField):
def apply(self, query: QueryBuilder) -> QueryBuilder:
table = frappe.qb.DocType(self.doctype)
main_table = frappe.qb.DocType(self.parent_doctype)
if not query.is_joined(table):
query = query.left_join(table).on(
(table.parent == main_table.name) & (table.parenttype == self.parent_doctype)
)
return query.select(getattr(table, self.fieldname).as_(self.alias or None))
class LinkTableField(DynamicTableField):
def __init__(
self,
doctype: str,
fieldname: str,
parent_doctype: str,
link_fieldname: str,
alias: str | None = None,
) -> None:
super().__init__(doctype, fieldname, parent_doctype, alias=alias)
self.link_fieldname = link_fieldname
def apply(self, query: QueryBuilder) -> QueryBuilder:
table = frappe.qb.DocType(self.doctype)
main_table = frappe.qb.DocType(self.parent_doctype)
if not query.is_joined(table):
query = query.left_join(table).on(table.name == getattr(main_table, self.link_fieldname))
return query.select(getattr(table, self.fieldname).as_(self.alias or None))

View file

@ -200,7 +200,7 @@ def get_cards_for_user(doctype, txt, searchfield, start, page_len, filters):
if txt:
search_conditions = [numberCard[field].like(f"%{txt}%") for field in searchfields]
condition_query = frappe.qb.engine.build_conditions(doctype, filters)
condition_query = frappe.qb.get_query(doctype, filters)
return (
condition_query.select(numberCard.name, numberCard.label, numberCard.document_type)

View file

@ -36,7 +36,7 @@ def get_group_by_count(doctype: str, current_filters: str, field: str) -> list[d
ToDo = DocType("ToDo")
User = DocType("User")
count = Count("*").as_("count")
filtered_records = frappe.qb.engine.build_conditions(doctype, current_filters).select("name")
filtered_records = frappe.qb.get_query(doctype, filters=current_filters).select("name")
return (
frappe.qb.from_(ToDo)

View file

@ -7,7 +7,7 @@ from frappe.query_builder.terms import ParameterizedFunction, ParameterizedValue
from frappe.query_builder.utils import (
Column,
DocType,
get_qb_engine,
get_query,
get_query_builder,
patch_query_aggregation,
patch_query_execute,

View file

@ -103,7 +103,7 @@ class Cast_(Function):
def _aggregate(function, dt, fieldname, filters, **kwargs):
return (
frappe.qb.engine.build_conditions(dt, filters)
frappe.qb.get_query(dt, filters=filters)
.select(function(PseudoColumn(fieldname)))
.run(**kwargs)[0][0]
or 0

View file

@ -3,6 +3,7 @@ from importlib import import_module
from typing import Any, Callable, get_type_hints
from pypika import Query
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
from pypika.queries import Column
from pypika.terms import PseudoColumn
@ -55,10 +56,10 @@ def get_query_builder(type_of_db: str) -> Postgres | MariaDB:
return picks[db]
def get_qb_engine():
def get_query(*args, **kwargs) -> MySQLQueryBuilder | PostgreSQLQueryBuilder:
from frappe.database.query import Engine
return Engine()
return Engine().get_query(*args, **kwargs)
def get_attr(method_string):

View file

@ -229,9 +229,7 @@ class TestReportview(FrappeTestCase):
)
def test_none_filter(self):
query = frappe.qb.engine.get_query(
"DocType", fields="name", filters={"restrict_to_domain": None}
)
query = frappe.qb.get_query("DocType", fields="name", filters={"restrict_to_domain": None})
sql = str(query).replace("`", "").replace('"', "")
condition = "restrict_to_domain IS NULL"
self.assertIn(condition, sql)

View file

@ -56,30 +56,28 @@ class TestQuery(FrappeTestCase):
@run_only_if(db_type_is.MARIADB)
def test_multiple_tables_in_filters(self):
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"DocType",
["*"],
[
["BOM Update Log", "name", "like", "f%"],
["DocField", "name", "like", "f%"],
["DocType", "parent", "=", "something"],
],
).get_sql(),
"SELECT * FROM `tabDocType` LEFT JOIN `tabBOM Update Log` ON `tabBOM Update Log`.`parent`=`tabDocType`.`name` AND `tabBOM Update Log`.`parenttype`='DocType' WHERE `tabBOM Update Log`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'",
"SELECT `tabDocType`.* FROM `tabDocType` LEFT JOIN `tabDocField` ON `tabDocField`.`parent`=`tabDocType`.`name` AND `tabDocField`.`parenttype`='DocType' WHERE `tabDocField`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'",
)
@run_only_if(db_type_is.MARIADB)
def test_string_fields(self):
self.assertEqual(
frappe.qb.engine.get_query(
"User", fields="name, email", filters={"name": "Administrator"}
).get_sql(),
frappe.qb.get_query("User", fields="name, email", filters={"name": "Administrator"}).get_sql(),
frappe.qb.from_("User")
.select(Field("name"), Field("email"))
.where(Field("name") == "Administrator")
.get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"User", fields=["`name`, `email`"], filters={"name": "Administrator"}
).get_sql(),
frappe.qb.from_("User")
@ -89,7 +87,7 @@ class TestQuery(FrappeTestCase):
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"User", fields=["`tabUser`.`name`", "`tabUser`.`email`"], filters={"name": "Administrator"}
).run(),
frappe.qb.from_("User")
@ -99,7 +97,7 @@ class TestQuery(FrappeTestCase):
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"User",
fields=["`tabUser`.`name` as owner", "`tabUser`.`email`"],
filters={"name": "Administrator"},
@ -111,7 +109,7 @@ class TestQuery(FrappeTestCase):
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"User", fields=["`tabUser`.`name`, Count(`name`) as count"], filters={"name": "Administrator"}
).run(),
frappe.qb.from_("User")
@ -121,7 +119,7 @@ class TestQuery(FrappeTestCase):
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"User",
fields=["`tabUser`.`name`, Count(`name`) as `count`"],
filters={"name": "Administrator"},
@ -133,7 +131,7 @@ class TestQuery(FrappeTestCase):
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"User", fields="`tabUser`.`name`, Count(`name`) as `count`", filters={"name": "Administrator"}
).run(),
frappe.qb.from_("User")
@ -144,38 +142,34 @@ class TestQuery(FrappeTestCase):
def test_functions_fields(self):
self.assertEqual(
frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(),
frappe.qb.get_query("User", fields="Count(name)", filters={}).get_sql(),
frappe.qb.from_("User").select(Count(Field("name"))).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(),
frappe.qb.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(),
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(
"User", fields=["abs(name-email)", "Count(name)"], filters={}
).get_sql(),
frappe.qb.get_query("User", fields=["abs(name-email)", "Count(name)"], filters={}).get_sql(),
frappe.qb.from_("User")
.select(Abs(Field("name") - Field("email")), Count(Field("name")))
.get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query("User", fields=[Count("*")], filters={}).get_sql(),
frappe.qb.get_query("User", fields=[Count("*")], filters={}).get_sql(),
frappe.qb.from_("User").select(Count("*")).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(
"User", fields="timestamp(creation, modified)", filters={}
).get_sql(),
frappe.qb.get_query("User", fields="timestamp(creation, modified)", filters={}).get_sql(),
frappe.qb.from_("User").select(Timestamp(Field("creation"), Field("modified"))).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"User", fields="Count(name) as count, Max(email) as max_email", filters={}
).get_sql(),
frappe.qb.from_("User")
@ -186,85 +180,83 @@ class TestQuery(FrappeTestCase):
def test_qb_fields(self):
user_doctype = frappe.qb.DocType("User")
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
user_doctype, fields=[user_doctype.name, user_doctype.email], filters={}
).get_sql(),
frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(),
frappe.qb.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(),
frappe.qb.from_(user_doctype).select(user_doctype.email).get_sql(),
)
def test_aliasing(self):
user_doctype = frappe.qb.DocType("User")
self.assertEqual(
frappe.qb.engine.get_query(
user_doctype, fields=["name as owner", "email as id"], filters={}
).get_sql(),
frappe.qb.get_query("User", fields=["name as owner", "email as id"], filters={}).get_sql(),
frappe.qb.from_(user_doctype)
.select(user_doctype.name.as_("owner"), user_doctype.email.as_("id"))
.get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(
user_doctype, fields="name as owner, email as id", filters={}
).get_sql(),
frappe.qb.get_query(user_doctype, fields="name as owner, email as id", filters={}).get_sql(),
frappe.qb.from_(user_doctype)
.select(user_doctype.name.as_("owner"), user_doctype.email.as_("id"))
.get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
user_doctype, fields=["Count(name) as count", "email as id"], filters={}
).get_sql(),
frappe.qb.from_(user_doctype)
.select(user_doctype.email.as_("id"), Count(Field("name")).as_("count"))
.select(Count(Field("name")).as_("count"), user_doctype.email.as_("id"))
.get_sql(),
)
@run_only_if(db_type_is.MARIADB)
def test_filters(self):
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"User", filters={"IfNull(name, " ")": ("<", Now())}, fields=["Max(name)"]
).run(),
frappe.qb.from_("User").select(Max(Field("name"))).where(Ifnull("name", "") < Now()).run(),
)
def test_implicit_join_query(self):
self.maxDiff = None
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"Note",
filters={"name": "Test Note Title"},
fields=["name", "`tabNote Seen By`.`user` as seen_by"],
).get_sql(),
"SELECT `tabNote`.`name`,`tabNote Seen By`.`user` seen_by FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace(
"SELECT `tabNote`.`name`,`tabNote Seen By`.`user` `seen_by` FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace(
"`", '"' if frappe.db.db_type == "postgres" else "`"
),
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"Note",
filters={"name": "Test Note Title"},
fields=["name", "`tabNote Seen By`.`user` as seen_by", "`tabNote Seen By`.`idx` as idx"],
).get_sql(),
"SELECT `tabNote`.`name`,`tabNote Seen By`.`user` seen_by,`tabNote Seen By`.`idx` idx FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace(
"SELECT `tabNote`.`name`,`tabNote Seen By`.`user` `seen_by`,`tabNote Seen By`.`idx` `idx` FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace(
"`", '"' if frappe.db.db_type == "postgres" else "`"
),
)
self.assertEqual(
frappe.qb.engine.get_query(
frappe.qb.get_query(
"Note",
filters={"name": "Test Note Title"},
fields=["name", "seen_by.user as seen_by", "`tabNote Seen By`.`idx` as idx"],
).get_sql(),
"SELECT `tabNote`.`name`,`tabNote Seen By`.`user` seen_by,`tabNote Seen By`.`idx` idx FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace(
"SELECT `tabNote`.`name`,`tabNote Seen By`.`user` `seen_by`,`tabNote Seen By`.`idx` `idx` FROM `tabNote` LEFT JOIN `tabNote Seen By` ON `tabNote Seen By`.`parent`=`tabNote`.`name` AND `tabNote Seen By`.`parenttype`='Note' WHERE `tabNote`.`name`='Test Note Title'".replace(
"`", '"' if frappe.db.db_type == "postgres" else "`"
),
)
@ -272,40 +264,40 @@ class TestQuery(FrappeTestCase):
@run_only_if(db_type_is.MARIADB)
def test_comment_stripping(self):
self.assertNotIn(
"email", frappe.qb.engine.get_query("User", fields=["name", "#email"], filters={}).get_sql()
"email", frappe.qb.get_query("User", fields=["name", "#email"], filters={}).get_sql()
)
def test_nestedset(self):
frappe.db.sql("delete from `tabDocType` where `name` = 'Test Tree DocType'")
frappe.db.sql_ddl("drop table if exists `tabTest Tree DocType`")
create_tree_docs()
descendants_result = frappe.qb.engine.get_query(
descendants_result = frappe.qb.get_query(
"Test Tree DocType",
fields=["name"],
filters={"name": ("descendants of", "Parent 1")},
orderby="modified",
order_by="modified",
).run(as_list=1)
# Format decendants result
descendants_result = list(itertools.chain.from_iterable(descendants_result))
self.assertListEqual(descendants_result, get_descendants_of("Test Tree DocType", "Parent 1"))
ancestors_result = frappe.qb.engine.get_query(
ancestors_result = frappe.qb.get_query(
"Test Tree DocType",
fields=["name"],
filters={"name": ("ancestors of", "Child 2")},
orderby="modified",
order_by="modified",
).run(as_list=1)
# Format ancestors result
ancestors_result = list(itertools.chain.from_iterable(ancestors_result))
self.assertListEqual(ancestors_result, get_ancestors_of("Test Tree DocType", "Child 2"))
not_descendants_result = frappe.qb.engine.get_query(
not_descendants_result = frappe.qb.get_query(
"Test Tree DocType",
fields=["name"],
filters={"name": ("not descendants of", "Parent 1")},
orderby="modified",
order_by="modified",
).run(as_dict=1)
self.assertListEqual(
@ -317,11 +309,11 @@ class TestQuery(FrappeTestCase):
),
)
not_ancestors_result = frappe.qb.engine.get_query(
not_ancestors_result = frappe.qb.get_query(
"Test Tree DocType",
fields=["name"],
filters={"name": ("not ancestors of", "Child 2")},
orderby="modified",
order_by="modified",
).run(as_dict=1)
self.assertListEqual(

View file

@ -24,7 +24,7 @@ def get_monthly_results(
date_format = "%m-%Y" if frappe.db.db_type != "postgres" else "MM-YYYY"
return dict(
frappe.qb.engine.build_conditions(table=goal_doctype, filters=filters)
frappe.qb.get_query(table=goal_doctype, filters=filters)
.select(
DateFormat(Table[date_col], date_format).as_("month_year"),
Function(aggregation, goal_field),