refactor: frappe.qb.engine

* feat: supporting empty iterables for Contains objects
* fix: explicitly setting empty iterables as tuples to support more operators
* feat: Added locate to frappe.qb Functions
* feat: Added support for functions passed as strings in fields
* feat: Included Criterion objects as fields
* fix: picking up only function intended fields to pass to get_function_objects
* feat: Added iterable for available functions, added support for Field objects
* fix: fixed * passed in fields in lists
This commit is contained in:
Aradhya 2022-06-13 16:25:47 +05:30 committed by Gavin D'souza
parent d1f5c49b02
commit d0680941ad
3 changed files with 99 additions and 32 deletions

View file

@ -1,5 +1,6 @@
import operator
import re
from ast import literal_eval
from functools import cached_property
from typing import Any, Callable, Dict, List, Tuple, Union
@ -325,7 +326,8 @@ class Query:
continue
if isinstance(value, (list, tuple)):
_operator = self.OPERATOR_MAP[value[0].casefold()]
conditions = conditions.where(_operator(Field(key), value[1]))
_value = value[1] if value[1] else ("",)
conditions = conditions.where(_operator(Field(key), _value))
else:
if value is not None:
conditions = conditions.where(_operator(Field(key), value))
@ -364,32 +366,102 @@ class Query:
def set_fields(self, fields, **kwargs):
fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name"
if isinstance(fields, str) and "," in fields:
if isinstance(fields, list) and None in fields and Field not in fields:
return None
is_list = isinstance(fields, (list, tuple, set))
is_str = isinstance(fields, str)
def add_functions(fields):
from frappe.query_builder.functions import SqlFunctions
sql_functions = [sql_function.value for sql_function in SqlFunctions]
def get_function_objects(fields):
from frappe.query_builder import functions
def literal_eval_(literal):
try:
return literal_eval(literal)
except (ValueError, SyntaxError):
return literal
func = fields.split("(")[0].casefold().split()
func = [f for f in func if f in sql_functions][0]
args = fields[len(func) + 1 : fields.index(")")].split(",")
args = [literal_eval_(arg.strip()) for arg in args]
return getattr(functions, func.capitalize())(*args)
if is_str and any(
[func in fields.casefold() and f"{func}(" in fields.casefold() for func in sql_functions]
):
function_objects = []
return function_objects or [get_function_objects(fields)]
else:
functions = []
for field in fields:
if not issubclass(type(field), Criterion):
if any(
[func in field.casefold() and f"{func}(" in field.casefold() for func in sql_functions]
):
functions.append(field.casefold())
return [get_function_objects(function) for function in functions]
function_objects = (
add_functions(fields=fields) if not issubclass(type(fields), Criterion) else []
)
for function in function_objects:
if is_str:
fields = re.sub(
r"\(.*?\)", "", fields.casefold().replace(str(type(function).__name__).strip().casefold(), "")
)
else:
updated_fields = []
for field in fields:
if isinstance(field, str):
updated_fields.append(
re.sub(r"\(.*?\)", "", field)
.strip()
.casefold()
.replace(str(type(function).__name__).strip().casefold(), "")
)
else:
updated_fields.append(field)
fields = updated_fields
if is_str and "," in fields:
fields = fields.split(",")
fields = [field.replace(" ", "") if "as" not in field else field for field in fields]
if isinstance(fields, str):
if is_str:
if fields == "*":
return fields
if " as " in fields:
fields, reference = fields.split(" as ")
fields = Field(fields).as_(reference)
else:
if not is_str and fields:
if issubclass(type(fields), Criterion):
return fields
updated_fields = list()
updated_fields = []
if "*" in fields:
return fields
for field in fields:
if not isinstance(field, Criterion) and field:
if " as " in field:
field, reference = field.split(" as ")
updated_fields.append(Field(field).as_(reference))
updated_fields.append(Field(field.strip()).as_(reference))
else:
updated_fields.append(Field(field))
fields = updated_fields
if not isinstance(fields, (list, tuple, str, Criterion)):
fields = list(fields)
fields = updated_fields
if not is_list:
fields = [fields] if fields else []
fields.extend(function_objects)
return fields
def get_sql(

View file

@ -163,7 +163,7 @@ class DatabaseQuery(object):
if not self.columns:
return []
result = self.build_and_run(ignore_permissions=ignore_permissions, pluck=pluck)
result = self.build_and_run()
if with_comment_count and not as_list and self.doctype:
self.add_comment_count(result)
@ -177,7 +177,7 @@ class DatabaseQuery(object):
return result
def build_and_run(self, ignore_permissions, pluck):
def build_and_run(self):
args = self.prepare_args()
args.limit = self.add_limit()
@ -202,27 +202,6 @@ class DatabaseQuery(object):
%(limit)s"""
% args
)
if ignore_permissions:
sql = self.query.get_sql(
self.doctype,
fields=self.temp_fields,
filters=self.temp_filters,
pluck=pluck,
join=self.join,
orderby=self.order_by,
groupby=self.group_by,
distinct=self.distinct,
limit=self.limit_page_length,
offset=self.limit_start,
)
return sql.run(
as_dict=not self.as_list,
debug=self.debug,
update=self.update,
ignore_ddl=self.ignore_ddl,
run=self.run,
)
return frappe.db.sql(
query,
as_dict=not self.as_list,

View file

@ -1,3 +1,5 @@
from enum import Enum
from pypika.functions import *
from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function
@ -14,6 +16,11 @@ class Concat_ws(Function):
super(Concat_ws, self).__init__("CONCAT_WS", *terms, **kwargs)
class Locate(Function):
def __init__(self, *terms, **kwargs):
super(Locate, self).__init__("LOCATE", *terms, **kwargs)
GroupConcat = ImportMapper({db_type_is.MARIADB: GROUP_CONCAT, db_type_is.POSTGRES: STRING_AGG})
Match = ImportMapper({db_type_is.MARIADB: MATCH, db_type_is.POSTGRES: TO_TSVECTOR})
@ -81,6 +88,15 @@ def _aggregate(function, dt, fieldname, filters, **kwargs):
)
class SqlFunctions(Enum):
DayOfYear = "dayofyear"
Extract = "extract"
Locate = "locate"
Count = "count"
Sum = "sum"
Avg = "avg"
def _max(dt, fieldname, filters=None, **kwargs):
return _aggregate(Max, dt, fieldname, filters, **kwargs)