From 73eb7806a85aacf641374fd356d08ebe87a2e2d8 Mon Sep 17 00:00:00 2001 From: Aradhya-Tripathi Date: Thu, 9 Sep 2021 18:05:44 +0530 Subject: [PATCH] refactor: removed aggregation from database.py refactor: moved aggregate to frappe.query --- .../package_release/package_release.py | 10 ++- frappe/database/database.py | 77 ------------------- frappe/tests/test_db.py | 6 -- frappe/utils/safe_exec.py | 7 +- 4 files changed, 8 insertions(+), 92 deletions(-) diff --git a/frappe/core/doctype/package_release/package_release.py b/frappe/core/doctype/package_release/package_release.py index 1fb8796882..fdad2edfc6 100644 --- a/frappe/core/doctype/package_release/package_release.py +++ b/frappe/core/doctype/package_release/package_release.py @@ -7,15 +7,19 @@ from frappe.modules.export_file import export_doc import os import subprocess + class PackageRelease(Document): def set_version(self): # set the next patch release by default + from frappe.query_builder.functions import Max + from frappe.query_builder import Field + if not self.major: - self.major = frappe.db.max('Package Release', 'major', dict(package=self.package)) + self.major = frappe.qb.from_("Package Release").where(Field(self.package) == "package").select(Max("major")).run()[0][0] or 0 if not self.minor: - self.minor = frappe.db.max('Package Release', 'minor', dict(package=self.package)) + self.minor = frappe.qb.from_("Package Release").where(Field(self.package) == "package").select(Max("minor")).run()[0][0] or 0 if not self.patch: - self.patch = frappe.db.max('Package Release', 'patch', dict(package=self.package)) + 1 + self.patch = frappe.qb.from_("Package Release").where(Field(self.package) == "package").select(Max("patch")).run()[0][0] or 1 def autoname(self): self.set_version() diff --git a/frappe/database/database.py b/frappe/database/database.py index acc6f804b4..2b7c52cd47 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -314,59 +314,6 @@ class Database(object): nres.append(nr) return nres - def build_conditions(self, filters): - """Convert filters sent as dict, lists to SQL conditions. filter's key - is passed by map function, build conditions like: - - * ifnull(`fieldname`, default_value) = %(fieldname)s - * `fieldname` [=, !=, >, >=, <, <=] %(fieldname)s - """ - conditions = [] - values = {} - def _build_condition(key): - """ - filter's key is passed by map function - build conditions like: - * ifnull(`fieldname`, default_value) = %(fieldname)s - * `fieldname` [=, !=, >, >=, <, <=] %(fieldname)s - """ - _operator = "=" - _rhs = " %(" + key + ")s" - value = filters.get(key) - values[key] = value - if isinstance(value, (list, tuple)): - # value is a tuple like ("!=", 0) - _operator = value[0].lower() - values[key] = value[1] - if isinstance(value[1], (tuple, list)): - # value is a list in tuple ("in", ("A", "B")) - _rhs = " ({0})".format(", ".join(self.escape(v) for v in value[1])) - del values[key] - - if _operator not in ["=", "!=", ">", ">=", "<", "<=", "like", "in", "not in", "not like"]: - _operator = "=" - - if "[" in key: - split_key = key.split("[") - condition = "coalesce(`" + split_key[0] + "`, " + split_key[1][:-1] + ") " \ - + _operator + _rhs - else: - condition = "`" + key + "` " + _operator + _rhs - - conditions.append(condition) - - if isinstance(filters, int): - # docname is a number, convert to string - filters = str(filters) - - if isinstance(filters, str): - filters = { "name": filters } - - for f in filters: - _build_condition(f) - - return " and ".join(conditions), values - def get(self, doctype, filters=None, as_dict=True, cache=False): """Returns `get_value` with fieldname='*'""" return self.get_value(doctype, filters, "*", as_dict=as_dict, cache=cache) @@ -822,30 +769,6 @@ class Database(object): frappe.cache().set_value('doctype:count:{}'.format(dt), count, expires_in_sec = 86400) return count - def sum(self, dt, fieldname, filters=None): - return self._get_aggregation('SUM', dt, fieldname, filters) - - def avg(self, dt, fieldname, filters=None): - return self._get_aggregation('AVG', dt, fieldname, filters) - - def min(self, dt, fieldname, filters=None): - return self._get_aggregation('MIN', dt, fieldname, filters) - - def max(self, dt, fieldname, filters=None): - return self._get_aggregation('MAX', dt, fieldname, filters) - - def _get_aggregation(self, function, dt, fieldname, filters=None): - if not self.has_column(dt, fieldname): - frappe.throw(frappe._('Invalid column'), self.InvalidColumnName) - - query = f'SELECT {function}({fieldname}) AS value FROM `tab{dt}`' - values = () - if filters: - conditions, values = self.build_conditions(filters) - query = f"{query} WHERE {conditions}" - - return self.sql(query, values)[0][0] or 0 - @staticmethod def format_date(date): return getdate(date).strftime("%Y-%m-%d") diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index 72bec78db7..20f38dc964 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -46,12 +46,6 @@ class TestDB(unittest.TestCase): def test_escape(self): frappe.db.escape("香港濟生堂製藥有限公司 - IT".encode("utf-8")) - def test_aggregation(self): - self.assertTrue(type(frappe.db.sum('DocField', 'permlevel', dict(parent=('like', 'doc')))) in (int, float)) - self.assertTrue(type(frappe.db.avg('DocField', 'permlevel')) in (int, float)) - self.assertTrue(type(frappe.db.min('DocField', 'permlevel')) in (int, float)) - self.assertTrue(type(frappe.db.max('DocField', 'permlevel')) in (int, float)) - def test_get_single_value(self): #setup values_dict = { diff --git a/frappe/utils/safe_exec.py b/frappe/utils/safe_exec.py index e18c498b3c..7ccd80e346 100644 --- a/frappe/utils/safe_exec.py +++ b/frappe/utils/safe_exec.py @@ -147,12 +147,7 @@ def get_safe_globals(): get_single_value = frappe.db.get_single_value, get_default = frappe.db.get_default, escape = frappe.db.escape, - sql = read_sql, - sum = frappe.db.sum, - avg = frappe.db.avg, - count = frappe.db.count, - min = frappe.db.min, - max = frappe.db.max + sql = read_sql ) if frappe.response: