diff --git a/frappe/database/database.py b/frappe/database/database.py index 3d48f1ebe9..4c5d54f857 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -1050,7 +1050,9 @@ class Database(object): cache_count = frappe.cache().get_value("doctype:count:{}".format(dt)) if cache_count is not None: return cache_count - query = frappe.qb.engine.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct) + query = frappe.qb.engine.get_query( + table=dt, filters=filters, fields=Count("*"), distinct=distinct + ) count = self.sql(query, debug=debug)[0][0] if not filters and cache: frappe.cache().set_value("doctype:count:{}".format(dt), count, expires_in_sec=86400) diff --git a/frappe/database/query.py b/frappe/database/query.py index 10cb62b4a2..b69ab7958f 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -387,16 +387,20 @@ class Engine: _args = [] for arg in args: - field = literal_eval_(arg.strip()) + has_operator = False + initial_fields = literal_eval_(arg.strip()) if to_cast: - try: - operator_fields = arg.split() - field = OPERATOR_MAP[operator_fields[1]]( - Field(operator_fields[0]), - Field(operator_fields[2]), - ) - except IndexError: - field = Field(field) + for _operator in OPERATOR_MAP.keys(): + if _operator in initial_fields: + has_operator = True + field = OPERATOR_MAP[_operator]( + *map(lambda field: Field(field.strip()), arg.split(_operator)) + ) + + field = Field(initial_fields) if not has_operator else field + else: + field = initial_fields + _args.append(field) return getattr(functions, func)(*_args) diff --git a/frappe/tests/test_db_query.py b/frappe/tests/test_db_query.py index 8727951f4a..ad9f59b3cd 100644 --- a/frappe/tests/test_db_query.py +++ b/frappe/tests/test_db_query.py @@ -143,7 +143,9 @@ class TestReportview(unittest.TestCase): ) def test_none_filter(self): - query = frappe.qb.engine.get_query("DocType", fields="name", filters={"restrict_to_domain": None}) + query = frappe.qb.engine.get_query( + "DocType", fields="name", filters={"restrict_to_domain": None} + ) sql = str(query).replace("`", "").replace('"', "") condition = "restrict_to_domain IS NULL" self.assertIn(condition, sql) diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index e7682a0d0c..88a631ca67 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -42,7 +42,7 @@ class TestQuery(unittest.TestCase): ) def test_functions_fields(self): - from frappe.query_builder.functions import Count, Max + from frappe.query_builder.functions import Abs, Count, Max self.assertEqual( frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(), @@ -54,6 +54,15 @@ class TestQuery(unittest.TestCase): frappe.qb.from_("User").select(Count(Field("name")), Max(Field("name"))).get_sql(), ) + self.assertEqual( + frappe.qb.engine.get_query( + "User", fields=["abs(name-email)", "Count(name)"], filters={} + ).get_sql(), + frappe.qb.from_("User") + .select(Abs(Field("name") - Field("email")), Count(Field("name"))) + .get_sql(), + ) + self.assertEqual( frappe.qb.engine.get_query("User", fields=[Count("*")], filters={}).get_sql(), frappe.qb.from_("User").select(Count("*")).get_sql(),