diff --git a/frappe/__init__.py b/frappe/__init__.py index 063a62c84b..0f55854535 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -22,7 +22,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import click from werkzeug.local import Local, release_local -from frappe.query_builder import get_query_builder, patch_query_aggregation, patch_query_execute +from frappe.query_builder import ( + get_qb_engine, + get_query_builder, + patch_query_aggregation, + patch_query_execute, +) from frappe.utils.caching import request_cache from frappe.utils.data import cstr, sbool @@ -240,7 +245,7 @@ def init(site, sites_path=None, new_site=False): local.session = _dict() local.dev_server = _dev_server local.qb = get_query_builder(local.conf.db_type or "mariadb") - + local.qb.engine = get_qb_engine() setup_module_map() if not _qb_patched.get(local.conf.db_type): diff --git a/frappe/database/database.py b/frappe/database/database.py index c68368ba38..2986414af6 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -12,7 +12,7 @@ from contextlib import contextmanager from time import time from typing import Dict, List, Optional, Tuple, Union -from pypika.terms import Criterion, NullValue, PseudoColumn +from pypika.terms import Criterion, NullValue import frappe import frappe.defaults @@ -75,15 +75,6 @@ class Database(object): self.password = password or frappe.conf.db_password self.value_cache = {} - @property - def query(self): - if not hasattr(self, "_query"): - from .query import Query - - self._query = Query() - del Query - return self._query - def setup_type_map(self): pass @@ -600,7 +591,7 @@ class Database(object): return [map(values.get, fields)] else: - r = self.query.get_sql( + r = frappe.qb.engine.get_query( "Singles", filters={"field": ("in", tuple(fields)), "doctype": doctype}, fields=["field", "value"], @@ -633,7 +624,7 @@ class Database(object): # Get coulmn and value of the single doctype Accounts Settings account_settings = frappe.db.get_singles_dict("Accounts Settings") """ - queried_result = self.query.get_sql( + queried_result = frappe.qb.engine.get_query( "Singles", filters={"doctype": doctype}, fields=["field", "value"], @@ -706,7 +697,7 @@ class Database(object): if cache and fieldname in self.value_cache[doctype]: return self.value_cache[doctype][fieldname] - val = self.query.get_sql( + val = frappe.qb.engine.get_query( table="Singles", filters={"doctype": doctype, "field": fieldname}, fields="value", @@ -748,14 +739,7 @@ class Database(object): ): field_objects = [] - if not isinstance(fields, Criterion): - for field in fields: - if "(" in str(field) or " as " in str(field): - field_objects.append(PseudoColumn(field)) - else: - field_objects.append(field) - - query = self.query.get_sql( + query = frappe.qb.engine.get_query( table=doctype, filters=filters, orderby=order_by, @@ -865,7 +849,7 @@ class Database(object): frappe.clear_document_cache(dt, docname) else: - query = self.query.build_conditions(table=dt, filters=dn, update=True) + query = frappe.qb.engine.build_conditions(table=dt, filters=dn, update=True) # TODO: Fix this; doesn't work rn - gavin@frappe.io # frappe.cache().hdel_keys(dt, "document_cache") # Workaround: clear all document caches @@ -1066,7 +1050,7 @@ class Database(object): cache_count = frappe.cache().get_value("doctype:count:{}".format(dt)) if cache_count is not None: return cache_count - query = self.query.get_sql(table=dt, filters=filters, fields=Count("*"), distinct=distinct) + query = frappe.qb.engine.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct) count = self.sql(query, debug=debug)[0][0] if not filters and cache: frappe.cache().set_value("doctype:count:{}".format(dt), count, expires_in_sec=86400) @@ -1206,7 +1190,7 @@ class Database(object): Doctype name can be passed directly, it will be pre-pended with `tab`. """ filters = filters or kwargs.get("conditions") - query = self.query.build_conditions(table=doctype, filters=filters).delete() + query = frappe.qb.engine.build_conditions(table=doctype, filters=filters).delete() if "debug" not in kwargs: kwargs["debug"] = debug return query.run(**kwargs) diff --git a/frappe/database/query.py b/frappe/database/query.py index f7cc143cf7..60eeaee8c2 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -1,16 +1,23 @@ import operator import re +from ast import literal_eval from functools import cached_property -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union import frappe from frappe import _ from frappe.boot import get_additional_filters_from_hooks from frappe.model.db_query import get_timespan_date_range -from frappe.query_builder import Criterion, Field, Order, Table +from frappe.query_builder import Criterion, Field, Order, Table, functions +from frappe.query_builder.functions import SqlFunctions TAB_PATTERN = re.compile("^tab") WORDS_PATTERN = re.compile(r"\w+") +BRACKETS_PATTERN = re.compile(r"\(.*?\)|$") +SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions] + +if TYPE_CHECKING: + from pypika.functions import Function def like(key: Field, value: str) -> frappe.qb: @@ -93,7 +100,7 @@ def func_between(key: Field, value: Union[List, Tuple]) -> frappe.qb: def func_is(key, value): "Wrapper for IS" - return Field(key).isnotnull() if value.lower() == "set" else Field(key).isnull() + return key.isnotnull() if value.lower() == "set" else key.isnull() def func_timespan(key: Field, value: str) -> frappe.qb: @@ -143,6 +150,13 @@ def change_orderby(order: str): return order[0], Order.desc +def literal_eval_(literal): + try: + return literal_eval(literal) + except (ValueError, SyntaxError): + return literal + + # default operators OPERATOR_MAP: Dict[str, Callable] = { "+": operator.add, @@ -168,7 +182,7 @@ OPERATOR_MAP: Dict[str, Callable] = { } -class Query: +class Engine: tables: dict = {} @cached_property @@ -238,7 +252,7 @@ class Query: Returns: conditions (frappe.qb): frappe.qb object """ - if kwargs.get("orderby"): + if kwargs.get("orderby") and kwargs.get("orderby") != "KEEP_DEFAULT_ORDERING": orderby = kwargs.get("orderby") if isinstance(orderby, str) and len(orderby.split()) > 1: for ordby in orderby.split(","): @@ -250,6 +264,7 @@ class Query: if kwargs.get("limit"): conditions = conditions.limit(kwargs.get("limit")) + conditions = conditions.offset(kwargs.get("offset", 0)) if kwargs.get("distinct"): conditions = conditions.distinct() @@ -257,6 +272,9 @@ class Query: if kwargs.get("for_update"): conditions = conditions.for_update() + if kwargs.get("groupby"): + conditions = conditions.groupby(kwargs.get("groupby")) + return conditions def misc_query(self, table: str, filters: Union[List, Tuple] = None, **kwargs): @@ -308,6 +326,10 @@ class Query: conditions = self.add_conditions(conditions, **kwargs) return conditions + for key, value in filters.items(): + if isinstance(value, bool): + filters.update({key: str(int(value))}) + for key in filters: value = filters.get(key) _operator = self.OPERATOR_MAP["="] @@ -317,7 +339,8 @@ class Query: continue if isinstance(value, (list, tuple)): _operator = self.OPERATOR_MAP[value[0].casefold()] - conditions = conditions.where(_operator(Field(key), value[1])) + _value = value[1] if value[1] else ("",) + conditions = conditions.where(_operator(Field(key), _value)) else: if value is not None: conditions = conditions.where(_operator(Field(key), value)) @@ -354,7 +377,117 @@ class Query: return criterion - def get_sql( + def get_function_object(self, field: str) -> "Function": + """Expects field to look like 'SUM(*)' or 'name' or something similar. Returns PyPika Function object""" + func = field.split("(", maxsplit=1)[0].capitalize() + args_start, args_end = len(func) + 1, field.index(")") + args = field[args_start:args_end].split(",") + + to_cast = "*" not in args + _args = [] + + for arg in args: + field = literal_eval_(arg.strip()) + if to_cast: + field = Field(field) + _args.append(field) + + return getattr(functions, func)(*_args) + + def function_objects_from_string(self, fields): + functions = "" + for func in SQL_FUNCTIONS: + if f"{func}(" in fields: + functions = str(func) + str(BRACKETS_PATTERN.search(fields).group()) + return [self.get_function_object(functions)] + if not functions: + return [] + + def function_objects_from_list(self, fields): + functions = [] + for field in fields: + field = field.casefold() if isinstance(field, str) else field + if not issubclass(type(field), Criterion): + if any([func in field and f"{func}(" in field for func in SQL_FUNCTIONS]): + functions.append(field) + return [self.get_function_object(function) for function in functions] + + def remove_string_functions(self, fields, function_objects): + """Remove string functions from fields which have already been converted to function objects""" + for function in function_objects: + if isinstance(fields, str): + fields = BRACKETS_PATTERN.sub("", fields.replace(function.name.casefold(), "")) + else: + updated_fields = [] + for field in fields: + if isinstance(field, str): + updated_fields.append( + BRACKETS_PATTERN.sub("", field).strip().casefold().replace(function.name.casefold(), "") + ) + else: + updated_fields.append(field) + + fields = [field for field in updated_fields if field] + + return fields + + def set_fields(self, fields, **kwargs): + fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" + if isinstance(fields, list) and None in fields and Field not in fields: + return None + + function_objects = [] + + is_list = isinstance(fields, (list, tuple, set)) + if is_list and len(fields) == 1: + fields = fields[0] + is_list = False + + if is_list: + function_objects += self.function_objects_from_list(fields=fields) + + is_str = isinstance(fields, str) + if is_str: + fields = fields.casefold() + function_objects += self.function_objects_from_string(fields=fields) + + fields = self.remove_string_functions(fields, function_objects) + + if is_str and "," in fields: + fields = [field.replace(" ", "") if "as" not in field else field for field in fields.split(",")] + is_list, is_str = True, False + + if is_str: + if fields == "*": + return fields + if " as " in fields: + fields, reference = fields.split(" as ") + fields = Field(fields).as_(reference) + + if not is_str and fields: + if issubclass(type(fields), Criterion): + return fields + updated_fields = [] + if "*" in fields: + return fields + for field in fields: + if not isinstance(field, Criterion) and field: + if " as " in field: + field, reference = field.split(" as ") + updated_fields.append(Field(field.strip()).as_(reference)) + else: + updated_fields.append(Field(field)) + + fields = updated_fields + + # Need to check instance again since fields modified. + if not isinstance(fields, (list, tuple, set)): + fields = [fields] if fields else [] + + fields.extend(function_objects) + return fields + + def get_query( self, table: str, fields: Union[List, Tuple], @@ -364,15 +497,20 @@ class Query: # Clean up state before each query self.tables = {} criterion = self.build_conditions(table, filters, **kwargs) + fields = self.set_fields(kwargs.get("field_objects") or fields, **kwargs) + + join = kwargs.get("join").replace(" ", "_") if kwargs.get("join") else "left_join" if len(self.tables) > 1: primary_table = self.tables[table] del self.tables[table] for table_object in self.tables.values(): - criterion = criterion.left_join(table_object).on(table_object.parent == primary_table.name) + criterion = getattr(criterion, join)(table_object).on( + table_object.parent == primary_table.name + ) if isinstance(fields, (list, tuple)): - query = criterion.select(*kwargs.get("field_objects", fields)) + query = criterion.select(*fields) elif isinstance(fields, Criterion): query = criterion.select(fields) diff --git a/frappe/desk/doctype/number_card/number_card.py b/frappe/desk/doctype/number_card/number_card.py index 74c8e9eb99..8d031aac01 100644 --- a/frappe/desk/doctype/number_card/number_card.py +++ b/frappe/desk/doctype/number_card/number_card.py @@ -204,7 +204,7 @@ def get_cards_for_user(doctype, txt, searchfield, start, page_len, filters): if txt: search_conditions = [numberCard[field].like("%{txt}%".format(txt=txt)) for field in searchfields] - condition_query = frappe.db.query.build_conditions(doctype, filters) + condition_query = frappe.qb.engine.build_conditions(doctype, filters) return ( condition_query.select(numberCard.name, numberCard.label, numberCard.document_type) diff --git a/frappe/desk/listview.py b/frappe/desk/listview.py index 5149f8bf86..11c985d1ff 100644 --- a/frappe/desk/listview.py +++ b/frappe/desk/listview.py @@ -37,7 +37,7 @@ def get_group_by_count(doctype: str, current_filters: str, field: str) -> List[D ToDo = DocType("ToDo") User = DocType("User") count = Count("*").as_("count") - filtered_records = frappe.db.query.build_conditions(doctype, current_filters).select("name") + filtered_records = frappe.qb.engine.build_conditions(doctype, current_filters).select("name") return ( frappe.qb.from_(ToDo) diff --git a/frappe/query_builder/__init__.py b/frappe/query_builder/__init__.py index 1bf9ec97d9..eb1d9df08f 100644 --- a/frappe/query_builder/__init__.py +++ b/frappe/query_builder/__init__.py @@ -7,6 +7,7 @@ from frappe.query_builder.terms import ParameterizedFunction, ParameterizedValue from frappe.query_builder.utils import ( Column, DocType, + get_qb_engine, get_query_builder, patch_query_aggregation, patch_query_execute, diff --git a/frappe/query_builder/builder.py b/frappe/query_builder/builder.py index d2fdeab324..c23d76974c 100644 --- a/frappe/query_builder/builder.py +++ b/frappe/query_builder/builder.py @@ -1,3 +1,5 @@ +import typing + from pypika import MySQLQuery, Order, PostgreSQLQuery, terms from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder from pypika.queries import QueryBuilder, Schema, Table @@ -13,6 +15,13 @@ class Base: Schema = Schema Table = Table + # Added dynamic type hints for engine attribute + # which is to be assigned later. + if typing.TYPE_CHECKING: + from frappe.database.query import Engine + + engine: Engine + @staticmethod def functions(name: str, *args, **kwargs) -> Function: return Function(name, *args, **kwargs) diff --git a/frappe/query_builder/functions.py b/frappe/query_builder/functions.py index f03c139f57..d2debd6da1 100644 --- a/frappe/query_builder/functions.py +++ b/frappe/query_builder/functions.py @@ -1,8 +1,9 @@ +from enum import Enum + from pypika.functions import * from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function import frappe -from frappe.database.query import Query from frappe.query_builder.custom import GROUP_CONCAT, MATCH, STRING_AGG, TO_TSVECTOR from frappe.query_builder.utils import ImportMapper, db_type_is @@ -14,6 +15,11 @@ class Concat_ws(Function): super(Concat_ws, self).__init__("CONCAT_WS", *terms, **kwargs) +class Locate(Function): + def __init__(self, *terms, **kwargs): + super(Locate, self).__init__("LOCATE", *terms, **kwargs) + + GroupConcat = ImportMapper({db_type_is.MARIADB: GROUP_CONCAT, db_type_is.POSTGRES: STRING_AGG}) Match = ImportMapper({db_type_is.MARIADB: MATCH, db_type_is.POSTGRES: TO_TSVECTOR}) @@ -73,14 +79,24 @@ class Cast_(Function): def _aggregate(function, dt, fieldname, filters, **kwargs): return ( - Query() - .build_conditions(dt, filters) + frappe.qb.engine.build_conditions(dt, filters) .select(function(PseudoColumn(fieldname))) .run(**kwargs)[0][0] or 0 ) +class SqlFunctions(Enum): + DayOfYear = "dayofyear" + Extract = "extract" + Locate = "locate" + Count = "count" + Sum = "sum" + Avg = "avg" + Max = "max" + Min = "min" + + def _max(dt, fieldname, filters=None, **kwargs): return _aggregate(Max, dt, fieldname, filters, **kwargs) diff --git a/frappe/query_builder/utils.py b/frappe/query_builder/utils.py index 10bab38a63..b601665dee 100644 --- a/frappe/query_builder/utils.py +++ b/frappe/query_builder/utils.py @@ -45,6 +45,12 @@ def get_query_builder(type_of_db: str) -> Union[Postgres, MariaDB]: return picks[db] +def get_qb_engine(): + from frappe.database.query import Engine + + return Engine() + + def get_attr(method_string): modulename = ".".join(method_string.split(".")[:-1]) methodname = method_string.split(".")[-1] diff --git a/frappe/tests/test_db_query.py b/frappe/tests/test_db_query.py index c1b2e05266..8727951f4a 100644 --- a/frappe/tests/test_db_query.py +++ b/frappe/tests/test_db_query.py @@ -143,7 +143,7 @@ class TestReportview(unittest.TestCase): ) def test_none_filter(self): - query = frappe.db.query.get_sql("DocType", fields="name", filters={"restrict_to_domain": None}) + query = frappe.qb.engine.get_query("DocType", fields="name", filters={"restrict_to_domain": None}) sql = str(query).replace("`", "").replace('"', "") condition = "restrict_to_domain IS NULL" self.assertIn(condition, sql) diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index 949c3e9433..e7682a0d0c 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -1,14 +1,15 @@ import unittest import frappe +from frappe.query_builder import Field from frappe.tests.test_query_builder import db_type_is, run_only_if -@run_only_if(db_type_is.MARIADB) class TestQuery(unittest.TestCase): + @run_only_if(db_type_is.MARIADB) def test_multiple_tables_in_filters(self): self.assertEqual( - frappe.db.query.get_sql( + frappe.qb.engine.get_query( "DocType", ["*"], [ @@ -18,3 +19,56 @@ class TestQuery(unittest.TestCase): ).get_sql(), "SELECT * FROM `tabDocType` LEFT JOIN `tabBOM Update Log` ON `tabBOM Update Log`.`parent`=`tabDocType`.`name` WHERE `tabBOM Update Log`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'", ) + + def test_string_fields(self): + self.assertEqual( + frappe.qb.engine.get_query( + "User", fields="name, email", filters={"name": "Administrator"} + ).get_sql(), + frappe.qb.from_("User") + .select(Field("name"), Field("email")) + .where(Field("name") == "Administrator") + .get_sql(), + ) + + self.assertEqual( + frappe.qb.engine.get_query( + "User", fields=["name, email"], filters={"name": "Administrator"} + ).get_sql(), + frappe.qb.from_("User") + .select(Field("name"), Field("email")) + .where(Field("name") == "Administrator") + .get_sql(), + ) + + def test_functions_fields(self): + from frappe.query_builder.functions import Count, Max + + self.assertEqual( + frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(), + frappe.qb.from_("User").select(Count(Field("name"))).get_sql(), + ) + + self.assertEqual( + frappe.qb.engine.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(), + frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))).get_sql(), + ) + + self.assertEqual( + frappe.qb.engine.get_query("User", fields=[Count("*")], filters={}).get_sql(), + frappe.qb.from_("User").select(Count("*")).get_sql(), + ) + + def test_qb_fields(self): + user_doctype = frappe.qb.DocType("User") + self.assertEqual( + frappe.qb.engine.get_query( + user_doctype, fields=[user_doctype.name, user_doctype.email], filters={} + ).get_sql(), + frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email).get_sql(), + ) + + self.assertEqual( + frappe.qb.engine.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(), + frappe.qb.from_(user_doctype).select(user_doctype.email).get_sql(), + ) diff --git a/frappe/utils/goal.py b/frappe/utils/goal.py index fb348496da..9273b83cbf 100644 --- a/frappe/utils/goal.py +++ b/frappe/utils/goal.py @@ -25,7 +25,7 @@ def get_monthly_results( date_format = "%m-%Y" if frappe.db.db_type != "postgres" else "MM-YYYY" return dict( - frappe.db.query.build_conditions(table=goal_doctype, filters=filters) + frappe.qb.engine.build_conditions(table=goal_doctype, filters=filters) .select( DateFormat(Table[date_col], date_format).as_("month_year"), Function(aggregation, goal_field),