From d0680941adacb6ebaa9bc9334be784566cf97a82 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Mon, 13 Jun 2022 16:25:47 +0530 Subject: [PATCH] refactor: frappe.qb.engine * feat: supporting empty iterables for Contains objects * fix: explicitly setting empty iterables as tuples to support more operators * feat: Added locate to frappe.qb Functions * feat: Added support for functions passed as strings in fields * feat: Included Criterion objects as fields * fix: picking up only function intended fields to pass to get_function_objects * feat: Added iterable for available functions, added support for Field objects * fix: fixed * passed in fields in lists --- frappe/database/query.py | 90 +++++++++++++++++++++++++++---- frappe/model/db_query.py | 25 +-------- frappe/query_builder/functions.py | 16 ++++++ 3 files changed, 99 insertions(+), 32 deletions(-) diff --git a/frappe/database/query.py b/frappe/database/query.py index 22b3074cad..b9de427c60 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -1,5 +1,6 @@ import operator import re +from ast import literal_eval from functools import cached_property from typing import Any, Callable, Dict, List, Tuple, Union @@ -325,7 +326,8 @@ class Query: continue if isinstance(value, (list, tuple)): _operator = self.OPERATOR_MAP[value[0].casefold()] - conditions = conditions.where(_operator(Field(key), value[1])) + _value = value[1] if value[1] else ("",) + conditions = conditions.where(_operator(Field(key), _value)) else: if value is not None: conditions = conditions.where(_operator(Field(key), value)) @@ -364,32 +366,102 @@ class Query: def set_fields(self, fields, **kwargs): fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" - if isinstance(fields, str) and "," in fields: + if isinstance(fields, list) and None in fields and Field not in fields: + return None + + is_list = isinstance(fields, (list, tuple, set)) + is_str = isinstance(fields, str) + + def add_functions(fields): + from frappe.query_builder.functions import SqlFunctions + + sql_functions = [sql_function.value for sql_function in SqlFunctions] + + def get_function_objects(fields): + from frappe.query_builder import functions + + def literal_eval_(literal): + try: + return literal_eval(literal) + except (ValueError, SyntaxError): + return literal + + func = fields.split("(")[0].casefold().split() + func = [f for f in func if f in sql_functions][0] + args = fields[len(func) + 1 : fields.index(")")].split(",") + args = [literal_eval_(arg.strip()) for arg in args] + return getattr(functions, func.capitalize())(*args) + + if is_str and any( + [func in fields.casefold() and f"{func}(" in fields.casefold() for func in sql_functions] + ): + function_objects = [] + return function_objects or [get_function_objects(fields)] + else: + functions = [] + for field in fields: + if not issubclass(type(field), Criterion): + if any( + [func in field.casefold() and f"{func}(" in field.casefold() for func in sql_functions] + ): + functions.append(field.casefold()) + return [get_function_objects(function) for function in functions] + + function_objects = ( + add_functions(fields=fields) if not issubclass(type(fields), Criterion) else [] + ) + for function in function_objects: + if is_str: + fields = re.sub( + r"\(.*?\)", "", fields.casefold().replace(str(type(function).__name__).strip().casefold(), "") + ) + + else: + updated_fields = [] + for field in fields: + if isinstance(field, str): + updated_fields.append( + re.sub(r"\(.*?\)", "", field) + .strip() + .casefold() + .replace(str(type(function).__name__).strip().casefold(), "") + ) + else: + updated_fields.append(field) + + fields = updated_fields + + if is_str and "," in fields: fields = fields.split(",") fields = [field.replace(" ", "") if "as" not in field else field for field in fields] - if isinstance(fields, str): + if is_str: if fields == "*": return fields if " as " in fields: fields, reference = fields.split(" as ") fields = Field(fields).as_(reference) - else: + + if not is_str and fields: if issubclass(type(fields), Criterion): return fields - updated_fields = list() + updated_fields = [] + if "*" in fields: + return fields for field in fields: if not isinstance(field, Criterion) and field: if " as " in field: field, reference = field.split(" as ") - updated_fields.append(Field(field).as_(reference)) + updated_fields.append(Field(field.strip()).as_(reference)) else: updated_fields.append(Field(field)) - fields = updated_fields - if not isinstance(fields, (list, tuple, str, Criterion)): - fields = list(fields) + fields = updated_fields + if not is_list: + fields = [fields] if fields else [] + + fields.extend(function_objects) return fields def get_sql( diff --git a/frappe/model/db_query.py b/frappe/model/db_query.py index 49b324b794..b9c757bfa6 100644 --- a/frappe/model/db_query.py +++ b/frappe/model/db_query.py @@ -163,7 +163,7 @@ class DatabaseQuery(object): if not self.columns: return [] - result = self.build_and_run(ignore_permissions=ignore_permissions, pluck=pluck) + result = self.build_and_run() if with_comment_count and not as_list and self.doctype: self.add_comment_count(result) @@ -177,7 +177,7 @@ class DatabaseQuery(object): return result - def build_and_run(self, ignore_permissions, pluck): + def build_and_run(self): args = self.prepare_args() args.limit = self.add_limit() @@ -202,27 +202,6 @@ class DatabaseQuery(object): %(limit)s""" % args ) - if ignore_permissions: - sql = self.query.get_sql( - self.doctype, - fields=self.temp_fields, - filters=self.temp_filters, - pluck=pluck, - join=self.join, - orderby=self.order_by, - groupby=self.group_by, - distinct=self.distinct, - limit=self.limit_page_length, - offset=self.limit_start, - ) - return sql.run( - as_dict=not self.as_list, - debug=self.debug, - update=self.update, - ignore_ddl=self.ignore_ddl, - run=self.run, - ) - return frappe.db.sql( query, as_dict=not self.as_list, diff --git a/frappe/query_builder/functions.py b/frappe/query_builder/functions.py index f03c139f57..7b260098d1 100644 --- a/frappe/query_builder/functions.py +++ b/frappe/query_builder/functions.py @@ -1,3 +1,5 @@ +from enum import Enum + from pypika.functions import * from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function @@ -14,6 +16,11 @@ class Concat_ws(Function): super(Concat_ws, self).__init__("CONCAT_WS", *terms, **kwargs) +class Locate(Function): + def __init__(self, *terms, **kwargs): + super(Locate, self).__init__("LOCATE", *terms, **kwargs) + + GroupConcat = ImportMapper({db_type_is.MARIADB: GROUP_CONCAT, db_type_is.POSTGRES: STRING_AGG}) Match = ImportMapper({db_type_is.MARIADB: MATCH, db_type_is.POSTGRES: TO_TSVECTOR}) @@ -81,6 +88,15 @@ def _aggregate(function, dt, fieldname, filters, **kwargs): ) +class SqlFunctions(Enum): + DayOfYear = "dayofyear" + Extract = "extract" + Locate = "locate" + Count = "count" + Sum = "sum" + Avg = "avg" + + def _max(dt, fieldname, filters=None, **kwargs): return _aggregate(Max, dt, fieldname, filters, **kwargs)