From 6db6be1f3c8b8dc9e8e458c6191e5e91e8de8943 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Tue, 21 Jun 2022 18:42:06 +0530 Subject: [PATCH] refactor: frappe.qb.engine * Small fixes in set_fields and clean code * Optimize casefolds * Fixed functions passed in List * get_sql => get_query - more expressive, less confusion * Updated tests --- frappe/database/database.py | 10 ++-- frappe/database/query.py | 94 +++++++++++++++++++---------------- frappe/tests/test_db_query.py | 2 +- frappe/tests/test_query.py | 43 ++++++++-------- 4 files changed, 79 insertions(+), 70 deletions(-) diff --git a/frappe/database/database.py b/frappe/database/database.py index fd94bcfe12..b418b5e9f6 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -586,7 +586,7 @@ class Database(object): return [map(values.get, fields)] else: - r = frappe.qb.engine.get_sql( + r = frappe.qb.engine.get_query( "Singles", filters={"field": ("in", tuple(fields)), "doctype": doctype}, fields=["field", "value"], @@ -616,7 +616,7 @@ class Database(object): # Get coulmn and value of the single doctype Accounts Settings account_settings = frappe.db.get_singles_dict("Accounts Settings") """ - queried_result = frappe.qb.engine.get_sql( + queried_result = frappe.qb.engine.get_query( "Singles", filters={"doctype": doctype}, fields=["field", "value"], @@ -672,7 +672,7 @@ class Database(object): if cache and fieldname in self.value_cache[doctype]: return self.value_cache[doctype][fieldname] - val = frappe.qb.engine.get_sql( + val = frappe.qb.engine.get_query( table="Singles", filters={"doctype": doctype, "field": fieldname}, fields="value", @@ -714,7 +714,7 @@ class Database(object): ): field_objects = [] - query = frappe.qb.engine.get_sql( + query = frappe.qb.engine.get_query( table=doctype, filters=filters, orderby=order_by, @@ -1025,7 +1025,7 @@ class Database(object): cache_count = frappe.cache().get_value("doctype:count:{}".format(dt)) if cache_count is not None: return cache_count - query = frappe.qb.engine.get_sql(table=dt, filters=filters, fields=Count("*"), distinct=distinct) + query = frappe.qb.engine.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct) count = self.sql(query, debug=debug)[0][0] if not filters and cache: frappe.cache().set_value("doctype:count:{}".format(dt), count, expires_in_sec=86400) diff --git a/frappe/database/query.py b/frappe/database/query.py index abf5c6936f..155cc99aab 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -8,10 +8,13 @@ import frappe from frappe import _ from frappe.boot import get_additional_filters_from_hooks from frappe.model.db_query import get_timespan_date_range -from frappe.query_builder import Criterion, Field, Order, Table +from frappe.query_builder import Criterion, Field, Order, Table, functions +from frappe.query_builder.functions import SqlFunctions TAB_PATTERN = re.compile("^tab") WORDS_PATTERN = re.compile(r"\w+") +BRACKETS_PATTERN = re.compile(r"\(.*?\)") +SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions] def like(key: Field, value: str) -> frappe.qb: @@ -144,6 +147,13 @@ def change_orderby(order: str): return order[0], Order.desc +def literal_eval_(literal): + try: + return literal_eval(literal) + except (ValueError, SyntaxError): + return literal + + # default operators OPERATOR_MAP: Dict[str, Callable] = { "+": operator.add, @@ -364,6 +374,34 @@ class Engine: return criterion + def get_function_objects(self, fields): + func = fields.split("(")[0].casefold().split() + func = [f for f in func if f in SQL_FUNCTIONS][0] + args = fields[len(func) + 1 : fields.index(")")].split(",") + args = [ + Field(literal_eval_((arg.strip()))) if "*" not in args else literal_eval_((arg.strip())) + for arg in args + ] + return getattr(functions, func.capitalize())(*args) + + def function_objects_to_fields(self, fields, is_str: bool): + if is_str: + functions = "" + for func in SQL_FUNCTIONS: + if f"{func}(" in fields: + functions = str(func) + str(BRACKETS_PATTERN.findall(fields)[0]) + return [self.get_function_objects(functions)] + if not functions: + return [] + else: + functions = [] + for field in fields: + field = field.casefold() if isinstance(field, str) else field + if not issubclass(type(field), Criterion): + if any([func in field and f"{func}(" in field for func in SQL_FUNCTIONS]): + functions.append(field) + return [self.get_function_objects(function) for function in functions] + def set_fields(self, fields, **kwargs): fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" if isinstance(fields, list) and None in fields and Field not in fields: @@ -375,57 +413,26 @@ class Engine: is_list = False is_str = isinstance(fields, str) - - def add_functions(fields): - from frappe.query_builder.functions import SqlFunctions - - sql_functions = [sql_function.value for sql_function in SqlFunctions] - - def get_function_objects(fields): - from frappe.query_builder import functions - - def literal_eval_(literal): - try: - return literal_eval(literal) - except (ValueError, SyntaxError): - return literal - - func = fields.split("(")[0].casefold().split() - func = [f for f in func if f in sql_functions][0] - args = fields[len(func) + 1 : fields.index(")")].split(",") - args = [Field(literal_eval_((arg.strip()))) for arg in args] - return getattr(functions, func.capitalize())(*args) - - if is_str and any( - [func in fields.casefold() and f"{func}(" in fields.casefold() for func in sql_functions] - ): - function_objects = [] - return function_objects or [get_function_objects(fields)] - else: - functions = [] - for field in fields: - if not issubclass(type(field), Criterion): - if any( - [func in field.casefold() and f"{func}(" in field.casefold() for func in sql_functions] - ): - functions.append(field.casefold()) - return [get_function_objects(function) for function in functions] + if is_str: + fields = fields.casefold() function_objects = ( - add_functions(fields=fields) if not issubclass(type(fields), Criterion) else [] + self.function_objects_to_fields(fields=fields, is_str=is_str) + if not issubclass(type(fields), Criterion) + else [] ) + for function in function_objects: if is_str: - fields = re.sub( - r"\(.*?\)", "", fields.casefold().replace(str(type(function).__name__).strip().casefold(), "") + fields = BRACKETS_PATTERN.sub( + "", fields.replace(str(type(function).__name__).strip().casefold(), "") ) - else: updated_fields = [] for field in fields: if isinstance(field, str): updated_fields.append( - re.sub(r"\(.*?\)", "", field) + BRACKETS_PATTERN.sub("", field) .strip() .casefold() .replace(str(type(function).__name__).strip().casefold(), "") @@ -433,11 +440,12 @@ class Engine: else: updated_fields.append(field) - fields = updated_fields + fields = [field for field in updated_fields if field] if is_str and "," in fields: fields = fields.split(",") fields = [field.replace(" ", "") if "as" not in field else field for field in fields] + is_list, is_str = True, False if is_str: if fields == "*": @@ -469,7 +477,7 @@ class Engine: fields.extend(function_objects) return fields - def get_sql( + def get_query( self, table: str, fields: Union[List, Tuple], diff --git a/frappe/tests/test_db_query.py b/frappe/tests/test_db_query.py index 3a4e5b72bd..8727951f4a 100644 --- a/frappe/tests/test_db_query.py +++ b/frappe/tests/test_db_query.py @@ -143,7 +143,7 @@ class TestReportview(unittest.TestCase): ) def test_none_filter(self): - query = frappe.qb.engine.get_sql("DocType", fields="name", filters={"restrict_to_domain": None}) + query = frappe.qb.engine.get_query("DocType", fields="name", filters={"restrict_to_domain": None}) sql = str(query).replace("`", "").replace('"', "") condition = "restrict_to_domain IS NULL" self.assertIn(condition, sql) diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index dd74d5ad18..06550b44a4 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -9,7 +9,7 @@ class TestQuery(unittest.TestCase): @run_only_if(db_type_is.MARIADB) def test_multiple_tables_in_filters(self): self.assertEqual( - frappe.qb.engine.get_sql( + frappe.qb.engine.get_query( "DocType", ["*"], [ @@ -22,52 +22,53 @@ class TestQuery(unittest.TestCase): def test_string_fields(self): self.assertEqual( - frappe.qb.engine.get_sql("User", fields="name, email", filters={"name": "Administrator"}), + frappe.qb.engine.get_query( + "User", fields="name, email", filters={"name": "Administrator"} + ).get_sql(), frappe.qb.from_("User") .select(Field("name"), Field("email")) - .where(Field("name") == "Administrator"), + .where(Field("name") == "Administrator") + .get_sql(), ) self.assertEqual( - frappe.qb.engine.get_sql("User", fields=["name, email"], filters={"name": "Administrator"}), + frappe.qb.engine.get_query( + "User", fields=["name, email"], filters={"name": "Administrator"} + ).get_sql(), frappe.qb.from_("User") .select(Field("name"), Field("email")) - .where(Field("name") == "Administrator"), + .where(Field("name") == "Administrator") + .get_sql(), ) def test_functions_fields(self): from frappe.query_builder.functions import Count self.assertEqual( - frappe.qb.engine.get_sql("User", fields="Count(name)", filters={}), - frappe.qb.from_("User").select(Count(Field("name"))), + frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(), + frappe.qb.from_("User").select(Count(Field("name"))).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_sql("User", fields="Count(name), Max(name)", filters={}), - frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), + frappe.qb.engine.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(), + frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_sql("User", fields=["Count(name)", "Max(name)"], filters={}), - frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), - ) - - self.assertEqual( - frappe.qb.engine.get_sql("User", fields=[Count("*")], filters={}), - frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))), + frappe.qb.engine.get_query("User", fields=[Count("*")], filters={}).get_sql(), + frappe.qb.from_("User").select(Count("*")).get_sql(), ) def test_qb_fields(self): user_doctype = frappe.qb.DocType("User") self.assertEqual( - frappe.qb.engine.get_sql( + frappe.qb.engine.get_query( user_doctype, fields=[user_doctype.name, user_doctype.email], filters={} - ), - frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email), + ).get_sql(), + frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email).get_sql(), ) self.assertEqual( - frappe.qb.engine.get_sql(user_doctype, fields=user_doctype.email, filters={}), - frappe.qb.from_(user_doctype).select(user_doctype.email), + frappe.qb.engine.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(), + frappe.qb.from_(user_doctype).select(user_doctype.email).get_sql(), )