diff --git a/frappe/core/report/permitted_documents_for_user/permitted_documents_for_user.py b/frappe/core/report/permitted_documents_for_user/permitted_documents_for_user.py index a7eff77ed0..2c92a72ab3 100644 --- a/frappe/core/report/permitted_documents_for_user/permitted_documents_for_user.py +++ b/frappe/core/report/permitted_documents_for_user/permitted_documents_for_user.py @@ -31,7 +31,7 @@ def execute(filters=None): def get_columns_and_fields(doctype): columns = [f"Name:Link/{doctype}:200"] - fields = ["`name`"] + fields = ["name"] for df in frappe.get_meta(doctype).fields: if df.in_list_view and df.fieldtype in data_fieldtypes: fields.append(f"`{df.fieldname}`") diff --git a/frappe/database/database.py b/frappe/database/database.py index 07f8162ef7..25e82135df 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -10,6 +10,7 @@ import traceback from contextlib import contextmanager from time import time +from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder from pypika.terms import Criterion, NullValue import frappe @@ -175,6 +176,9 @@ class Database: {"name": "a%", "owner":"test@example.com"}) """ + if isinstance(query, (MySQLQueryBuilder, PostgreSQLQueryBuilder)): + frappe.errprint("Use run method to execute SQL queries generated by Query Engine") + debug = debug or getattr(self, "debug", False) query = str(query) if not run: @@ -790,7 +794,7 @@ class Database: if fields == "*" and not isinstance(fields, (list, tuple)) and not isinstance(fields, Criterion): as_dict = True - return self.sql(query, as_dict=as_dict, debug=debug, update=update, run=run, pluck=pluck) + return query.run(as_dict=as_dict, debug=debug, update=update, run=run, pluck=pluck) def _get_value_for_many_names( self, @@ -807,18 +811,15 @@ class Database: as_dict=False, ): if names := list(filter(None, names)): - return self.get_all( + return frappe.qb.engine.get_query( doctype, fields=field, filters=names, order_by=order_by, pluck=pluck, - debug=debug, - as_list=not as_dict, - run=run, distinct=distinct, - limit_page_length=limit, - ) + limit=limit, + ).run(debug=debug, run=run, as_dict=as_dict) return {} def update(self, *args, **kwargs): @@ -1069,10 +1070,9 @@ class Database: cache_count = frappe.cache().get_value(f"doctype:count:{dt}") if cache_count is not None: return cache_count - query = frappe.qb.engine.get_query( + count = frappe.qb.engine.get_query( table=dt, filters=filters, fields=Count("*"), distinct=distinct - ) - count = self.sql(query, debug=debug)[0][0] + ).run(debug=debug)[0][0] if not filters and cache: frappe.cache().set_value(f"doctype:count:{dt}", count, expires_in_sec=86400) return count diff --git a/frappe/database/query.py b/frappe/database/query.py index 8dbb564edc..5726b760ee 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -3,19 +3,27 @@ import re from ast import literal_eval from functools import cached_property from types import BuiltinFunctionType -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Callable + +from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder import frappe from frappe import _ +from frappe.database.utils import is_pypika_function_object from frappe.model.db_query import get_timespan_date_range from frappe.query_builder import Criterion, Field, Order, Table, functions from frappe.query_builder.functions import Function, SqlFunctions +from frappe.query_builder.utils import PseudoColumn + +if TYPE_CHECKING: + from frappe.query_builder import DocType TAB_PATTERN = re.compile("^tab") WORDS_PATTERN = re.compile(r"\w+") BRACKETS_PATTERN = re.compile(r"\(.*?\)|$") SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions] COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))") +TABLE_PATTERN = re.compile(r"`\btab\w+") def like(key: Field, value: str) -> frappe.qb: @@ -41,6 +49,8 @@ def func_in(key: Field, value: list | tuple) -> frappe.qb: Returns: frappe.qb: `frappe.qb object with `IN` """ + if isinstance(value, str): + value = value.split(",") return key.isin(value) @@ -57,7 +67,7 @@ def not_like(key: Field, value: str) -> frappe.qb: return key.not_like(value) -def func_not_in(key: Field, value: list | tuple): +def func_not_in(key: Field, value: list | tuple | str): """Wrapper method for `NOT IN` Args: @@ -67,6 +77,8 @@ def func_not_in(key: Field, value: list | tuple): Returns: frappe.qb: `frappe.qb object with `NOT IN` """ + if isinstance(value, str): + value = value.split(",") return key.notin(value) @@ -115,19 +127,6 @@ def func_timespan(key: Field, value: str) -> frappe.qb: return func_between(key, get_timespan_date_range(value)) -def make_function(key: Any, value: int | str): - """returns fucntion query - - Args: - key (Any): field - value (Union[int, str]): criterion - - Returns: - frappe.qb: frappe.qb object - """ - return OPERATOR_MAP[value[0].casefold()](key, value[1]) - - def change_orderby(order: str): """Convert orderby to standart Order object @@ -155,6 +154,18 @@ def literal_eval_(literal): return literal +def has_function(field): + _field = field.casefold() if (isinstance(field, str) and "`" not in field) else field + if not issubclass(type(_field), Criterion): + if any([f"{func}(" in _field for func in SQL_FUNCTIONS]) or "(" in _field: + return True + + +def table_from_string(table: str) -> "DocType": + table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:] + return frappe.qb.DocType(table_name=table_name.replace("`", "")) + + # default operators OPERATOR_MAP: dict[str, Callable] = { "+": operator.add, @@ -297,7 +308,7 @@ class Engine: else: _operator = self.OPERATOR_MAP[filters[1].casefold()] if not isinstance(filters[0], str): - conditions = make_function(filters[0], filters[2]) + conditions = self.make_function_for_filters(filters[0], filters[2]) break conditions = conditions.where(_operator(Field(filters[0]), filters[2])) break @@ -315,6 +326,8 @@ class Engine: frappe.qb: conditions object """ conditions = self.get_condition(table, **kwargs) + if isinstance(table, str): + table = frappe.qb.DocType(table) if not filters: conditions = self.add_conditions(conditions, **kwargs) return conditions @@ -323,20 +336,23 @@ class Engine: if isinstance(value, bool): filters.update({key: str(int(value))}) + filters = { + (self.get_function_object(k) if has_function(k) else k): v for k, v in filters.items() + } for key in filters: value = filters.get(key) _operator = self.OPERATOR_MAP["="] if not isinstance(key, str): - conditions = conditions.where(make_function(key, value)) + conditions = conditions.where(self.make_function_for_filters(key, value)) continue if isinstance(value, (list, tuple)): _operator = self.OPERATOR_MAP[value[0].casefold()] _value = value[1] if value[1] else ("",) - conditions = conditions.where(_operator(Field(key), _value)) + conditions = conditions.where(_operator(getattr(table, key), _value)) else: if value is not None: - conditions = conditions.where(_operator(Field(key), value)) + conditions = conditions.where(_operator(getattr(table, key), value)) else: _table = conditions._from[0] field = getattr(_table, key) @@ -370,6 +386,12 @@ class Engine: return criterion + def make_function_for_filters(self, key, value: int | str): + value = list(value) + if isinstance(value[1], str) and has_function(value[1]): + value[1] = self.get_function_object(value[1]) + return OPERATOR_MAP[value[0].casefold()](key, value[1]) + def get_function_object(self, field: str) -> "Function": """Expects field to look like 'SUM(*)' or 'name' or something similar. Returns PyPika Function object""" func = field.split("(", maxsplit=1)[0].capitalize() @@ -395,12 +417,21 @@ class Engine: *map(lambda field: Field(field.strip()), arg.split(_operator)), ) - field = Field(initial_fields) if not has_primitive_operator else field + field = ( + (Field(initial_fields) if "`" not in initial_fields else PseudoColumn(initial_fields)) + if not has_primitive_operator + else field + ) else: field = initial_fields _args.append(field) + + if alias and "`" in alias: + alias = alias.replace("`", "") try: + if func.casefold() == "now": + return getattr(functions, func)() return getattr(functions, func)(*_args, alias=alias or None) except AttributeError: # Fall back for functions not present in `SqlFunctions`` @@ -413,7 +444,7 @@ class Engine: def function_objects_from_list(self, fields): functions = [] for field in fields: - field = field.casefold() if isinstance(field, str) else field + field = field.casefold() if (isinstance(field, str) and "`" not in field) else field if not issubclass(type(field), Criterion): if any([f"{func}(" in field for func in SQL_FUNCTIONS]) or "(" in field: functions.append(field) @@ -422,11 +453,20 @@ class Engine: def remove_string_functions(self, fields, function_objects): """Remove string functions from fields which have already been converted to function objects""" + + def _remove_string_aliasing(function, fields: list | str): + if function.alias: + to_replace = " as " + function.alias.casefold() + if to_replace in fields: + fields = fields.replace(to_replace, "") + elif " as " + f"`{function.alias.casefold()}" in fields: + fields = fields.replace(" as " + f"`{function.alias.casefold()}`", "") + return fields + for function in function_objects: if isinstance(fields, str): - if function.alias: - fields = fields.replace(" as " + function.alias.casefold(), "") - fields = BRACKETS_PATTERN.sub("", fields.replace(function.name.casefold(), "")) + fields = _remove_string_aliasing(function, fields) + fields = BRACKETS_PATTERN.sub("", re.sub(function.name, "", fields, flags=re.IGNORECASE)) # Check if only comma is left in fields after stripping functions. if "," in fields and (len(fields.strip()) == 1): fields = "" @@ -434,24 +474,47 @@ class Engine: updated_fields = [] for field in fields: if isinstance(field, str): - if function.alias: - field = field.replace(" as " + function.alias.casefold(), "") - field = ( - BRACKETS_PATTERN.sub("", field).strip().casefold().replace(function.name.casefold(), "") + field = _remove_string_aliasing(function, field) + substituted_string = ( + BRACKETS_PATTERN.sub("", field).strip().casefold() + if "`" not in field + else BRACKETS_PATTERN.sub("", field).strip() ) - updated_fields.append(field) + # This is done to avoid casefold of table name. + if substituted_string.casefold() == function.name.casefold(): + replaced_string = substituted_string.casefold().replace(function.name.casefold(), "") + else: + replaced_string = substituted_string.replace(function.name.casefold(), "") + updated_fields.append(replaced_string) + fields = [field for field in updated_fields if field] + return fields - fields = [field for field in updated_fields if field] + def get_fieldnames_from_child_table(self, doctype, fields): + # Hacky and flaky implementation of implicit joins. + # convert child_table.fieldname to `tabChild DocType`.`fieldname` + for idx, field in enumerate(fields, start=0): + if "." in field and "tab" not in field: + alias = None + if " as " in field: + field, alias = field.split(" as ") + self.fieldname, linked_fieldname = field.split(".") + linked_field = frappe.get_meta(doctype, cached=True).get_field(self.fieldname) + try: + self.linked_doctype = linked_field.options + except AttributeError: + return fields + field = f"`tab{self.linked_doctype}`.`{linked_fieldname}`" + if alias: + field = f"{field} as {alias}" + fields[idx] = field return fields - def set_fields(self, fields, **kwargs): + def set_fields(self, table, fields, **kwargs) -> list: fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" if isinstance(fields, list) and None in fields and Field not in fields: return None - function_objects = [] - is_list = isinstance(fields, (list, tuple, set)) if is_list and len(fields) == 1: fields = fields[0] @@ -462,7 +525,7 @@ class Engine: is_str = isinstance(fields, str) if is_str: - fields = fields.casefold() + fields = fields.casefold() if "`" not in fields else fields function_objects += self.function_objects_from_string(fields=fields) fields = self.remove_string_functions(fields, function_objects) @@ -474,9 +537,14 @@ class Engine: if is_str: if fields == "*": return fields - if " as " in fields: - fields, reference = fields.split(" as ") - fields = Field(fields).as_(reference) + if "`" in fields: + fields = PseudoColumn(fields) + if " as " in str(fields): + fields, reference = str(fields).split(" as ") + if "`" in str(fields): + fields = PseudoColumn(f"{fields} as {reference}") + else: + fields = Field(fields).as_(reference) if not is_str and fields: if issubclass(type(fields), Criterion): @@ -484,15 +552,22 @@ class Engine: updated_fields = [] if "*" in fields: return fields + # fields = self.get_fieldnames_from_child_table(doctype=table, fields=fields) for field in fields: if not isinstance(field, Criterion) and field: if " as " in field: field, reference = field.split(" as ") - updated_fields.append(Field(field.strip()).as_(reference)) + if "`" in field: + updated_fields.append(PseudoColumn(f"{field} as {reference}")) + else: + updated_fields.append(Field(field.strip()).as_(reference)) + + elif "`" in str(field): + updated_fields.append(PseudoColumn(field.strip())) else: updated_fields.append(Field(field)) - fields = updated_fields + fields = updated_fields # Need to check instance again since fields modified. if not isinstance(fields, (list, tuple, set)): @@ -501,27 +576,65 @@ class Engine: fields.extend(function_objects) return fields + def join_(self, criterion, fields, table, join): + """Handles all join operations on criterion objects""" + has_join = False + if not isinstance(fields, Criterion): + for field in fields: + # Only perform this bit if foreign doctype in fields + if ( + not is_pypika_function_object(field) + and str(field).startswith("`tab") + and (f"`tab{table}`" not in str(field)) + ): + join_table = table_from_string(str(field)) + if self.fieldname: + criterion = criterion.left_join(join_table).on( + getattr(join_table, "name") == getattr(frappe.qb.DocType(table), self.fieldname) + ) + else: + criterion = criterion.left_join(join_table).on( + getattr(join_table, "parent") == getattr(frappe.qb.DocType(table), "name") + ) + has_join = True + + if has_join: + for idx, field in enumerate(fields): + if not is_pypika_function_object(field): + field = field if isinstance(field, str) else field.get_sql() + if not TABLE_PATTERN.search(str(field)): + fields[idx] = getattr(frappe.qb.DocType(table), field) + else: + field.args = [getattr(frappe.qb.DocType(table), arg.get_sql()) for arg in field.args] + field.args[0] = getattr(frappe.qb.DocType(table), field.args[0].get_sql()) + fields[idx] = field + + if len(self.tables) > 1: + primary_table = self.tables.pop(table) + for table_object in self.tables.values(): + criterion = getattr(criterion, join)(table_object).on( + table_object.parent == primary_table.name + ) + has_join = True + + return criterion, fields + def get_query( self, table: str, fields: list | tuple, filters: dict[str, str | int] | str | int | list[list | str | int] = None, **kwargs, - ): + ) -> MySQLQueryBuilder | PostgreSQLQueryBuilder: # Clean up state before each query self.tables = {} + self.linked_doctype = None + self.fieldname = None + + fields = self.set_fields(table, kwargs.get("field_objects") or fields, **kwargs) criterion = self.build_conditions(table, filters, **kwargs) - fields = self.set_fields(kwargs.get("field_objects") or fields, **kwargs) - join = kwargs.get("join").replace(" ", "_") if kwargs.get("join") else "left_join" - - if len(self.tables) > 1: - primary_table = self.tables[table] - del self.tables[table] - for table_object in self.tables.values(): - criterion = getattr(criterion, join)(table_object).on( - table_object.parent == primary_table.name - ) + criterion, fields = self.join_(criterion=criterion, fields=fields, table=table, join=join) if isinstance(fields, (list, tuple)): query = criterion.select(*fields) diff --git a/frappe/database/utils.py b/frappe/database/utils.py index c4d8cb4953..c1f70d388e 100644 --- a/frappe/database/utils.py +++ b/frappe/database/utils.py @@ -1,11 +1,16 @@ # Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors # License: MIT. See LICENSE +import typing from functools import cached_property from types import NoneType import frappe from frappe.query_builder.builder import MariaDB, Postgres +from frappe.query_builder.functions import Function + +if typing.TYPE_CHECKING: + from frappe.query_builder import DocType Query = str | MariaDB | Postgres QueryValues = tuple | list | dict | NoneType @@ -18,6 +23,10 @@ def is_query_type(query: str, query_type: str | tuple[str]) -> bool: return query.lstrip().split(maxsplit=1)[0].lower().startswith(query_type) +def is_pypika_function_object(field: str) -> bool: + return getattr(field, "__module__", None) == "pypika.functions" or isinstance(field, Function) + + class LazyString: def _setup(self) -> None: raise NotImplementedError diff --git a/frappe/integrations/doctype/webhook/__init__.py b/frappe/integrations/doctype/webhook/__init__.py index 192cd2fa12..b9c96190ca 100644 --- a/frappe/integrations/doctype/webhook/__init__.py +++ b/frappe/integrations/doctype/webhook/__init__.py @@ -25,7 +25,7 @@ def run_webhooks(doc, method): # query webhooks webhooks_list = frappe.get_all( "Webhook", - fields=["name", "`condition`", "webhook_docevent", "webhook_doctype"], + fields=["name", "condition", "webhook_docevent", "webhook_doctype"], filters={"enabled": True}, ) diff --git a/frappe/query_builder/functions.py b/frappe/query_builder/functions.py index 824de7fbf5..e725dff828 100644 --- a/frappe/query_builder/functions.py +++ b/frappe/query_builder/functions.py @@ -17,9 +17,21 @@ class Concat_ws(Function): class Locate(Function): def __init__(self, *terms, **kwargs): + terms = list(terms) + if not isinstance(terms[0], str): + terms[0] = terms[0].get_sql() super().__init__("LOCATE", *terms, **kwargs) +class Ifnull(IfNull): + def __init__(self, condition, term, **kwargs): + if not isinstance(condition, str): + condition = condition.get_sql() + if not isinstance(term, str): + term = term.get_sql() + super().__init__(condition, term, **kwargs) + + class Timestamp(Function): def __init__(self, term: str, time=None, alias=None): if time: @@ -105,6 +117,7 @@ class SqlFunctions(Enum): Min = "min" Abs = "abs" Timestamp = "timestamp" + IfNull = "ifnull" def _max(dt, fieldname, filters=None, **kwargs): diff --git a/frappe/share.py b/frappe/share.py index b142a5060e..711693cd50 100644 --- a/frappe/share.py +++ b/frappe/share.py @@ -104,12 +104,12 @@ def get_users(doctype, name): return frappe.get_all( "DocShare", fields=[ - "`name`", - "`user`", - "`read`", - "`write`", - "`submit`", - "`share`", + "name", + "user", + "read", + "write", + "submit", + "share", "everyone", "owner", "creation", diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index 254489d281..ae9513a302 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -1,6 +1,6 @@ import frappe from frappe.query_builder import Field -from frappe.query_builder.functions import Abs, Count, Max, Timestamp +from frappe.query_builder.functions import Abs, Count, Ifnull, Max, Now, Timestamp from frappe.tests.test_query_builder import db_type_is, run_only_if from frappe.tests.utils import FrappeTestCase @@ -20,6 +20,7 @@ class TestQuery(FrappeTestCase): "SELECT * FROM `tabDocType` LEFT JOIN `tabBOM Update Log` ON `tabBOM Update Log`.`parent`=`tabDocType`.`name` WHERE `tabBOM Update Log`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'", ) + @run_only_if(db_type_is.MARIADB) def test_string_fields(self): self.assertEqual( frappe.qb.engine.get_query( @@ -30,10 +31,9 @@ class TestQuery(FrappeTestCase): .where(Field("name") == "Administrator") .get_sql(), ) - self.assertEqual( frappe.qb.engine.get_query( - "User", fields=["name, email"], filters={"name": "Administrator"} + "User", fields=["`name`, `email`"], filters={"name": "Administrator"} ).get_sql(), frappe.qb.from_("User") .select(Field("name"), Field("email")) @@ -41,6 +41,60 @@ class TestQuery(FrappeTestCase): .get_sql(), ) + self.assertEqual( + frappe.qb.engine.get_query( + "User", fields=["`tabUser`.`name`", "`tabUser`.`email`"], filters={"name": "Administrator"} + ).run(), + frappe.qb.from_("User") + .select(Field("name"), Field("email")) + .where(Field("name") == "Administrator") + .run(), + ) + + self.assertEqual( + frappe.qb.engine.get_query( + "User", + fields=["`tabUser`.`name` as owner", "`tabUser`.`email`"], + filters={"name": "Administrator"}, + ).run(as_dict=1), + frappe.qb.from_("User") + .select(Field("name").as_("owner"), Field("email")) + .where(Field("name") == "Administrator") + .run(as_dict=1), + ) + + self.assertEqual( + frappe.qb.engine.get_query( + "User", fields=["`tabUser`.`name`, Count(`name`) as count"], filters={"name": "Administrator"} + ).run(), + frappe.qb.from_("User") + .select(Field("name"), Count("name").as_("count")) + .where(Field("name") == "Administrator") + .run(), + ) + + self.assertEqual( + frappe.qb.engine.get_query( + "User", + fields=["`tabUser`.`name`, Count(`name`) as `count`"], + filters={"name": "Administrator"}, + ).run(), + frappe.qb.from_("User") + .select(Field("name"), Count("name").as_("count")) + .where(Field("name") == "Administrator") + .run(), + ) + + self.assertEqual( + frappe.qb.engine.get_query( + "User", fields="`tabUser`.`name`, Count(`name`) as `count`", filters={"name": "Administrator"} + ).run(), + frappe.qb.from_("User") + .select(Field("name"), Count("name").as_("count")) + .where(Field("name") == "Administrator") + .run(), + ) + def test_functions_fields(self): self.assertEqual( frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(), @@ -124,3 +178,26 @@ class TestQuery(FrappeTestCase): .select(user_doctype.email.as_("id"), Count(Field("name")).as_("count")) .get_sql(), ) + + @run_only_if(db_type_is.MARIADB) + def test_filters(self): + self.assertEqual( + frappe.qb.engine.get_query( + "User", filters={"IfNull(name, " ")": ("<", Now())}, fields=["Max(name)"] + ).run(), + frappe.qb.from_("User").select(Max(Field("name"))).where(Ifnull("name", "") < Now()).run(), + ) + + def test_implicit_join_query(self): + self.assertEqual( + frappe.qb.engine.get_query( + "Note", + filters={"name": "Test Note Title"}, + fields=["name", "`tabNote Seen By`.`user` as seen_by"], + ).run(as_dict=1), + frappe.get_list( + "Note", + filters={"name": "Test Note Title"}, + fields=["name", "`tabNote Seen By`.`user` as seen_by"], + ), + ) diff --git a/frappe/workflow/doctype/workflow_action/workflow_action.py b/frappe/workflow/doctype/workflow_action/workflow_action.py index 038a3021d2..545ad6ec77 100644 --- a/frappe/workflow/doctype/workflow_action/workflow_action.py +++ b/frappe/workflow/doctype/workflow_action/workflow_action.py @@ -290,7 +290,7 @@ def update_completed_workflow_actions_using_user(doc, user=None): def get_next_possible_transitions(workflow_name, state, doc=None): transitions = frappe.get_all( "Workflow Transition", - fields=["allowed", "action", "state", "allow_self_approval", "next_state", "`condition`"], + fields=["allowed", "action", "state", "allow_self_approval", "next_state", "condition"], filters=[["parent", "=", workflow_name], ["state", "=", state]], )