refactor: frappe.qb.engine
* Small fixes in set_fields and clean code * Optimize casefolds * Fixed functions passed in List * get_sql => get_query - more expressive, less confusion * Updated tests
This commit is contained in:
parent
4af2e1e886
commit
6db6be1f3c
4 changed files with 79 additions and 70 deletions
|
|
@ -586,7 +586,7 @@ class Database(object):
|
|||
return [map(values.get, fields)]
|
||||
|
||||
else:
|
||||
r = frappe.qb.engine.get_sql(
|
||||
r = frappe.qb.engine.get_query(
|
||||
"Singles",
|
||||
filters={"field": ("in", tuple(fields)), "doctype": doctype},
|
||||
fields=["field", "value"],
|
||||
|
|
@ -616,7 +616,7 @@ class Database(object):
|
|||
# Get coulmn and value of the single doctype Accounts Settings
|
||||
account_settings = frappe.db.get_singles_dict("Accounts Settings")
|
||||
"""
|
||||
queried_result = frappe.qb.engine.get_sql(
|
||||
queried_result = frappe.qb.engine.get_query(
|
||||
"Singles",
|
||||
filters={"doctype": doctype},
|
||||
fields=["field", "value"],
|
||||
|
|
@ -672,7 +672,7 @@ class Database(object):
|
|||
if cache and fieldname in self.value_cache[doctype]:
|
||||
return self.value_cache[doctype][fieldname]
|
||||
|
||||
val = frappe.qb.engine.get_sql(
|
||||
val = frappe.qb.engine.get_query(
|
||||
table="Singles",
|
||||
filters={"doctype": doctype, "field": fieldname},
|
||||
fields="value",
|
||||
|
|
@ -714,7 +714,7 @@ class Database(object):
|
|||
):
|
||||
field_objects = []
|
||||
|
||||
query = frappe.qb.engine.get_sql(
|
||||
query = frappe.qb.engine.get_query(
|
||||
table=doctype,
|
||||
filters=filters,
|
||||
orderby=order_by,
|
||||
|
|
@ -1025,7 +1025,7 @@ class Database(object):
|
|||
cache_count = frappe.cache().get_value("doctype:count:{}".format(dt))
|
||||
if cache_count is not None:
|
||||
return cache_count
|
||||
query = frappe.qb.engine.get_sql(table=dt, filters=filters, fields=Count("*"), distinct=distinct)
|
||||
query = frappe.qb.engine.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct)
|
||||
count = self.sql(query, debug=debug)[0][0]
|
||||
if not filters and cache:
|
||||
frappe.cache().set_value("doctype:count:{}".format(dt), count, expires_in_sec=86400)
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@ import frappe
|
|||
from frappe import _
|
||||
from frappe.boot import get_additional_filters_from_hooks
|
||||
from frappe.model.db_query import get_timespan_date_range
|
||||
from frappe.query_builder import Criterion, Field, Order, Table
|
||||
from frappe.query_builder import Criterion, Field, Order, Table, functions
|
||||
from frappe.query_builder.functions import SqlFunctions
|
||||
|
||||
TAB_PATTERN = re.compile("^tab")
|
||||
WORDS_PATTERN = re.compile(r"\w+")
|
||||
BRACKETS_PATTERN = re.compile(r"\(.*?\)")
|
||||
SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions]
|
||||
|
||||
|
||||
def like(key: Field, value: str) -> frappe.qb:
|
||||
|
|
@ -144,6 +147,13 @@ def change_orderby(order: str):
|
|||
return order[0], Order.desc
|
||||
|
||||
|
||||
def literal_eval_(literal):
|
||||
try:
|
||||
return literal_eval(literal)
|
||||
except (ValueError, SyntaxError):
|
||||
return literal
|
||||
|
||||
|
||||
# default operators
|
||||
OPERATOR_MAP: Dict[str, Callable] = {
|
||||
"+": operator.add,
|
||||
|
|
@ -364,6 +374,34 @@ class Engine:
|
|||
|
||||
return criterion
|
||||
|
||||
def get_function_objects(self, fields):
|
||||
func = fields.split("(")[0].casefold().split()
|
||||
func = [f for f in func if f in SQL_FUNCTIONS][0]
|
||||
args = fields[len(func) + 1 : fields.index(")")].split(",")
|
||||
args = [
|
||||
Field(literal_eval_((arg.strip()))) if "*" not in args else literal_eval_((arg.strip()))
|
||||
for arg in args
|
||||
]
|
||||
return getattr(functions, func.capitalize())(*args)
|
||||
|
||||
def function_objects_to_fields(self, fields, is_str: bool):
|
||||
if is_str:
|
||||
functions = ""
|
||||
for func in SQL_FUNCTIONS:
|
||||
if f"{func}(" in fields:
|
||||
functions = str(func) + str(BRACKETS_PATTERN.findall(fields)[0])
|
||||
return [self.get_function_objects(functions)]
|
||||
if not functions:
|
||||
return []
|
||||
else:
|
||||
functions = []
|
||||
for field in fields:
|
||||
field = field.casefold() if isinstance(field, str) else field
|
||||
if not issubclass(type(field), Criterion):
|
||||
if any([func in field and f"{func}(" in field for func in SQL_FUNCTIONS]):
|
||||
functions.append(field)
|
||||
return [self.get_function_objects(function) for function in functions]
|
||||
|
||||
def set_fields(self, fields, **kwargs):
|
||||
fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name"
|
||||
if isinstance(fields, list) and None in fields and Field not in fields:
|
||||
|
|
@ -375,57 +413,26 @@ class Engine:
|
|||
is_list = False
|
||||
|
||||
is_str = isinstance(fields, str)
|
||||
|
||||
def add_functions(fields):
|
||||
from frappe.query_builder.functions import SqlFunctions
|
||||
|
||||
sql_functions = [sql_function.value for sql_function in SqlFunctions]
|
||||
|
||||
def get_function_objects(fields):
|
||||
from frappe.query_builder import functions
|
||||
|
||||
def literal_eval_(literal):
|
||||
try:
|
||||
return literal_eval(literal)
|
||||
except (ValueError, SyntaxError):
|
||||
return literal
|
||||
|
||||
func = fields.split("(")[0].casefold().split()
|
||||
func = [f for f in func if f in sql_functions][0]
|
||||
args = fields[len(func) + 1 : fields.index(")")].split(",")
|
||||
args = [Field(literal_eval_((arg.strip()))) for arg in args]
|
||||
return getattr(functions, func.capitalize())(*args)
|
||||
|
||||
if is_str and any(
|
||||
[func in fields.casefold() and f"{func}(" in fields.casefold() for func in sql_functions]
|
||||
):
|
||||
function_objects = []
|
||||
return function_objects or [get_function_objects(fields)]
|
||||
else:
|
||||
functions = []
|
||||
for field in fields:
|
||||
if not issubclass(type(field), Criterion):
|
||||
if any(
|
||||
[func in field.casefold() and f"{func}(" in field.casefold() for func in sql_functions]
|
||||
):
|
||||
functions.append(field.casefold())
|
||||
return [get_function_objects(function) for function in functions]
|
||||
if is_str:
|
||||
fields = fields.casefold()
|
||||
|
||||
function_objects = (
|
||||
add_functions(fields=fields) if not issubclass(type(fields), Criterion) else []
|
||||
self.function_objects_to_fields(fields=fields, is_str=is_str)
|
||||
if not issubclass(type(fields), Criterion)
|
||||
else []
|
||||
)
|
||||
|
||||
for function in function_objects:
|
||||
if is_str:
|
||||
fields = re.sub(
|
||||
r"\(.*?\)", "", fields.casefold().replace(str(type(function).__name__).strip().casefold(), "")
|
||||
fields = BRACKETS_PATTERN.sub(
|
||||
"", fields.replace(str(type(function).__name__).strip().casefold(), "")
|
||||
)
|
||||
|
||||
else:
|
||||
updated_fields = []
|
||||
for field in fields:
|
||||
if isinstance(field, str):
|
||||
updated_fields.append(
|
||||
re.sub(r"\(.*?\)", "", field)
|
||||
BRACKETS_PATTERN.sub("", field)
|
||||
.strip()
|
||||
.casefold()
|
||||
.replace(str(type(function).__name__).strip().casefold(), "")
|
||||
|
|
@ -433,11 +440,12 @@ class Engine:
|
|||
else:
|
||||
updated_fields.append(field)
|
||||
|
||||
fields = updated_fields
|
||||
fields = [field for field in updated_fields if field]
|
||||
|
||||
if is_str and "," in fields:
|
||||
fields = fields.split(",")
|
||||
fields = [field.replace(" ", "") if "as" not in field else field for field in fields]
|
||||
is_list, is_str = True, False
|
||||
|
||||
if is_str:
|
||||
if fields == "*":
|
||||
|
|
@ -469,7 +477,7 @@ class Engine:
|
|||
fields.extend(function_objects)
|
||||
return fields
|
||||
|
||||
def get_sql(
|
||||
def get_query(
|
||||
self,
|
||||
table: str,
|
||||
fields: Union[List, Tuple],
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class TestReportview(unittest.TestCase):
|
|||
)
|
||||
|
||||
def test_none_filter(self):
|
||||
query = frappe.qb.engine.get_sql("DocType", fields="name", filters={"restrict_to_domain": None})
|
||||
query = frappe.qb.engine.get_query("DocType", fields="name", filters={"restrict_to_domain": None})
|
||||
sql = str(query).replace("`", "").replace('"', "")
|
||||
condition = "restrict_to_domain IS NULL"
|
||||
self.assertIn(condition, sql)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class TestQuery(unittest.TestCase):
|
|||
@run_only_if(db_type_is.MARIADB)
|
||||
def test_multiple_tables_in_filters(self):
|
||||
self.assertEqual(
|
||||
frappe.qb.engine.get_sql(
|
||||
frappe.qb.engine.get_query(
|
||||
"DocType",
|
||||
["*"],
|
||||
[
|
||||
|
|
@ -22,52 +22,53 @@ class TestQuery(unittest.TestCase):
|
|||
|
||||
def test_string_fields(self):
|
||||
self.assertEqual(
|
||||
frappe.qb.engine.get_sql("User", fields="name, email", filters={"name": "Administrator"}),
|
||||
frappe.qb.engine.get_query(
|
||||
"User", fields="name, email", filters={"name": "Administrator"}
|
||||
).get_sql(),
|
||||
frappe.qb.from_("User")
|
||||
.select(Field("name"), Field("email"))
|
||||
.where(Field("name") == "Administrator"),
|
||||
.where(Field("name") == "Administrator")
|
||||
.get_sql(),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
frappe.qb.engine.get_sql("User", fields=["name, email"], filters={"name": "Administrator"}),
|
||||
frappe.qb.engine.get_query(
|
||||
"User", fields=["name, email"], filters={"name": "Administrator"}
|
||||
).get_sql(),
|
||||
frappe.qb.from_("User")
|
||||
.select(Field("name"), Field("email"))
|
||||
.where(Field("name") == "Administrator"),
|
||||
.where(Field("name") == "Administrator")
|
||||
.get_sql(),
|
||||
)
|
||||
|
||||
def test_functions_fields(self):
|
||||
from frappe.query_builder.functions import Count
|
||||
|
||||
self.assertEqual(
|
||||
frappe.qb.engine.get_sql("User", fields="Count(name)", filters={}),
|
||||
frappe.qb.from_("User").select(Count(Field("name"))),
|
||||
frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(),
|
||||
frappe.qb.from_("User").select(Count(Field("name"))).get_sql(),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
frappe.qb.engine.get_sql("User", fields="Count(name), Max(name)", filters={}),
|
||||
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))),
|
||||
frappe.qb.engine.get_query("User", fields=["Count(name)", "Max(name)"], filters={}).get_sql(),
|
||||
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))).get_sql(),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
frappe.qb.engine.get_sql("User", fields=["Count(name)", "Max(name)"], filters={}),
|
||||
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
frappe.qb.engine.get_sql("User", fields=[Count("*")], filters={}),
|
||||
frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))),
|
||||
frappe.qb.engine.get_query("User", fields=[Count("*")], filters={}).get_sql(),
|
||||
frappe.qb.from_("User").select(Count("*")).get_sql(),
|
||||
)
|
||||
|
||||
def test_qb_fields(self):
|
||||
user_doctype = frappe.qb.DocType("User")
|
||||
self.assertEqual(
|
||||
frappe.qb.engine.get_sql(
|
||||
frappe.qb.engine.get_query(
|
||||
user_doctype, fields=[user_doctype.name, user_doctype.email], filters={}
|
||||
),
|
||||
frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email),
|
||||
).get_sql(),
|
||||
frappe.qb.from_(user_doctype).select(user_doctype.name, user_doctype.email).get_sql(),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
frappe.qb.engine.get_sql(user_doctype, fields=user_doctype.email, filters={}),
|
||||
frappe.qb.from_(user_doctype).select(user_doctype.email),
|
||||
frappe.qb.engine.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(),
|
||||
frappe.qb.from_(user_doctype).select(user_doctype.email).get_sql(),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue