refactor: frappe.qb.engine
* feat: supporting empty iterables for Contains objects * fix: explicitly setting empty iterables as tuples to support more operators * feat: Added locate to frappe.qb Functions * feat: Added support for functions passed as strings in fields * feat: Included Criterion objects as fields * fix: picking up only function intended fields to pass to get_function_objects * feat: Added iterable for available functions, added support for Field objects * fix: fixed * passed in fields in lists
This commit is contained in:
parent
d1f5c49b02
commit
d0680941ad
3 changed files with 99 additions and 32 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import operator
|
||||
import re
|
||||
from ast import literal_eval
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
|
|
@ -325,7 +326,8 @@ class Query:
|
|||
continue
|
||||
if isinstance(value, (list, tuple)):
|
||||
_operator = self.OPERATOR_MAP[value[0].casefold()]
|
||||
conditions = conditions.where(_operator(Field(key), value[1]))
|
||||
_value = value[1] if value[1] else ("",)
|
||||
conditions = conditions.where(_operator(Field(key), _value))
|
||||
else:
|
||||
if value is not None:
|
||||
conditions = conditions.where(_operator(Field(key), value))
|
||||
|
|
@ -364,32 +366,102 @@ class Query:
|
|||
|
||||
def set_fields(self, fields, **kwargs):
|
||||
fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name"
|
||||
if isinstance(fields, str) and "," in fields:
|
||||
if isinstance(fields, list) and None in fields and Field not in fields:
|
||||
return None
|
||||
|
||||
is_list = isinstance(fields, (list, tuple, set))
|
||||
is_str = isinstance(fields, str)
|
||||
|
||||
def add_functions(fields):
|
||||
from frappe.query_builder.functions import SqlFunctions
|
||||
|
||||
sql_functions = [sql_function.value for sql_function in SqlFunctions]
|
||||
|
||||
def get_function_objects(fields):
|
||||
from frappe.query_builder import functions
|
||||
|
||||
def literal_eval_(literal):
|
||||
try:
|
||||
return literal_eval(literal)
|
||||
except (ValueError, SyntaxError):
|
||||
return literal
|
||||
|
||||
func = fields.split("(")[0].casefold().split()
|
||||
func = [f for f in func if f in sql_functions][0]
|
||||
args = fields[len(func) + 1 : fields.index(")")].split(",")
|
||||
args = [literal_eval_(arg.strip()) for arg in args]
|
||||
return getattr(functions, func.capitalize())(*args)
|
||||
|
||||
if is_str and any(
|
||||
[func in fields.casefold() and f"{func}(" in fields.casefold() for func in sql_functions]
|
||||
):
|
||||
function_objects = []
|
||||
return function_objects or [get_function_objects(fields)]
|
||||
else:
|
||||
functions = []
|
||||
for field in fields:
|
||||
if not issubclass(type(field), Criterion):
|
||||
if any(
|
||||
[func in field.casefold() and f"{func}(" in field.casefold() for func in sql_functions]
|
||||
):
|
||||
functions.append(field.casefold())
|
||||
return [get_function_objects(function) for function in functions]
|
||||
|
||||
function_objects = (
|
||||
add_functions(fields=fields) if not issubclass(type(fields), Criterion) else []
|
||||
)
|
||||
for function in function_objects:
|
||||
if is_str:
|
||||
fields = re.sub(
|
||||
r"\(.*?\)", "", fields.casefold().replace(str(type(function).__name__).strip().casefold(), "")
|
||||
)
|
||||
|
||||
else:
|
||||
updated_fields = []
|
||||
for field in fields:
|
||||
if isinstance(field, str):
|
||||
updated_fields.append(
|
||||
re.sub(r"\(.*?\)", "", field)
|
||||
.strip()
|
||||
.casefold()
|
||||
.replace(str(type(function).__name__).strip().casefold(), "")
|
||||
)
|
||||
else:
|
||||
updated_fields.append(field)
|
||||
|
||||
fields = updated_fields
|
||||
|
||||
if is_str and "," in fields:
|
||||
fields = fields.split(",")
|
||||
fields = [field.replace(" ", "") if "as" not in field else field for field in fields]
|
||||
|
||||
if isinstance(fields, str):
|
||||
if is_str:
|
||||
if fields == "*":
|
||||
return fields
|
||||
if " as " in fields:
|
||||
fields, reference = fields.split(" as ")
|
||||
fields = Field(fields).as_(reference)
|
||||
else:
|
||||
|
||||
if not is_str and fields:
|
||||
if issubclass(type(fields), Criterion):
|
||||
return fields
|
||||
updated_fields = list()
|
||||
updated_fields = []
|
||||
if "*" in fields:
|
||||
return fields
|
||||
for field in fields:
|
||||
if not isinstance(field, Criterion) and field:
|
||||
if " as " in field:
|
||||
field, reference = field.split(" as ")
|
||||
updated_fields.append(Field(field).as_(reference))
|
||||
updated_fields.append(Field(field.strip()).as_(reference))
|
||||
else:
|
||||
updated_fields.append(Field(field))
|
||||
fields = updated_fields
|
||||
|
||||
if not isinstance(fields, (list, tuple, str, Criterion)):
|
||||
fields = list(fields)
|
||||
fields = updated_fields
|
||||
|
||||
if not is_list:
|
||||
fields = [fields] if fields else []
|
||||
|
||||
fields.extend(function_objects)
|
||||
return fields
|
||||
|
||||
def get_sql(
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ class DatabaseQuery(object):
|
|||
if not self.columns:
|
||||
return []
|
||||
|
||||
result = self.build_and_run(ignore_permissions=ignore_permissions, pluck=pluck)
|
||||
result = self.build_and_run()
|
||||
|
||||
if with_comment_count and not as_list and self.doctype:
|
||||
self.add_comment_count(result)
|
||||
|
|
@ -177,7 +177,7 @@ class DatabaseQuery(object):
|
|||
|
||||
return result
|
||||
|
||||
def build_and_run(self, ignore_permissions, pluck):
|
||||
def build_and_run(self):
|
||||
args = self.prepare_args()
|
||||
args.limit = self.add_limit()
|
||||
|
||||
|
|
@ -202,27 +202,6 @@ class DatabaseQuery(object):
|
|||
%(limit)s"""
|
||||
% args
|
||||
)
|
||||
if ignore_permissions:
|
||||
sql = self.query.get_sql(
|
||||
self.doctype,
|
||||
fields=self.temp_fields,
|
||||
filters=self.temp_filters,
|
||||
pluck=pluck,
|
||||
join=self.join,
|
||||
orderby=self.order_by,
|
||||
groupby=self.group_by,
|
||||
distinct=self.distinct,
|
||||
limit=self.limit_page_length,
|
||||
offset=self.limit_start,
|
||||
)
|
||||
return sql.run(
|
||||
as_dict=not self.as_list,
|
||||
debug=self.debug,
|
||||
update=self.update,
|
||||
ignore_ddl=self.ignore_ddl,
|
||||
run=self.run,
|
||||
)
|
||||
|
||||
return frappe.db.sql(
|
||||
query,
|
||||
as_dict=not self.as_list,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from enum import Enum
|
||||
|
||||
from pypika.functions import *
|
||||
from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function
|
||||
|
||||
|
|
@ -14,6 +16,11 @@ class Concat_ws(Function):
|
|||
super(Concat_ws, self).__init__("CONCAT_WS", *terms, **kwargs)
|
||||
|
||||
|
||||
class Locate(Function):
|
||||
def __init__(self, *terms, **kwargs):
|
||||
super(Locate, self).__init__("LOCATE", *terms, **kwargs)
|
||||
|
||||
|
||||
GroupConcat = ImportMapper({db_type_is.MARIADB: GROUP_CONCAT, db_type_is.POSTGRES: STRING_AGG})
|
||||
|
||||
Match = ImportMapper({db_type_is.MARIADB: MATCH, db_type_is.POSTGRES: TO_TSVECTOR})
|
||||
|
|
@ -81,6 +88,15 @@ def _aggregate(function, dt, fieldname, filters, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
class SqlFunctions(Enum):
|
||||
DayOfYear = "dayofyear"
|
||||
Extract = "extract"
|
||||
Locate = "locate"
|
||||
Count = "count"
|
||||
Sum = "sum"
|
||||
Avg = "avg"
|
||||
|
||||
|
||||
def _max(dt, fieldname, filters=None, **kwargs):
|
||||
return _aggregate(Max, dt, fieldname, filters, **kwargs)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue