Merge pull request #17131 from Aradhya-Tripathi/get-all-mod

feat: Adding support to Query engine
This commit is contained in:
mergify[bot] 2022-06-29 05:36:33 +00:00 committed by GitHub
commit 3e69b562f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 257 additions and 44 deletions

View file

@ -22,7 +22,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import click
from werkzeug.local import Local, release_local
from frappe.query_builder import get_query_builder, patch_query_aggregation, patch_query_execute
from frappe.query_builder import (
get_qb_engine,
get_query_builder,
patch_query_aggregation,
patch_query_execute,
)
from frappe.utils.caching import request_cache
from frappe.utils.data import cstr, sbool
@ -240,7 +245,7 @@ def init(site, sites_path=None, new_site=False):
local.session = _dict()
local.dev_server = _dev_server
local.qb = get_query_builder(local.conf.db_type or "mariadb")
local.qb.engine = get_qb_engine()
setup_module_map()
if not _qb_patched.get(local.conf.db_type):

View file

@ -12,7 +12,7 @@ from contextlib import contextmanager
from time import time
from typing import Dict, List, Optional, Tuple, Union
from pypika.terms import Criterion, NullValue, PseudoColumn
from pypika.terms import Criterion, NullValue
import frappe
import frappe.defaults
@ -75,15 +75,6 @@ class Database(object):
self.password = password or frappe.conf.db_password
self.value_cache = {}
@property
def query(self):
if not hasattr(self, "_query"):
from .query import Query
self._query = Query()
del Query
return self._query
def setup_type_map(self):
pass
@ -600,7 +591,7 @@ class Database(object):
return [map(values.get, fields)]
else:
r = self.query.get_sql(
r = frappe.qb.engine.get_query(
"Singles",
filters={"field": ("in", tuple(fields)), "doctype": doctype},
fields=["field", "value"],
@ -633,7 +624,7 @@ class Database(object):
# Get coulmn and value of the single doctype Accounts Settings
account_settings = frappe.db.get_singles_dict("Accounts Settings")
"""
queried_result = self.query.get_sql(
queried_result = frappe.qb.engine.get_query(
"Singles",
filters={"doctype": doctype},
fields=["field", "value"],
@ -706,7 +697,7 @@ class Database(object):
if cache and fieldname in self.value_cache[doctype]:
return self.value_cache[doctype][fieldname]
val = self.query.get_sql(
val = frappe.qb.engine.get_query(
table="Singles",
filters={"doctype": doctype, "field": fieldname},
fields="value",
@ -748,14 +739,7 @@ class Database(object):
):
field_objects = []
if not isinstance(fields, Criterion):
for field in fields:
if "(" in str(field) or " as " in str(field):
field_objects.append(PseudoColumn(field))
else:
field_objects.append(field)
query = self.query.get_sql(
query = frappe.qb.engine.get_query(
table=doctype,
filters=filters,
orderby=order_by,
@ -865,7 +849,7 @@ class Database(object):
frappe.clear_document_cache(dt, docname)
else:
query = self.query.build_conditions(table=dt, filters=dn, update=True)
query = frappe.qb.engine.build_conditions(table=dt, filters=dn, update=True)
# TODO: Fix this; doesn't work rn - gavin@frappe.io
# frappe.cache().hdel_keys(dt, "document_cache")
# Workaround: clear all document caches
@ -1066,7 +1050,7 @@ class Database(object):
cache_count = frappe.cache().get_value("doctype:count:{}".format(dt))
if cache_count is not None:
return cache_count
query = self.query.get_sql(table=dt, filters=filters, fields=Count("*"), distinct=distinct)
query = frappe.qb.engine.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct)
count = self.sql(query, debug=debug)[0][0]
if not filters and cache:
frappe.cache().set_value("doctype:count:{}".format(dt), count, expires_in_sec=86400)
@ -1206,7 +1190,7 @@ class Database(object):
Doctype name can be passed directly, it will be pre-pended with `tab`.
"""
filters = filters or kwargs.get("conditions")
query = self.query.build_conditions(table=doctype, filters=filters).delete()
query = frappe.qb.engine.build_conditions(table=doctype, filters=filters).delete()
if "debug" not in kwargs:
kwargs["debug"] = debug
return query.run(**kwargs)

View file

@ -1,16 +1,23 @@
import operator
import re
from ast import literal_eval
from functools import cached_property
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union
import frappe
from frappe import _
from frappe.boot import get_additional_filters_from_hooks
from frappe.model.db_query import get_timespan_date_range
from frappe.query_builder import Criterion, Field, Order, Table
from frappe.query_builder import Criterion, Field, Order, Table, functions
from frappe.query_builder.functions import SqlFunctions
TAB_PATTERN = re.compile("^tab")
WORDS_PATTERN = re.compile(r"\w+")
BRACKETS_PATTERN = re.compile(r"\(.*?\)|$")
SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions]
if TYPE_CHECKING:
from pypika.functions import Function
def like(key: Field, value: str) -> frappe.qb:
@ -93,7 +100,7 @@ def func_between(key: Field, value: Union[List, Tuple]) -> frappe.qb:
def func_is(key, value):
"Wrapper for IS"
return Field(key).isnotnull() if value.lower() == "set" else Field(key).isnull()
return key.isnotnull() if value.lower() == "set" else key.isnull()
def func_timespan(key: Field, value: str) -> frappe.qb:
@ -143,6 +150,13 @@ def change_orderby(order: str):
return order[0], Order.desc
def literal_eval_(literal):
try:
return literal_eval(literal)
except (ValueError, SyntaxError):
return literal
# default operators
OPERATOR_MAP: Dict[str, Callable] = {
"+": operator.add,
@ -168,7 +182,7 @@ OPERATOR_MAP: Dict[str, Callable] = {
}
class Query:
class Engine:
tables: dict = {}
@cached_property
@ -238,7 +252,7 @@ class Query:
Returns:
conditions (frappe.qb): frappe.qb object
"""
if kwargs.get("orderby"):
if kwargs.get("orderby") and kwargs.get("orderby") != "KEEP_DEFAULT_ORDERING":
orderby = kwargs.get("orderby")
if isinstance(orderby, str) and len(orderby.split()) > 1:
for ordby in orderby.split(","):
@ -250,6 +264,7 @@ class Query:
if kwargs.get("limit"):
conditions = conditions.limit(kwargs.get("limit"))
conditions = conditions.offset(kwargs.get("offset", 0))
if kwargs.get("distinct"):
conditions = conditions.distinct()
@ -257,6 +272,9 @@ class Query:
if kwargs.get("for_update"):
conditions = conditions.for_update()
if kwargs.get("groupby"):
conditions = conditions.groupby(kwargs.get("groupby"))
return conditions
def misc_query(self, table: str, filters: Union[List, Tuple] = None, **kwargs):
@ -308,6 +326,10 @@ class Query:
conditions = self.add_conditions(conditions, **kwargs)
return conditions
for key, value in filters.items():
if isinstance(value, bool):
filters.update({key: str(int(value))})
for key in filters:
value = filters.get(key)
_operator = self.OPERATOR_MAP["="]
@ -317,7 +339,8 @@ class Query:
continue
if isinstance(value, (list, tuple)):
_operator = self.OPERATOR_MAP[value[0].casefold()]
conditions = conditions.where(_operator(Field(key), value[1]))
_value = value[1] if value[1] else ("",)
conditions = conditions.where(_operator(Field(key), _value))
else:
if value is not None:
conditions = conditions.where(_operator(Field(key), value))
@ -354,7 +377,117 @@ class Query:
return criterion
def get_sql(
def get_function_object(self, field: str) -> "Function":
"""Expects field to look like 'SUM(*)' or 'name' or something similar. Returns PyPika Function object"""
func = field.split("(", maxsplit=1)[0].capitalize()
args_start, args_end = len(func) + 1, field.index(")")
args = field[args_start:args_end].split(",")
to_cast = "*" not in args
_args = []
for arg in args:
field = literal_eval_(arg.strip())
if to_cast:
field = Field(field)
_args.append(field)
return getattr(functions, func)(*_args)
def function_objects_from_string(self, fields):
functions = ""
for func in SQL_FUNCTIONS:
if f"{func}(" in fields:
functions = str(func) + str(BRACKETS_PATTERN.search(fields).group())
return [self.get_function_object(functions)]
if not functions:
return []
def function_objects_from_list(self, fields):
functions = []
for field in fields:
field = field.casefold() if isinstance(field, str) else field
if not issubclass(type(field), Criterion):
if any([func in field and f"{func}(" in field for func in SQL_FUNCTIONS]):
functions.append(field)
return [self.get_function_object(function) for function in functions]
def remove_string_functions(self, fields, function_objects):
"""Remove string functions from fields which have already been converted to function objects"""
for function in function_objects:
if isinstance(fields, str):
fields = BRACKETS_PATTERN.sub("", fields.replace(function.name.casefold(), ""))
else:
updated_fields = []
for field in fields:
if isinstance(field, str):
updated_fields.append(
BRACKETS_PATTERN.sub("", field).strip().casefold().replace(function.name.casefold(), "")
)
else:
updated_fields.append(field)
fields = [field for field in updated_fields if field]
return fields
def set_fields(self, fields, **kwargs):
fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name"
if isinstance(fields, list) and None in fields and Field not in fields:
return None
function_objects = []
is_list = isinstance(fields, (list, tuple, set))
if is_list and len(fields) == 1:
fields = fields[0]
is_list = False
if is_list:
function_objects += self.function_objects_from_list(fields=fields)
is_str = isinstance(fields, str)
if is_str:
fields = fields.casefold()
function_objects += self.function_objects_from_string(fields=fields)
fields = self.remove_string_functions(fields, function_objects)
if is_str and "," in fields:
fields = [field.replace(" ", "") if "as" not in field else field for field in fields.split(",")]
is_list, is_str = True, False
if is_str:
if fields == "*":
return fields
if " as " in fields:
fields, reference = fields.split(" as ")
fields = Field(fields).as_(reference)
if not is_str and fields:
if issubclass(type(fields), Criterion):
return fields
updated_fields = []
if "*" in fields:
return fields
for field in fields:
if not isinstance(field, Criterion) and field:
if " as " in field:
field, reference = field.split(" as ")
updated_fields.append(Field(field.strip()).as_(reference))
else:
updated_fields.append(Field(field))
fields = updated_fields
# Need to check instance again since fields modified.
if not isinstance(fields, (list, tuple, set)):
fields = [fields] if fields else []
fields.extend(function_objects)
return fields
def get_query(
self,
table: str,
fields: Union[List, Tuple],
@ -364,15 +497,20 @@ class Query:
# Clean up state before each query
self.tables = {}
criterion = self.build_conditions(table, filters, **kwargs)
fields = self.set_fields(kwargs.get("field_objects") or fields, **kwargs)
join = kwargs.get("join").replace(" ", "_") if kwargs.get("join") else "left_join"
if len(self.tables) > 1:
primary_table = self.tables[table]
del self.tables[table]
for table_object in self.tables.values():
criterion = criterion.left_join(table_object).on(table_object.parent == primary_table.name)
criterion = getattr(criterion, join)(table_object).on(
table_object.parent == primary_table.name
)
if isinstance(fields, (list, tuple)):
query = criterion.select(*kwargs.get("field_objects", fields))
query = criterion.select(*fields)
elif isinstance(fields, Criterion):
query = criterion.select(fields)

View file

@ -204,7 +204,7 @@ def get_cards_for_user(doctype, txt, searchfield, start, page_len, filters):
if txt:
search_conditions = [numberCard[field].like("%{txt}%".format(txt=txt)) for field in searchfields]
condition_query = frappe.db.query.build_conditions(doctype, filters)
condition_query = frappe.qb.engine.build_conditions(doctype, filters)
return (
condition_query.select(numberCard.name, numberCard.label, numberCard.document_type)

View file

@ -37,7 +37,7 @@ def get_group_by_count(doctype: str, current_filters: str, field: str) -> List[D
ToDo = DocType("ToDo")
User = DocType("User")
count = Count("*").as_("count")
filtered_records = frappe.db.query.build_conditions(doctype, current_filters).select("name")
filtered_records = frappe.qb.engine.build_conditions(doctype, current_filters).select("name")
return (
frappe.qb.from_(ToDo)

View file

@ -7,6 +7,7 @@ from frappe.query_builder.terms import ParameterizedFunction, ParameterizedValue
from frappe.query_builder.utils import (
Column,
DocType,
get_qb_engine,
get_query_builder,
patch_query_aggregation,
patch_query_execute,

View file

@ -1,3 +1,5 @@
import typing
from pypika import MySQLQuery, Order, PostgreSQLQuery, terms
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
from pypika.queries import QueryBuilder, Schema, Table
@ -13,6 +15,13 @@ class Base:
Schema = Schema
Table = Table
# Added dynamic type hints for engine attribute
# which is to be assigned later.
if typing.TYPE_CHECKING:
from frappe.database.query import Engine
engine: Engine
@staticmethod
def functions(name: str, *args, **kwargs) -> Function:
return Function(name, *args, **kwargs)

View file

@ -1,8 +1,9 @@
from enum import Enum
from pypika.functions import *
from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function
import frappe
from frappe.database.query import Query
from frappe.query_builder.custom import GROUP_CONCAT, MATCH, STRING_AGG, TO_TSVECTOR
from frappe.query_builder.utils import ImportMapper, db_type_is
@ -14,6 +15,11 @@ class Concat_ws(Function):
super(Concat_ws, self).__init__("CONCAT_WS", *terms, **kwargs)
class Locate(Function):
def __init__(self, *terms, **kwargs):
super(Locate, self).__init__("LOCATE", *terms, **kwargs)
GroupConcat = ImportMapper({db_type_is.MARIADB: GROUP_CONCAT, db_type_is.POSTGRES: STRING_AGG})
Match = ImportMapper({db_type_is.MARIADB: MATCH, db_type_is.POSTGRES: TO_TSVECTOR})
@ -73,14 +79,24 @@ class Cast_(Function):
def _aggregate(function, dt, fieldname, filters, **kwargs):
return (
Query()
.build_conditions(dt, filters)
frappe.qb.engine.build_conditions(dt, filters)
.select(function(PseudoColumn(fieldname)))
.run(**kwargs)[0][0]
or 0
)
class SqlFunctions(Enum):
DayOfYear = "dayofyear"
Extract = "extract"
Locate = "locate"
Count = "count"
Sum = "sum"
Avg = "avg"
Max = "max"
Min = "min"
def _max(dt, fieldname, filters=None, **kwargs):
return _aggregate(Max, dt, fieldname, filters, **kwargs)

View file

@ -45,6 +45,12 @@ def get_query_builder(type_of_db: str) -> Union[Postgres, MariaDB]:
return picks[db]
def get_qb_engine():
from frappe.database.query import Engine
return Engine()
def get_attr(method_string):
modulename = ".".join(method_string.split(".")[:-1])
methodname = method_string.split(".")[-1]

View file

@ -143,7 +143,7 @@ class TestReportview(unittest.TestCase):
)
def test_none_filter(self):
query = frappe.db.query.get_sql("DocType", fields="name", filters={"restrict_to_domain": None})
query = frappe.qb.engine.get_query("DocType", fields="name", filters={"restrict_to_domain": None})
sql = str(query).replace("`", "").replace('"', "")
condition = "restrict_to_domain IS NULL"
self.assertIn(condition, sql)

View file

@ -1,14 +1,15 @@
import unittest
import frappe
from frappe.query_builder import Field
from frappe.tests.test_query_builder import db_type_is, run_only_if
@run_only_if(db_type_is.MARIADB)
class TestQuery(unittest.TestCase):
@run_only_if(db_type_is.MARIADB)
def test_multiple_tables_in_filters(self):
self.assertEqual(
frappe.db.query.get_sql(
frappe.qb.engine.get_query(
"DocType",
["*"],
[
@ -18,3 +19,56 @@ class TestQuery(unittest.TestCase):
).get_sql(),
"SELECT * FROM `tabDocType` LEFT JOIN `tabBOM Update Log` ON `tabBOM Update Log`.`parent`=`tabDocType`.`name` WHERE `tabBOM Update Log`.`name` LIKE 'f%' AND `tabDocType`.`parent`='something'",
)
def test_string_fields(self):
self.assertEqual(
frappe.qb.engine.get_query(
"User", fields="name, email", filters={"name": "Administrator"}
).get_sql(),
frappe.qb.from_("User")
.select(Field("name"), Field("email"))
.where(Field("name") == "Administrator")
.get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(
"User", fields=["name, email"], filters={"name": "Administrator"}
).get_sql(),
frappe.qb.from_("User")
.select(Field("name"), Field("email"))
.where(Field("name") == "Administrator")
.get_sql(),
)
def test_functions_fields(self):
from frappe.query_builder.functions import Count, Max
self.assertEqual(
frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(),
frappe.qb.from_("User").select(Count(Field("name"))).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(),
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query("User", fields=[Count("*")], filters={}).get_sql(),
frappe.qb.from_("User").select(Count("*")).get_sql(),
)
def test_qb_fields(self):
user_doctype = frappe.qb.DocType("User")
self.assertEqual(
frappe.qb.engine.get_query(
user_doctype, fields=[user_doctype.name, user_doctype.email], filters={}
).get_sql(),
frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(),
frappe.qb.from_(user_doctype).select(user_doctype.email).get_sql(),
)

View file

@ -25,7 +25,7 @@ def get_monthly_results(
date_format = "%m-%Y" if frappe.db.db_type != "postgres" else "MM-YYYY"
return dict(
frappe.db.query.build_conditions(table=goal_doctype, filters=filters)
frappe.qb.engine.build_conditions(table=goal_doctype, filters=filters)
.select(
DateFormat(Table[date_col], date_format).as_("month_year"),
Function(aggregation, goal_field),