From 8cf6d42d173e4b7d6d7c2966fbfc2dadc918e72d Mon Sep 17 00:00:00 2001 From: Aradhya Date: Fri, 10 Jun 2022 08:33:03 +0530 Subject: [PATCH 01/10] feat: Trying to replace db_query operations directly with qb --- frappe/model/db_query.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/frappe/model/db_query.py b/frappe/model/db_query.py index fe52818235..6ac3402235 100644 --- a/frappe/model/db_query.py +++ b/frappe/model/db_query.py @@ -89,6 +89,9 @@ class DatabaseQuery(object): ignore_ddl=False, parent_doctype=None, ) -> List: + from frappe.database.query import Query + + self.query = Query() if ( not ignore_permissions @@ -112,6 +115,8 @@ class DatabaseQuery(object): # if `filters` is a list of strings, its probably fields filters, fields = fields, filters + self.temp_filters, self.temp_fields = filters, fields + if fields: self.fields = fields else: @@ -158,7 +163,7 @@ class DatabaseQuery(object): if not self.columns: return [] - result = self.build_and_run() + result = self.build_and_run(ignore_permissions=ignore_permissions) if with_comment_count and not as_list and self.doctype: self.add_comment_count(result) @@ -172,7 +177,7 @@ class DatabaseQuery(object): return result - def build_and_run(self): + def build_and_run(self, ignore_permissions): args = self.prepare_args() args.limit = self.add_limit() @@ -197,6 +202,15 @@ class DatabaseQuery(object): %(limit)s""" % args ) + if ignore_permissions: + sql = self.query.get_sql(self.doctype, fields=self.temp_fields, filters=self.temp_filters) + return sql.run( + as_dict=not self.as_list, + debug=self.debug, + update=self.update, + ignore_ddl=self.ignore_ddl, + run=self.run, + ) return frappe.db.sql( query, From d1f5c49b0292c503dc6c1cc1c9c70b2ce7d6a360 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Fri, 10 Jun 2022 11:27:16 +0530 Subject: [PATCH 02/10] refactor(qb-engine): frappe.db.query * feat: Added support for True as filter and pluck in the query engine * feat: Added join support * fix: return if filters are None * feat: Added support for sets as filters and additional conditions * fix: fixed IS operator in query builder * feat: Added support for 'as' in query engine * fix: fixed 'as' for criterion objects passed directly * fix: fixed frappe.db.count * fix: fixed functions in fieldname * feat: Added support for multiple fields passed as a single string :) fixed None fields in a list * feat: Added support for "as" in single string fields * fix: fixed queries with invalid syntax --- frappe/database/query.py | 49 +++++++++++++++++-- .../integrations/doctype/webhook/__init__.py | 2 +- frappe/model/db_query.py | 17 +++++-- frappe/share.py | 12 ++--- .../workflow_action/workflow_action.py | 2 +- 5 files changed, 68 insertions(+), 14 deletions(-) diff --git a/frappe/database/query.py b/frappe/database/query.py index f7cc143cf7..22b3074cad 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -93,7 +93,7 @@ def func_between(key: Field, value: Union[List, Tuple]) -> frappe.qb: def func_is(key, value): "Wrapper for IS" - return Field(key).isnotnull() if value.lower() == "set" else Field(key).isnull() + return key.isnotnull() if value.lower() == "set" else key.isnull() def func_timespan(key: Field, value: str) -> frappe.qb: @@ -238,7 +238,7 @@ class Query: Returns: conditions (frappe.qb): frappe.qb object """ - if kwargs.get("orderby"): + if kwargs.get("orderby") and kwargs.get("orderby") != "KEEP_DEFAULT_ORDERING": orderby = kwargs.get("orderby") if isinstance(orderby, str) and len(orderby.split()) > 1: for ordby in orderby.split(","): @@ -250,6 +250,7 @@ class Query: if kwargs.get("limit"): conditions = conditions.limit(kwargs.get("limit")) + conditions = conditions.offset(kwargs.get("offset", 0)) if kwargs.get("distinct"): conditions = conditions.distinct() @@ -257,6 +258,9 @@ class Query: if kwargs.get("for_update"): conditions = conditions.for_update() + if kwargs.get("groupby"): + conditions = conditions.groupby(kwargs.get("groupby")) + return conditions def misc_query(self, table: str, filters: Union[List, Tuple] = None, **kwargs): @@ -308,6 +312,10 @@ class Query: conditions = self.add_conditions(conditions, **kwargs) return conditions + for key, value in filters.items(): + if isinstance(value, bool): + filters.update({key: str(int(value))}) + for key in filters: value = filters.get(key) _operator = self.OPERATOR_MAP["="] @@ -354,6 +362,36 @@ class Query: return criterion + def set_fields(self, fields, **kwargs): + fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" + if isinstance(fields, str) and "," in fields: + fields = fields.split(",") + fields = [field.replace(" ", "") if "as" not in field else field for field in fields] + + if isinstance(fields, str): + if fields == "*": + return fields + if " as " in fields: + fields, reference = fields.split(" as ") + fields = Field(fields).as_(reference) + else: + if issubclass(type(fields), Criterion): + return fields + updated_fields = list() + for field in fields: + if not isinstance(field, Criterion) and field: + if " as " in field: + field, reference = field.split(" as ") + updated_fields.append(Field(field).as_(reference)) + else: + updated_fields.append(Field(field)) + fields = updated_fields + + if not isinstance(fields, (list, tuple, str, Criterion)): + fields = list(fields) + + return fields + def get_sql( self, table: str, @@ -364,12 +402,17 @@ class Query: # Clean up state before each query self.tables = {} criterion = self.build_conditions(table, filters, **kwargs) + fields = self.set_fields(fields, **kwargs) + + join = kwargs.get("join").replace(" ", "_") if kwargs.get("join") else "left_join" if len(self.tables) > 1: primary_table = self.tables[table] del self.tables[table] for table_object in self.tables.values(): - criterion = criterion.left_join(table_object).on(table_object.parent == primary_table.name) + criterion = getattr(criterion, join)(table_object).on( + table_object.parent == primary_table.name + ) if isinstance(fields, (list, tuple)): query = criterion.select(*kwargs.get("field_objects", fields)) diff --git a/frappe/integrations/doctype/webhook/__init__.py b/frappe/integrations/doctype/webhook/__init__.py index 915d2819ee..077b39101e 100644 --- a/frappe/integrations/doctype/webhook/__init__.py +++ b/frappe/integrations/doctype/webhook/__init__.py @@ -25,7 +25,7 @@ def run_webhooks(doc, method): # query webhooks webhooks_list = frappe.get_all( "Webhook", - fields=["name", "`condition`", "webhook_docevent", "webhook_doctype"], + fields=["name", "condition", "webhook_docevent", "webhook_doctype"], filters={"enabled": True}, ) diff --git a/frappe/model/db_query.py b/frappe/model/db_query.py index 6ac3402235..49b324b794 100644 --- a/frappe/model/db_query.py +++ b/frappe/model/db_query.py @@ -163,7 +163,7 @@ class DatabaseQuery(object): if not self.columns: return [] - result = self.build_and_run(ignore_permissions=ignore_permissions) + result = self.build_and_run(ignore_permissions=ignore_permissions, pluck=pluck) if with_comment_count and not as_list and self.doctype: self.add_comment_count(result) @@ -177,7 +177,7 @@ class DatabaseQuery(object): return result - def build_and_run(self, ignore_permissions): + def build_and_run(self, ignore_permissions, pluck): args = self.prepare_args() args.limit = self.add_limit() @@ -203,7 +203,18 @@ class DatabaseQuery(object): % args ) if ignore_permissions: - sql = self.query.get_sql(self.doctype, fields=self.temp_fields, filters=self.temp_filters) + sql = self.query.get_sql( + self.doctype, + fields=self.temp_fields, + filters=self.temp_filters, + pluck=pluck, + join=self.join, + orderby=self.order_by, + groupby=self.group_by, + distinct=self.distinct, + limit=self.limit_page_length, + offset=self.limit_start, + ) return sql.run( as_dict=not self.as_list, debug=self.debug, diff --git a/frappe/share.py b/frappe/share.py index 3edcb1be38..2933017048 100644 --- a/frappe/share.py +++ b/frappe/share.py @@ -104,12 +104,12 @@ def get_users(doctype, name): return frappe.db.get_all( "DocShare", fields=[ - "`name`", - "`user`", - "`read`", - "`write`", - "`submit`", - "`share`", + "name", + "user", + "read", + "write", + "submit", + "share", "everyone", "owner", "creation", diff --git a/frappe/workflow/doctype/workflow_action/workflow_action.py b/frappe/workflow/doctype/workflow_action/workflow_action.py index 038a3021d2..545ad6ec77 100644 --- a/frappe/workflow/doctype/workflow_action/workflow_action.py +++ b/frappe/workflow/doctype/workflow_action/workflow_action.py @@ -290,7 +290,7 @@ def update_completed_workflow_actions_using_user(doc, user=None): def get_next_possible_transitions(workflow_name, state, doc=None): transitions = frappe.get_all( "Workflow Transition", - fields=["allowed", "action", "state", "allow_self_approval", "next_state", "`condition`"], + fields=["allowed", "action", "state", "allow_self_approval", "next_state", "condition"], filters=[["parent", "=", workflow_name], ["state", "=", state]], ) From d0680941adacb6ebaa9bc9334be784566cf97a82 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Mon, 13 Jun 2022 16:25:47 +0530 Subject: [PATCH 03/10] refactor: frappe.qb.engine * feat: supporting empty iterables for Contains objects * fix: explicitly setting empty iterables as tuples to support more operators * feat: Added locate to frappe.qb Functions * feat: Added support for functions passed as strings in fields * feat: Included Criterion objects as fields * fix: picking up only function intended fields to pass to get_function_objects * feat: Added iterable for available functions, added support for Field objects * fix: fixed * passed in fields in lists --- frappe/database/query.py | 90 +++++++++++++++++++++++++++---- frappe/model/db_query.py | 25 +-------- frappe/query_builder/functions.py | 16 ++++++ 3 files changed, 99 insertions(+), 32 deletions(-) diff --git a/frappe/database/query.py b/frappe/database/query.py index 22b3074cad..b9de427c60 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -1,5 +1,6 @@ import operator import re +from ast import literal_eval from functools import cached_property from typing import Any, Callable, Dict, List, Tuple, Union @@ -325,7 +326,8 @@ class Query: continue if isinstance(value, (list, tuple)): _operator = self.OPERATOR_MAP[value[0].casefold()] - conditions = conditions.where(_operator(Field(key), value[1])) + _value = value[1] if value[1] else ("",) + conditions = conditions.where(_operator(Field(key), _value)) else: if value is not None: conditions = conditions.where(_operator(Field(key), value)) @@ -364,32 +366,102 @@ class Query: def set_fields(self, fields, **kwargs): fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" - if isinstance(fields, str) and "," in fields: + if isinstance(fields, list) and None in fields and Field not in fields: + return None + + is_list = isinstance(fields, (list, tuple, set)) + is_str = isinstance(fields, str) + + def add_functions(fields): + from frappe.query_builder.functions import SqlFunctions + + sql_functions = [sql_function.value for sql_function in SqlFunctions] + + def get_function_objects(fields): + from frappe.query_builder import functions + + def literal_eval_(literal): + try: + return literal_eval(literal) + except (ValueError, SyntaxError): + return literal + + func = fields.split("(")[0].casefold().split() + func = [f for f in func if f in sql_functions][0] + args = fields[len(func) + 1 : fields.index(")")].split(",") + args = [literal_eval_(arg.strip()) for arg in args] + return getattr(functions, func.capitalize())(*args) + + if is_str and any( + [func in fields.casefold() and f"{func}(" in fields.casefold() for func in sql_functions] + ): + function_objects = [] + return function_objects or [get_function_objects(fields)] + else: + functions = [] + for field in fields: + if not issubclass(type(field), Criterion): + if any( + [func in field.casefold() and f"{func}(" in field.casefold() for func in sql_functions] + ): + functions.append(field.casefold()) + return [get_function_objects(function) for function in functions] + + function_objects = ( + add_functions(fields=fields) if not issubclass(type(fields), Criterion) else [] + ) + for function in function_objects: + if is_str: + fields = re.sub( + r"\(.*?\)", "", fields.casefold().replace(str(type(function).__name__).strip().casefold(), "") + ) + + else: + updated_fields = [] + for field in fields: + if isinstance(field, str): + updated_fields.append( + re.sub(r"\(.*?\)", "", field) + .strip() + .casefold() + .replace(str(type(function).__name__).strip().casefold(), "") + ) + else: + updated_fields.append(field) + + fields = updated_fields + + if is_str and "," in fields: fields = fields.split(",") fields = [field.replace(" ", "") if "as" not in field else field for field in fields] - if isinstance(fields, str): + if is_str: if fields == "*": return fields if " as " in fields: fields, reference = fields.split(" as ") fields = Field(fields).as_(reference) - else: + + if not is_str and fields: if issubclass(type(fields), Criterion): return fields - updated_fields = list() + updated_fields = [] + if "*" in fields: + return fields for field in fields: if not isinstance(field, Criterion) and field: if " as " in field: field, reference = field.split(" as ") - updated_fields.append(Field(field).as_(reference)) + updated_fields.append(Field(field.strip()).as_(reference)) else: updated_fields.append(Field(field)) - fields = updated_fields - if not isinstance(fields, (list, tuple, str, Criterion)): - fields = list(fields) + fields = updated_fields + if not is_list: + fields = [fields] if fields else [] + + fields.extend(function_objects) return fields def get_sql( diff --git a/frappe/model/db_query.py b/frappe/model/db_query.py index 49b324b794..b9c757bfa6 100644 --- a/frappe/model/db_query.py +++ b/frappe/model/db_query.py @@ -163,7 +163,7 @@ class DatabaseQuery(object): if not self.columns: return [] - result = self.build_and_run(ignore_permissions=ignore_permissions, pluck=pluck) + result = self.build_and_run() if with_comment_count and not as_list and self.doctype: self.add_comment_count(result) @@ -177,7 +177,7 @@ class DatabaseQuery(object): return result - def build_and_run(self, ignore_permissions, pluck): + def build_and_run(self): args = self.prepare_args() args.limit = self.add_limit() @@ -202,27 +202,6 @@ class DatabaseQuery(object): %(limit)s""" % args ) - if ignore_permissions: - sql = self.query.get_sql( - self.doctype, - fields=self.temp_fields, - filters=self.temp_filters, - pluck=pluck, - join=self.join, - orderby=self.order_by, - groupby=self.group_by, - distinct=self.distinct, - limit=self.limit_page_length, - offset=self.limit_start, - ) - return sql.run( - as_dict=not self.as_list, - debug=self.debug, - update=self.update, - ignore_ddl=self.ignore_ddl, - run=self.run, - ) - return frappe.db.sql( query, as_dict=not self.as_list, diff --git a/frappe/query_builder/functions.py b/frappe/query_builder/functions.py index f03c139f57..7b260098d1 100644 --- a/frappe/query_builder/functions.py +++ b/frappe/query_builder/functions.py @@ -1,3 +1,5 @@ +from enum import Enum + from pypika.functions import * from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function @@ -14,6 +16,11 @@ class Concat_ws(Function): super(Concat_ws, self).__init__("CONCAT_WS", *terms, **kwargs) +class Locate(Function): + def __init__(self, *terms, **kwargs): + super(Locate, self).__init__("LOCATE", *terms, **kwargs) + + GroupConcat = ImportMapper({db_type_is.MARIADB: GROUP_CONCAT, db_type_is.POSTGRES: STRING_AGG}) Match = ImportMapper({db_type_is.MARIADB: MATCH, db_type_is.POSTGRES: TO_TSVECTOR}) @@ -81,6 +88,15 @@ def _aggregate(function, dt, fieldname, filters, **kwargs): ) +class SqlFunctions(Enum): + DayOfYear = "dayofyear" + Extract = "extract" + Locate = "locate" + Count = "count" + Sum = "sum" + Avg = "avg" + + def _max(dt, fieldname, filters=None, **kwargs): return _aggregate(Max, dt, fieldname, filters, **kwargs) From fca026927e7fcbe1e7bd88f01bd742d8bbfef991 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Wed, 15 Jun 2022 12:34:02 +0530 Subject: [PATCH 04/10] refactor: moved all query logic to query class --- frappe/database/database.py | 7 ------- frappe/database/query.py | 10 +++++++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/frappe/database/database.py b/frappe/database/database.py index 1de22af037..d88d3417d6 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -723,13 +723,6 @@ class Database(object): ): field_objects = [] - if not isinstance(fields, Criterion): - for field in fields: - if "(" in str(field) or " as " in str(field): - field_objects.append(PseudoColumn(field)) - else: - field_objects.append(field) - query = self.query.get_sql( table=doctype, filters=filters, diff --git a/frappe/database/query.py b/frappe/database/query.py index b9de427c60..c8e1ddabc6 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -370,6 +370,10 @@ class Query: return None is_list = isinstance(fields, (list, tuple, set)) + if is_list and len(fields) == 1: + fields = fields[0] + is_list = False + is_str = isinstance(fields, str) def add_functions(fields): @@ -389,7 +393,7 @@ class Query: func = fields.split("(")[0].casefold().split() func = [f for f in func if f in sql_functions][0] args = fields[len(func) + 1 : fields.index(")")].split(",") - args = [literal_eval_(arg.strip()) for arg in args] + args = [Field(literal_eval_((arg.strip()))) for arg in args] return getattr(functions, func.capitalize())(*args) if is_str and any( @@ -474,7 +478,7 @@ class Query: # Clean up state before each query self.tables = {} criterion = self.build_conditions(table, filters, **kwargs) - fields = self.set_fields(fields, **kwargs) + fields = self.set_fields(kwargs.get("field_objects") or fields, **kwargs) join = kwargs.get("join").replace(" ", "_") if kwargs.get("join") else "left_join" @@ -487,7 +491,7 @@ class Query: ) if isinstance(fields, (list, tuple)): - query = criterion.select(*kwargs.get("field_objects", fields)) + query = criterion.select(*fields) elif isinstance(fields, Criterion): query = criterion.select(fields) From 960952cfc337f2764bd52d20458cf3f5c2b3bbba Mon Sep 17 00:00:00 2001 From: Aradhya Date: Wed, 15 Jun 2022 13:06:56 +0530 Subject: [PATCH 05/10] feat(qb-engine): Added Aggregation function support * Added Min, Max * Added tests --- frappe/database/query.py | 3 +- .../integrations/doctype/webhook/__init__.py | 2 +- frappe/model/db_query.py | 6 +- frappe/query_builder/functions.py | 2 + frappe/share.py | 12 ++-- frappe/tests/test_query.py | 55 ++++++++++++++++++- .../workflow_action/workflow_action.py | 2 +- 7 files changed, 67 insertions(+), 15 deletions(-) diff --git a/frappe/database/query.py b/frappe/database/query.py index c8e1ddabc6..6295c4ec3e 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -462,7 +462,8 @@ class Query: fields = updated_fields - if not is_list: + # Need to check instance again since fields modified. + if not isinstance(fields, (list, tuple, set)): fields = [fields] if fields else [] fields.extend(function_objects) diff --git a/frappe/integrations/doctype/webhook/__init__.py b/frappe/integrations/doctype/webhook/__init__.py index 077b39101e..915d2819ee 100644 --- a/frappe/integrations/doctype/webhook/__init__.py +++ b/frappe/integrations/doctype/webhook/__init__.py @@ -25,7 +25,7 @@ def run_webhooks(doc, method): # query webhooks webhooks_list = frappe.get_all( "Webhook", - fields=["name", "condition", "webhook_docevent", "webhook_doctype"], + fields=["name", "`condition`", "webhook_docevent", "webhook_doctype"], filters={"enabled": True}, ) diff --git a/frappe/model/db_query.py b/frappe/model/db_query.py index b9c757bfa6..fe52818235 100644 --- a/frappe/model/db_query.py +++ b/frappe/model/db_query.py @@ -89,9 +89,6 @@ class DatabaseQuery(object): ignore_ddl=False, parent_doctype=None, ) -> List: - from frappe.database.query import Query - - self.query = Query() if ( not ignore_permissions @@ -115,8 +112,6 @@ class DatabaseQuery(object): # if `filters` is a list of strings, its probably fields filters, fields = fields, filters - self.temp_filters, self.temp_fields = filters, fields - if fields: self.fields = fields else: @@ -202,6 +197,7 @@ class DatabaseQuery(object): %(limit)s""" % args ) + return frappe.db.sql( query, as_dict=not self.as_list, diff --git a/frappe/query_builder/functions.py b/frappe/query_builder/functions.py index 7b260098d1..b55eff4298 100644 --- a/frappe/query_builder/functions.py +++ b/frappe/query_builder/functions.py @@ -95,6 +95,8 @@ class SqlFunctions(Enum): Count = "count" Sum = "sum" Avg = "avg" + Max = "max" + Min = "min" def _max(dt, fieldname, filters=None, **kwargs): diff --git a/frappe/share.py b/frappe/share.py index 2933017048..3edcb1be38 100644 --- a/frappe/share.py +++ b/frappe/share.py @@ -104,12 +104,12 @@ def get_users(doctype, name): return frappe.db.get_all( "DocShare", fields=[ - "name", - "user", - "read", - "write", - "submit", - "share", + "`name`", + "`user`", + "`read`", + "`write`", + "`submit`", + "`share`", "everyone", "owner", "creation", diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index 949c3e9433..d7190fe7ae 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -1,11 +1,12 @@ import unittest import frappe +from frappe.query_builder import Field from frappe.tests.test_query_builder import db_type_is, run_only_if -@run_only_if(db_type_is.MARIADB) class TestQuery(unittest.TestCase): + @run_only_if(db_type_is.MARIADB) def test_multiple_tables_in_filters(self): self.assertEqual( frappe.db.query.get_sql( @@ -18,3 +19,55 @@ class TestQuery(unittest.TestCase): ).get_sql(), "SELECT * FROM `tabDocType` LEFT JOIN `tabBOM Update Log` ON `tabBOM Update Log`.`parent`=`tabDocType`.`name` WHERE `tabBOM Update Log`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'", ) + + def test_string_fields(self): + self.assertEqual( + frappe.db.query.get_sql("User", fields="name, email", filters={"name": "Administrator"}), + frappe.qb.from_("User") + .select(Field("name"), Field("email")) + .where(Field("name") == "Administrator"), + ) + + self.assertEqual( + frappe.db.query.get_sql("User", fields=["name, email"], filters={"name": "Administrator"}), + frappe.qb.from_("User") + .select(Field("name"), Field("email")) + .where(Field("name") == "Administrator"), + ) + + def test_functions_fields(self): + from frappe.query_builder.functions import Count + + self.assertEqual( + frappe.db.query.get_sql("User", fields="Count(name)", filters={}), + frappe.qb.from_("User").select(Count(Field("name"))), + ) + + self.assertEqual( + frappe.db.query.get_sql("User", fields="Count(name), Max(name)", filters={}), + frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), + ) + + self.assertEqual( + frappe.db.query.get_sql("User", fields=["Count(name)", "Max(name)"], filters={}), + frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), + ) + + self.assertEqual( + frappe.db.query.get_sql("User", fields=[Count("*")], filters={}), + frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), + ) + + def test_qb_fields(self): + user_doctype = frappe.qb.DocType("User") + self.assertEqual( + frappe.db.query.get_sql( + user_doctype, fields=[user_doctype.name, user_doctype.email], filters={} + ), + frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email), + ) + + self.assertEqual( + frappe.db.query.get_sql(user_doctype, fields=user_doctype.email, filters={}), + frappe.qb.from_(user_doctype).select(user_doctype.email), + ) diff --git a/frappe/workflow/doctype/workflow_action/workflow_action.py b/frappe/workflow/doctype/workflow_action/workflow_action.py index 545ad6ec77..038a3021d2 100644 --- a/frappe/workflow/doctype/workflow_action/workflow_action.py +++ b/frappe/workflow/doctype/workflow_action/workflow_action.py @@ -290,7 +290,7 @@ def update_completed_workflow_actions_using_user(doc, user=None): def get_next_possible_transitions(workflow_name, state, doc=None): transitions = frappe.get_all( "Workflow Transition", - fields=["allowed", "action", "state", "allow_self_approval", "next_state", "condition"], + fields=["allowed", "action", "state", "allow_self_approval", "next_state", "`condition`"], filters=[["parent", "=", workflow_name], ["state", "=", state]], ) From 7732accded4ca57ac78ec3c9f81f9069b3250947 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Tue, 21 Jun 2022 16:04:09 +0530 Subject: [PATCH 06/10] feat: Attached Engine object to qb & added dynamic type hints --- frappe/__init__.py | 9 +++++++-- frappe/database/database.py | 27 +++++++++------------------ frappe/database/query.py | 2 +- frappe/query_builder/__init__.py | 1 + frappe/query_builder/builder.py | 9 +++++++++ frappe/query_builder/functions.py | 4 +--- frappe/query_builder/utils.py | 6 ++++++ 7 files changed, 34 insertions(+), 24 deletions(-) diff --git a/frappe/__init__.py b/frappe/__init__.py index 51c6ba3a74..d36a13b3a3 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -21,7 +21,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import click from werkzeug.local import Local, release_local -from frappe.query_builder import get_query_builder, patch_query_aggregation, patch_query_execute +from frappe.query_builder import ( + get_qb_engine, + get_query_builder, + patch_query_aggregation, + patch_query_execute, +) from frappe.utils.caching import request_cache from frappe.utils.data import cstr, sbool @@ -234,7 +239,7 @@ def init(site, sites_path=None, new_site=False): local.session = _dict() local.dev_server = _dev_server local.qb = get_query_builder(local.conf.db_type or "mariadb") - + local.qb.engine = get_qb_engine() setup_module_map() patch_query_execute() patch_query_aggregation() diff --git a/frappe/database/database.py b/frappe/database/database.py index d88d3417d6..fd94bcfe12 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -12,7 +12,7 @@ from contextlib import contextmanager from time import time from typing import Dict, List, Optional, Tuple, Union -from pypika.terms import Criterion, NullValue, PseudoColumn +from pypika.terms import Criterion, NullValue import frappe import frappe.defaults @@ -69,15 +69,6 @@ class Database(object): self.password = password or frappe.conf.db_password self.value_cache = {} - @property - def query(self): - if not hasattr(self, "_query"): - from .query import Query - - self._query = Query() - del Query - return self._query - def setup_type_map(self): pass @@ -595,7 +586,7 @@ class Database(object): return [map(values.get, fields)] else: - r = self.query.get_sql( + r = frappe.qb.engine.get_sql( "Singles", filters={"field": ("in", tuple(fields)), "doctype": doctype}, fields=["field", "value"], @@ -625,14 +616,14 @@ class Database(object): # Get coulmn and value of the single doctype Accounts Settings account_settings = frappe.db.get_singles_dict("Accounts Settings") """ - result = self.query.get_sql( + queried_result = frappe.qb.engine.get_sql( "Singles", filters={"doctype": doctype}, fields=["field", "value"], for_update=for_update, ).run() - return frappe._dict(result) + return frappe._dict(queried_result) @staticmethod def get_all(*args, **kwargs): @@ -681,7 +672,7 @@ class Database(object): if cache and fieldname in self.value_cache[doctype]: return self.value_cache[doctype][fieldname] - val = self.query.get_sql( + val = frappe.qb.engine.get_sql( table="Singles", filters={"doctype": doctype, "field": fieldname}, fields="value", @@ -723,7 +714,7 @@ class Database(object): ): field_objects = [] - query = self.query.get_sql( + query = frappe.qb.engine.get_sql( table=doctype, filters=filters, orderby=order_by, @@ -833,7 +824,7 @@ class Database(object): frappe.clear_document_cache(dt, docname) else: - query = self.query.build_conditions(table=dt, filters=dn, update=True) + query = frappe.qb.engine.build_conditions(table=dt, filters=dn, update=True) # TODO: Fix this; doesn't work rn - gavin@frappe.io # frappe.cache().hdel_keys(dt, "document_cache") # Workaround: clear all document caches @@ -1034,7 +1025,7 @@ class Database(object): cache_count = frappe.cache().get_value("doctype:count:{}".format(dt)) if cache_count is not None: return cache_count - query = self.query.get_sql(table=dt, filters=filters, fields=Count("*"), distinct=distinct) + query = frappe.qb.engine.get_sql(table=dt, filters=filters, fields=Count("*"), distinct=distinct) count = self.sql(query, debug=debug)[0][0] if not filters and cache: frappe.cache().set_value("doctype:count:{}".format(dt), count, expires_in_sec=86400) @@ -1174,7 +1165,7 @@ class Database(object): Doctype name can be passed directly, it will be pre-pended with `tab`. """ filters = filters or kwargs.get("conditions") - query = self.query.build_conditions(table=doctype, filters=filters).delete() + query = frappe.qb.engine.build_conditions(table=doctype, filters=filters).delete() if "debug" not in kwargs: kwargs["debug"] = debug return query.run(**kwargs) diff --git a/frappe/database/query.py b/frappe/database/query.py index 6295c4ec3e..abf5c6936f 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -169,7 +169,7 @@ OPERATOR_MAP: Dict[str, Callable] = { } -class Query: +class Engine: tables: dict = {} @cached_property diff --git a/frappe/query_builder/__init__.py b/frappe/query_builder/__init__.py index 1bf9ec97d9..eb1d9df08f 100644 --- a/frappe/query_builder/__init__.py +++ b/frappe/query_builder/__init__.py @@ -7,6 +7,7 @@ from frappe.query_builder.terms import ParameterizedFunction, ParameterizedValue from frappe.query_builder.utils import ( Column, DocType, + get_qb_engine, get_query_builder, patch_query_aggregation, patch_query_execute, diff --git a/frappe/query_builder/builder.py b/frappe/query_builder/builder.py index d2fdeab324..c23d76974c 100644 --- a/frappe/query_builder/builder.py +++ b/frappe/query_builder/builder.py @@ -1,3 +1,5 @@ +import typing + from pypika import MySQLQuery, Order, PostgreSQLQuery, terms from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder from pypika.queries import QueryBuilder, Schema, Table @@ -13,6 +15,13 @@ class Base: Schema = Schema Table = Table + # Added dynamic type hints for engine attribute + # which is to be assigned later. + if typing.TYPE_CHECKING: + from frappe.database.query import Engine + + engine: Engine + @staticmethod def functions(name: str, *args, **kwargs) -> Function: return Function(name, *args, **kwargs) diff --git a/frappe/query_builder/functions.py b/frappe/query_builder/functions.py index b55eff4298..d2debd6da1 100644 --- a/frappe/query_builder/functions.py +++ b/frappe/query_builder/functions.py @@ -4,7 +4,6 @@ from pypika.functions import * from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function import frappe -from frappe.database.query import Query from frappe.query_builder.custom import GROUP_CONCAT, MATCH, STRING_AGG, TO_TSVECTOR from frappe.query_builder.utils import ImportMapper, db_type_is @@ -80,8 +79,7 @@ class Cast_(Function): def _aggregate(function, dt, fieldname, filters, **kwargs): return ( - Query() - .build_conditions(dt, filters) + frappe.qb.engine.build_conditions(dt, filters) .select(function(PseudoColumn(fieldname))) .run(**kwargs)[0][0] or 0 diff --git a/frappe/query_builder/utils.py b/frappe/query_builder/utils.py index 69aee9b350..573173ef03 100644 --- a/frappe/query_builder/utils.py +++ b/frappe/query_builder/utils.py @@ -45,6 +45,12 @@ def get_query_builder(type_of_db: str) -> Union[Postgres, MariaDB]: return picks[db] +def get_qb_engine(): + from frappe.database.query import Engine + + return Engine() + + def get_attr(method_string): modulename = ".".join(method_string.split(".")[:-1]) methodname = method_string.split(".")[-1] From 4af2e1e88605b54ab0a737c15c9ae75aa6f0b698 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Tue, 21 Jun 2022 16:39:15 +0530 Subject: [PATCH 07/10] refactor: replaced frappe.db.query with frappe.qb.engine --- frappe/desk/doctype/number_card/number_card.py | 2 +- frappe/desk/listview.py | 2 +- frappe/tests/test_db_query.py | 2 +- frappe/tests/test_query.py | 18 +++++++++--------- frappe/utils/goal.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/frappe/desk/doctype/number_card/number_card.py b/frappe/desk/doctype/number_card/number_card.py index d6d4f00b69..a9ef31cb2c 100644 --- a/frappe/desk/doctype/number_card/number_card.py +++ b/frappe/desk/doctype/number_card/number_card.py @@ -197,7 +197,7 @@ def get_cards_for_user(doctype, txt, searchfield, start, page_len, filters): if txt: search_conditions = [numberCard[field].like("%{txt}%".format(txt=txt)) for field in searchfields] - condition_query = frappe.db.query.build_conditions(doctype, filters) + condition_query = frappe.qb.engine.build_conditions(doctype, filters) return ( condition_query.select(numberCard.name, numberCard.label, numberCard.document_type) diff --git a/frappe/desk/listview.py b/frappe/desk/listview.py index 5149f8bf86..11c985d1ff 100644 --- a/frappe/desk/listview.py +++ b/frappe/desk/listview.py @@ -37,7 +37,7 @@ def get_group_by_count(doctype: str, current_filters: str, field: str) -> List[D ToDo = DocType("ToDo") User = DocType("User") count = Count("*").as_("count") - filtered_records = frappe.db.query.build_conditions(doctype, current_filters).select("name") + filtered_records = frappe.qb.engine.build_conditions(doctype, current_filters).select("name") return ( frappe.qb.from_(ToDo) diff --git a/frappe/tests/test_db_query.py b/frappe/tests/test_db_query.py index c1b2e05266..3a4e5b72bd 100644 --- a/frappe/tests/test_db_query.py +++ b/frappe/tests/test_db_query.py @@ -143,7 +143,7 @@ class TestReportview(unittest.TestCase): ) def test_none_filter(self): - query = frappe.db.query.get_sql("DocType", fields="name", filters={"restrict_to_domain": None}) + query = frappe.qb.engine.get_sql("DocType", fields="name", filters={"restrict_to_domain": None}) sql = str(query).replace("`", "").replace('"', "") condition = "restrict_to_domain IS NULL" self.assertIn(condition, sql) diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index d7190fe7ae..dd74d5ad18 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -9,7 +9,7 @@ class TestQuery(unittest.TestCase): @run_only_if(db_type_is.MARIADB) def test_multiple_tables_in_filters(self): self.assertEqual( - frappe.db.query.get_sql( + frappe.qb.engine.get_sql( "DocType", ["*"], [ @@ -22,14 +22,14 @@ class TestQuery(unittest.TestCase): def test_string_fields(self): self.assertEqual( - frappe.db.query.get_sql("User", fields="name, email", filters={"name": "Administrator"}), + frappe.qb.engine.get_sql("User", fields="name, email", filters={"name": "Administrator"}), frappe.qb.from_("User") .select(Field("name"), Field("email")) .where(Field("name") == "Administrator"), ) self.assertEqual( - frappe.db.query.get_sql("User", fields=["name, email"], filters={"name": "Administrator"}), + frappe.qb.engine.get_sql("User", fields=["name, email"], filters={"name": "Administrator"}), frappe.qb.from_("User") .select(Field("name"), Field("email")) .where(Field("name") == "Administrator"), @@ -39,35 +39,35 @@ class TestQuery(unittest.TestCase): from frappe.query_builder.functions import Count self.assertEqual( - frappe.db.query.get_sql("User", fields="Count(name)", filters={}), + frappe.qb.engine.get_sql("User", fields="Count(name)", filters={}), frappe.qb.from_("User").select(Count(Field("name"))), ) self.assertEqual( - frappe.db.query.get_sql("User", fields="Count(name), Max(name)", filters={}), + frappe.qb.engine.get_sql("User", fields="Count(name), Max(name)", filters={}), frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), ) self.assertEqual( - frappe.db.query.get_sql("User", fields=["Count(name)", "Max(name)"], filters={}), + frappe.qb.engine.get_sql("User", fields=["Count(name)", "Max(name)"], filters={}), frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), ) self.assertEqual( - frappe.db.query.get_sql("User", fields=[Count("*")], filters={}), + frappe.qb.engine.get_sql("User", fields=[Count("*")], filters={}), frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), ) def test_qb_fields(self): user_doctype = frappe.qb.DocType("User") self.assertEqual( - frappe.db.query.get_sql( + frappe.qb.engine.get_sql( user_doctype, fields=[user_doctype.name, user_doctype.email], filters={} ), frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email), ) self.assertEqual( - frappe.db.query.get_sql(user_doctype, fields=user_doctype.email, filters={}), + frappe.qb.engine.get_sql(user_doctype, fields=user_doctype.email, filters={}), frappe.qb.from_(user_doctype).select(user_doctype.email), ) diff --git a/frappe/utils/goal.py b/frappe/utils/goal.py index fb348496da..9273b83cbf 100644 --- a/frappe/utils/goal.py +++ b/frappe/utils/goal.py @@ -25,7 +25,7 @@ def get_monthly_results( date_format = "%m-%Y" if frappe.db.db_type != "postgres" else "MM-YYYY" return dict( - frappe.db.query.build_conditions(table=goal_doctype, filters=filters) + frappe.qb.engine.build_conditions(table=goal_doctype, filters=filters) .select( DateFormat(Table[date_col], date_format).as_("month_year"), Function(aggregation, goal_field), From 6db6be1f3c8b8dc9e8e458c6191e5e91e8de8943 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Tue, 21 Jun 2022 18:42:06 +0530 Subject: [PATCH 08/10] refactor: frappe.qb.engine * Small fixes in set_fields and clean code * Optimize casefolds * Fixed functions passed in List * get_sql => get_query - more expressive, less confusion * Updated tests --- frappe/database/database.py | 10 ++-- frappe/database/query.py | 94 +++++++++++++++++++---------------- frappe/tests/test_db_query.py | 2 +- frappe/tests/test_query.py | 43 ++++++++-------- 4 files changed, 79 insertions(+), 70 deletions(-) diff --git a/frappe/database/database.py b/frappe/database/database.py index fd94bcfe12..b418b5e9f6 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -586,7 +586,7 @@ class Database(object): return [map(values.get, fields)] else: - r = frappe.qb.engine.get_sql( + r = frappe.qb.engine.get_query( "Singles", filters={"field": ("in", tuple(fields)), "doctype": doctype}, fields=["field", "value"], @@ -616,7 +616,7 @@ class Database(object): # Get coulmn and value of the single doctype Accounts Settings account_settings = frappe.db.get_singles_dict("Accounts Settings") """ - queried_result = frappe.qb.engine.get_sql( + queried_result = frappe.qb.engine.get_query( "Singles", filters={"doctype": doctype}, fields=["field", "value"], @@ -672,7 +672,7 @@ class Database(object): if cache and fieldname in self.value_cache[doctype]: return self.value_cache[doctype][fieldname] - val = frappe.qb.engine.get_sql( + val = frappe.qb.engine.get_query( table="Singles", filters={"doctype": doctype, "field": fieldname}, fields="value", @@ -714,7 +714,7 @@ class Database(object): ): field_objects = [] - query = frappe.qb.engine.get_sql( + query = frappe.qb.engine.get_query( table=doctype, filters=filters, orderby=order_by, @@ -1025,7 +1025,7 @@ class Database(object): cache_count = frappe.cache().get_value("doctype:count:{}".format(dt)) if cache_count is not None: return cache_count - query = frappe.qb.engine.get_sql(table=dt, filters=filters, fields=Count("*"), distinct=distinct) + query = frappe.qb.engine.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct) count = self.sql(query, debug=debug)[0][0] if not filters and cache: frappe.cache().set_value("doctype:count:{}".format(dt), count, expires_in_sec=86400) diff --git a/frappe/database/query.py b/frappe/database/query.py index abf5c6936f..155cc99aab 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -8,10 +8,13 @@ import frappe from frappe import _ from frappe.boot import get_additional_filters_from_hooks from frappe.model.db_query import get_timespan_date_range -from frappe.query_builder import Criterion, Field, Order, Table +from frappe.query_builder import Criterion, Field, Order, Table, functions +from frappe.query_builder.functions import SqlFunctions TAB_PATTERN = re.compile("^tab") WORDS_PATTERN = re.compile(r"\w+") +BRACKETS_PATTERN = re.compile(r"\(.*?\)") +SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions] def like(key: Field, value: str) -> frappe.qb: @@ -144,6 +147,13 @@ def change_orderby(order: str): return order[0], Order.desc +def literal_eval_(literal): + try: + return literal_eval(literal) + except (ValueError, SyntaxError): + return literal + + # default operators OPERATOR_MAP: Dict[str, Callable] = { "+": operator.add, @@ -364,6 +374,34 @@ class Engine: return criterion + def get_function_objects(self, fields): + func = fields.split("(")[0].casefold().split() + func = [f for f in func if f in SQL_FUNCTIONS][0] + args = fields[len(func) + 1 : fields.index(")")].split(",") + args = [ + Field(literal_eval_((arg.strip()))) if "*" not in args else literal_eval_((arg.strip())) + for arg in args + ] + return getattr(functions, func.capitalize())(*args) + + def function_objects_to_fields(self, fields, is_str: bool): + if is_str: + functions = "" + for func in SQL_FUNCTIONS: + if f"{func}(" in fields: + functions = str(func) + str(BRACKETS_PATTERN.findall(fields)[0]) + return [self.get_function_objects(functions)] + if not functions: + return [] + else: + functions = [] + for field in fields: + field = field.casefold() if isinstance(field, str) else field + if not issubclass(type(field), Criterion): + if any([func in field and f"{func}(" in field for func in SQL_FUNCTIONS]): + functions.append(field) + return [self.get_function_objects(function) for function in functions] + def set_fields(self, fields, **kwargs): fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" if isinstance(fields, list) and None in fields and Field not in fields: @@ -375,57 +413,26 @@ class Engine: is_list = False is_str = isinstance(fields, str) - - def add_functions(fields): - from frappe.query_builder.functions import SqlFunctions - - sql_functions = [sql_function.value for sql_function in SqlFunctions] - - def get_function_objects(fields): - from frappe.query_builder import functions - - def literal_eval_(literal): - try: - return literal_eval(literal) - except (ValueError, SyntaxError): - return literal - - func = fields.split("(")[0].casefold().split() - func = [f for f in func if f in sql_functions][0] - args = fields[len(func) + 1 : fields.index(")")].split(",") - args = [Field(literal_eval_((arg.strip()))) for arg in args] - return getattr(functions, func.capitalize())(*args) - - if is_str and any( - [func in fields.casefold() and f"{func}(" in fields.casefold() for func in sql_functions] - ): - function_objects = [] - return function_objects or [get_function_objects(fields)] - else: - functions = [] - for field in fields: - if not issubclass(type(field), Criterion): - if any( - [func in field.casefold() and f"{func}(" in field.casefold() for func in sql_functions] - ): - functions.append(field.casefold()) - return [get_function_objects(function) for function in functions] + if is_str: + fields = fields.casefold() function_objects = ( - add_functions(fields=fields) if not issubclass(type(fields), Criterion) else [] + self.function_objects_to_fields(fields=fields, is_str=is_str) + if not issubclass(type(fields), Criterion) + else [] ) + for function in function_objects: if is_str: - fields = re.sub( - r"\(.*?\)", "", fields.casefold().replace(str(type(function).__name__).strip().casefold(), "") + fields = BRACKETS_PATTERN.sub( + "", fields.replace(str(type(function).__name__).strip().casefold(), "") ) - else: updated_fields = [] for field in fields: if isinstance(field, str): updated_fields.append( - re.sub(r"\(.*?\)", "", field) + BRACKETS_PATTERN.sub("", field) .strip() .casefold() .replace(str(type(function).__name__).strip().casefold(), "") @@ -433,11 +440,12 @@ class Engine: else: updated_fields.append(field) - fields = updated_fields + fields = [field for field in updated_fields if field] if is_str and "," in fields: fields = fields.split(",") fields = [field.replace(" ", "") if "as" not in field else field for field in fields] + is_list, is_str = True, False if is_str: if fields == "*": @@ -469,7 +477,7 @@ class Engine: fields.extend(function_objects) return fields - def get_sql( + def get_query( self, table: str, fields: Union[List, Tuple], diff --git a/frappe/tests/test_db_query.py b/frappe/tests/test_db_query.py index 3a4e5b72bd..8727951f4a 100644 --- a/frappe/tests/test_db_query.py +++ b/frappe/tests/test_db_query.py @@ -143,7 +143,7 @@ class TestReportview(unittest.TestCase): ) def test_none_filter(self): - query = frappe.qb.engine.get_sql("DocType", fields="name", filters={"restrict_to_domain": None}) + query = frappe.qb.engine.get_query("DocType", fields="name", filters={"restrict_to_domain": None}) sql = str(query).replace("`", "").replace('"', "") condition = "restrict_to_domain IS NULL" self.assertIn(condition, sql) diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index dd74d5ad18..06550b44a4 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -9,7 +9,7 @@ class TestQuery(unittest.TestCase): @run_only_if(db_type_is.MARIADB) def test_multiple_tables_in_filters(self): self.assertEqual( - frappe.qb.engine.get_sql( + frappe.qb.engine.get_query( "DocType", ["*"], [ @@ -22,52 +22,53 @@ class TestQuery(unittest.TestCase): def test_string_fields(self): self.assertEqual( - frappe.qb.engine.get_sql("User", fields="name, email", filters={"name": "Administrator"}), + frappe.qb.engine.get_query( + "User", fields="name, email", filters={"name": "Administrator"} + ).get_sql(), frappe.qb.from_("User") .select(Field("name"), Field("email")) - .where(Field("name") == "Administrator"), + .where(Field("name") == "Administrator") + .get_sql(), ) self.assertEqual( - frappe.qb.engine.get_sql("User", fields=["name, email"], filters={"name": "Administrator"}), + frappe.qb.engine.get_query( + "User", fields=["name, email"], filters={"name": "Administrator"} + ).get_sql(), frappe.qb.from_("User") .select(Field("name"), Field("email")) - .where(Field("name") == "Administrator"), + .where(Field("name") == "Administrator") + .get_sql(), ) def test_functions_fields(self): from frappe.query_builder.functions import Count self.assertEqual( - frappe.qb.engine.get_sql("User", fields="Count(name)", filters={}), - frappe.qb.from_("User").select(Count(Field("name"))), + frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(), + frappe.qb.from_("User").select(Count(Field("name"))).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_sql("User", fields="Count(name), Max(name)", filters={}), - frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), + frappe.qb.engine.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(), + frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_sql("User", fields=["Count(name)", "Max(name)"], filters={}), - frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), - ) - - self.assertEqual( - frappe.qb.engine.get_sql("User", fields=[Count("*")], filters={}), - frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), + frappe.qb.engine.get_query("User", fields=[Count("*")], filters={}).get_sql(), + frappe.qb.from_("User").select(Count("*")).get_sql(), ) def test_qb_fields(self): user_doctype = frappe.qb.DocType("User") self.assertEqual( - frappe.qb.engine.get_sql( + frappe.qb.engine.get_query( user_doctype, fields=[user_doctype.name, user_doctype.email], filters={} - ), - frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email), + ).get_sql(), + frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_sql(user_doctype, fields=user_doctype.email, filters={}), - frappe.qb.from_(user_doctype).select(user_doctype.email), + frappe.qb.engine.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(), + frappe.qb.from_(user_doctype).select(user_doctype.email).get_sql(), ) From 303d94494d24d88f8957afb66217ac171a251771 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Wed, 22 Jun 2022 16:58:27 +0530 Subject: [PATCH 09/10] refactor: atomic functions & removed complicated checks Co-authored-by: gavin --- frappe/database/query.py | 118 +++++++++++++++++++++------------------ 1 file changed, 64 insertions(+), 54 deletions(-) diff --git a/frappe/database/query.py b/frappe/database/query.py index 155cc99aab..60eeaee8c2 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -2,7 +2,7 @@ import operator import re from ast import literal_eval from functools import cached_property -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union import frappe from frappe import _ @@ -13,9 +13,12 @@ from frappe.query_builder.functions import SqlFunctions TAB_PATTERN = re.compile("^tab") WORDS_PATTERN = re.compile(r"\w+") -BRACKETS_PATTERN = re.compile(r"\(.*?\)") +BRACKETS_PATTERN = re.compile(r"\(.*?\)|$") SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions] +if TYPE_CHECKING: + from pypika.functions import Function + def like(key: Field, value: str) -> frappe.qb: """Wrapper method for `LIKE` @@ -374,77 +377,84 @@ class Engine: return criterion - def get_function_objects(self, fields): - func = fields.split("(")[0].casefold().split() - func = [f for f in func if f in SQL_FUNCTIONS][0] - args = fields[len(func) + 1 : fields.index(")")].split(",") - args = [ - Field(literal_eval_((arg.strip()))) if "*" not in args else literal_eval_((arg.strip())) - for arg in args - ] - return getattr(functions, func.capitalize())(*args) + def get_function_object(self, field: str) -> "Function": + """Expects field to look like 'SUM(*)' or 'name' or something similar. Returns PyPika Function object""" + func = field.split("(", maxsplit=1)[0].capitalize() + args_start, args_end = len(func) + 1, field.index(")") + args = field[args_start:args_end].split(",") - def function_objects_to_fields(self, fields, is_str: bool): - if is_str: - functions = "" - for func in SQL_FUNCTIONS: - if f"{func}(" in fields: - functions = str(func) + str(BRACKETS_PATTERN.findall(fields)[0]) - return [self.get_function_objects(functions)] - if not functions: - return [] - else: - functions = [] - for field in fields: - field = field.casefold() if isinstance(field, str) else field - if not issubclass(type(field), Criterion): - if any([func in field and f"{func}(" in field for func in SQL_FUNCTIONS]): - functions.append(field) - return [self.get_function_objects(function) for function in functions] + to_cast = "*" not in args + _args = [] - def set_fields(self, fields, **kwargs): - fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" - if isinstance(fields, list) and None in fields and Field not in fields: - return None + for arg in args: + field = literal_eval_(arg.strip()) + if to_cast: + field = Field(field) + _args.append(field) - is_list = isinstance(fields, (list, tuple, set)) - if is_list and len(fields) == 1: - fields = fields[0] - is_list = False + return getattr(functions, func)(*_args) - is_str = isinstance(fields, str) - if is_str: - fields = fields.casefold() + def function_objects_from_string(self, fields): + functions = "" + for func in SQL_FUNCTIONS: + if f"{func}(" in fields: + functions = str(func) + str(BRACKETS_PATTERN.search(fields).group()) + return [self.get_function_object(functions)] + if not functions: + return [] - function_objects = ( - self.function_objects_to_fields(fields=fields, is_str=is_str) - if not issubclass(type(fields), Criterion) - else [] - ) + def function_objects_from_list(self, fields): + functions = [] + for field in fields: + field = field.casefold() if isinstance(field, str) else field + if not issubclass(type(field), Criterion): + if any([func in field and f"{func}(" in field for func in SQL_FUNCTIONS]): + functions.append(field) + return [self.get_function_object(function) for function in functions] + def remove_string_functions(self, fields, function_objects): + """Remove string functions from fields which have already been converted to function objects""" for function in function_objects: - if is_str: - fields = BRACKETS_PATTERN.sub( - "", fields.replace(str(type(function).__name__).strip().casefold(), "") - ) + if isinstance(fields, str): + fields = BRACKETS_PATTERN.sub("", fields.replace(function.name.casefold(), "")) else: updated_fields = [] for field in fields: if isinstance(field, str): updated_fields.append( - BRACKETS_PATTERN.sub("", field) - .strip() - .casefold() - .replace(str(type(function).__name__).strip().casefold(), "") + BRACKETS_PATTERN.sub("", field).strip().casefold().replace(function.name.casefold(), "") ) else: updated_fields.append(field) fields = [field for field in updated_fields if field] + return fields + + def set_fields(self, fields, **kwargs): + fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" + if isinstance(fields, list) and None in fields and Field not in fields: + return None + + function_objects = [] + + is_list = isinstance(fields, (list, tuple, set)) + if is_list and len(fields) == 1: + fields = fields[0] + is_list = False + + if is_list: + function_objects += self.function_objects_from_list(fields=fields) + + is_str = isinstance(fields, str) + if is_str: + fields = fields.casefold() + function_objects += self.function_objects_from_string(fields=fields) + + fields = self.remove_string_functions(fields, function_objects) + if is_str and "," in fields: - fields = fields.split(",") - fields = [field.replace(" ", "") if "as" not in field else field for field in fields] + fields = [field.replace(" ", "") if "as" not in field else field for field in fields.split(",")] is_list, is_str = True, False if is_str: From ee18694b1b1eb9bff1f2cf161976c201ba54bc75 Mon Sep 17 00:00:00 2001 From: Gavin D'souza Date: Wed, 29 Jun 2022 10:40:21 +0530 Subject: [PATCH 10/10] test(fix): Import Max for test_functions_fields --- frappe/tests/test_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index 06550b44a4..e7682a0d0c 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -42,7 +42,7 @@ class TestQuery(unittest.TestCase): ) def test_functions_fields(self): - from frappe.query_builder.functions import Count + from frappe.query_builder.functions import Count, Max self.assertEqual( frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(),