refactor: frappe.qb.engine

* Small fixes in set_fields and clean code
* Optimize casefolds
* Fixed functions passed in List
* get_sql => get_query - more expressive, less confusion
* Updated tests
This commit is contained in:
Aradhya 2022-06-21 18:42:06 +05:30 committed by Gavin D'souza
parent 4af2e1e886
commit 6db6be1f3c
4 changed files with 79 additions and 70 deletions

View file

@ -586,7 +586,7 @@ class Database(object):
return [map(values.get, fields)]
else:
r = frappe.qb.engine.get_sql(
r = frappe.qb.engine.get_query(
"Singles",
filters={"field": ("in", tuple(fields)), "doctype": doctype},
fields=["field", "value"],
@ -616,7 +616,7 @@ class Database(object):
# Get coulmn and value of the single doctype Accounts Settings
account_settings = frappe.db.get_singles_dict("Accounts Settings")
"""
queried_result = frappe.qb.engine.get_sql(
queried_result = frappe.qb.engine.get_query(
"Singles",
filters={"doctype": doctype},
fields=["field", "value"],
@ -672,7 +672,7 @@ class Database(object):
if cache and fieldname in self.value_cache[doctype]:
return self.value_cache[doctype][fieldname]
val = frappe.qb.engine.get_sql(
val = frappe.qb.engine.get_query(
table="Singles",
filters={"doctype": doctype, "field": fieldname},
fields="value",
@ -714,7 +714,7 @@ class Database(object):
):
field_objects = []
query = frappe.qb.engine.get_sql(
query = frappe.qb.engine.get_query(
table=doctype,
filters=filters,
orderby=order_by,
@ -1025,7 +1025,7 @@ class Database(object):
cache_count = frappe.cache().get_value("doctype:count:{}".format(dt))
if cache_count is not None:
return cache_count
query = frappe.qb.engine.get_sql(table=dt, filters=filters, fields=Count("*"), distinct=distinct)
query = frappe.qb.engine.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct)
count = self.sql(query, debug=debug)[0][0]
if not filters and cache:
frappe.cache().set_value("doctype:count:{}".format(dt), count, expires_in_sec=86400)

View file

@ -8,10 +8,13 @@ import frappe
from frappe import _
from frappe.boot import get_additional_filters_from_hooks
from frappe.model.db_query import get_timespan_date_range
from frappe.query_builder import Criterion, Field, Order, Table
from frappe.query_builder import Criterion, Field, Order, Table, functions
from frappe.query_builder.functions import SqlFunctions
TAB_PATTERN = re.compile("^tab")
WORDS_PATTERN = re.compile(r"\w+")
BRACKETS_PATTERN = re.compile(r"\(.*?\)")
SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions]
def like(key: Field, value: str) -> frappe.qb:
@ -144,6 +147,13 @@ def change_orderby(order: str):
return order[0], Order.desc
def literal_eval_(literal):
try:
return literal_eval(literal)
except (ValueError, SyntaxError):
return literal
# default operators
OPERATOR_MAP: Dict[str, Callable] = {
"+": operator.add,
@ -364,6 +374,34 @@ class Engine:
return criterion
def get_function_objects(self, fields):
func = fields.split("(")[0].casefold().split()
func = [f for f in func if f in SQL_FUNCTIONS][0]
args = fields[len(func) + 1 : fields.index(")")].split(",")
args = [
Field(literal_eval_((arg.strip()))) if "*" not in args else literal_eval_((arg.strip()))
for arg in args
]
return getattr(functions, func.capitalize())(*args)
def function_objects_to_fields(self, fields, is_str: bool):
if is_str:
functions = ""
for func in SQL_FUNCTIONS:
if f"{func}(" in fields:
functions = str(func) + str(BRACKETS_PATTERN.findall(fields)[0])
return [self.get_function_objects(functions)]
if not functions:
return []
else:
functions = []
for field in fields:
field = field.casefold() if isinstance(field, str) else field
if not issubclass(type(field), Criterion):
if any([func in field and f"{func}(" in field for func in SQL_FUNCTIONS]):
functions.append(field)
return [self.get_function_objects(function) for function in functions]
def set_fields(self, fields, **kwargs):
fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name"
if isinstance(fields, list) and None in fields and Field not in fields:
@ -375,57 +413,26 @@ class Engine:
is_list = False
is_str = isinstance(fields, str)
def add_functions(fields):
from frappe.query_builder.functions import SqlFunctions
sql_functions = [sql_function.value for sql_function in SqlFunctions]
def get_function_objects(fields):
from frappe.query_builder import functions
def literal_eval_(literal):
try:
return literal_eval(literal)
except (ValueError, SyntaxError):
return literal
func = fields.split("(")[0].casefold().split()
func = [f for f in func if f in sql_functions][0]
args = fields[len(func) + 1 : fields.index(")")].split(",")
args = [Field(literal_eval_((arg.strip()))) for arg in args]
return getattr(functions, func.capitalize())(*args)
if is_str and any(
[func in fields.casefold() and f"{func}(" in fields.casefold() for func in sql_functions]
):
function_objects = []
return function_objects or [get_function_objects(fields)]
else:
functions = []
for field in fields:
if not issubclass(type(field), Criterion):
if any(
[func in field.casefold() and f"{func}(" in field.casefold() for func in sql_functions]
):
functions.append(field.casefold())
return [get_function_objects(function) for function in functions]
if is_str:
fields = fields.casefold()
function_objects = (
add_functions(fields=fields) if not issubclass(type(fields), Criterion) else []
self.function_objects_to_fields(fields=fields, is_str=is_str)
if not issubclass(type(fields), Criterion)
else []
)
for function in function_objects:
if is_str:
fields = re.sub(
r"\(.*?\)", "", fields.casefold().replace(str(type(function).__name__).strip().casefold(), "")
fields = BRACKETS_PATTERN.sub(
"", fields.replace(str(type(function).__name__).strip().casefold(), "")
)
else:
updated_fields = []
for field in fields:
if isinstance(field, str):
updated_fields.append(
re.sub(r"\(.*?\)", "", field)
BRACKETS_PATTERN.sub("", field)
.strip()
.casefold()
.replace(str(type(function).__name__).strip().casefold(), "")
@ -433,11 +440,12 @@ class Engine:
else:
updated_fields.append(field)
fields = updated_fields
fields = [field for field in updated_fields if field]
if is_str and "," in fields:
fields = fields.split(",")
fields = [field.replace(" ", "") if "as" not in field else field for field in fields]
is_list, is_str = True, False
if is_str:
if fields == "*":
@ -469,7 +477,7 @@ class Engine:
fields.extend(function_objects)
return fields
def get_sql(
def get_query(
self,
table: str,
fields: Union[List, Tuple],

View file

@ -143,7 +143,7 @@ class TestReportview(unittest.TestCase):
)
def test_none_filter(self):
query = frappe.qb.engine.get_sql("DocType", fields="name", filters={"restrict_to_domain": None})
query = frappe.qb.engine.get_query("DocType", fields="name", filters={"restrict_to_domain": None})
sql = str(query).replace("`", "").replace('"', "")
condition = "restrict_to_domain IS NULL"
self.assertIn(condition, sql)

View file

@ -9,7 +9,7 @@ class TestQuery(unittest.TestCase):
@run_only_if(db_type_is.MARIADB)
def test_multiple_tables_in_filters(self):
self.assertEqual(
frappe.qb.engine.get_sql(
frappe.qb.engine.get_query(
"DocType",
["*"],
[
@ -22,52 +22,53 @@ class TestQuery(unittest.TestCase):
def test_string_fields(self):
self.assertEqual(
frappe.qb.engine.get_sql("User", fields="name, email", filters={"name": "Administrator"}),
frappe.qb.engine.get_query(
"User", fields="name, email", filters={"name": "Administrator"}
).get_sql(),
frappe.qb.from_("User")
.select(Field("name"), Field("email"))
.where(Field("name") == "Administrator"),
.where(Field("name") == "Administrator")
.get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_sql("User", fields=["name, email"], filters={"name": "Administrator"}),
frappe.qb.engine.get_query(
"User", fields=["name, email"], filters={"name": "Administrator"}
).get_sql(),
frappe.qb.from_("User")
.select(Field("name"), Field("email"))
.where(Field("name") == "Administrator"),
.where(Field("name") == "Administrator")
.get_sql(),
)
def test_functions_fields(self):
from frappe.query_builder.functions import Count
self.assertEqual(
frappe.qb.engine.get_sql("User", fields="Count(name)", filters={}),
frappe.qb.from_("User").select(Count(Field("name"))),
frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(),
frappe.qb.from_("User").select(Count(Field("name"))).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_sql("User", fields="Count(name), Max(name)", filters={}),
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))),
frappe.qb.engine.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(),
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_sql("User", fields=["Count(name)", "Max(name)"], filters={}),
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))),
)
self.assertEqual(
frappe.qb.engine.get_sql("User", fields=[Count("*")], filters={}),
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))),
frappe.qb.engine.get_query("User", fields=[Count("*")], filters={}).get_sql(),
frappe.qb.from_("User").select(Count("*")).get_sql(),
)
def test_qb_fields(self):
user_doctype = frappe.qb.DocType("User")
self.assertEqual(
frappe.qb.engine.get_sql(
frappe.qb.engine.get_query(
user_doctype, fields=[user_doctype.name, user_doctype.email], filters={}
),
frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email),
).get_sql(),
frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email).get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_sql(user_doctype, fields=user_doctype.email, filters={}),
frappe.qb.from_(user_doctype).select(user_doctype.email),
frappe.qb.engine.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(),
frappe.qb.from_(user_doctype).select(user_doctype.email).get_sql(),
)