refactor: removed aggregation from database.py

refactor: moved aggregate to frappe.query
This commit is contained in:
Aradhya-Tripathi 2021-09-09 18:05:44 +05:30
parent 1c2e470792
commit 73eb7806a8
4 changed files with 8 additions and 92 deletions

View file

@ -7,15 +7,19 @@ from frappe.modules.export_file import export_doc
import os
import subprocess
class PackageRelease(Document):
def set_version(self):
# set the next patch release by default
from frappe.query_builder.functions import Max
from frappe.query_builder import Field
if not self.major:
self.major = frappe.db.max('Package Release', 'major', dict(package=self.package))
self.major = frappe.qb.from_("Package Release").where(Field(self.package) == "package").select(Max("major")).run()[0][0] or 0
if not self.minor:
self.minor = frappe.db.max('Package Release', 'minor', dict(package=self.package))
self.minor = frappe.qb.from_("Package Release").where(Field(self.package) == "package").select(Max("minor")).run()[0][0] or 0
if not self.patch:
self.patch = frappe.db.max('Package Release', 'patch', dict(package=self.package)) + 1
self.patch = frappe.qb.from_("Package Release").where(Field(self.package) == "package").select(Max("patch")).run()[0][0] or 1
def autoname(self):
self.set_version()

View file

@ -314,59 +314,6 @@ class Database(object):
nres.append(nr)
return nres
def build_conditions(self, filters):
"""Convert filters sent as dict, lists to SQL conditions. filter's key
is passed by map function, build conditions like:
* ifnull(`fieldname`, default_value) = %(fieldname)s
* `fieldname` [=, !=, >, >=, <, <=] %(fieldname)s
"""
conditions = []
values = {}
def _build_condition(key):
"""
filter's key is passed by map function
build conditions like:
* ifnull(`fieldname`, default_value) = %(fieldname)s
* `fieldname` [=, !=, >, >=, <, <=] %(fieldname)s
"""
_operator = "="
_rhs = " %(" + key + ")s"
value = filters.get(key)
values[key] = value
if isinstance(value, (list, tuple)):
# value is a tuple like ("!=", 0)
_operator = value[0].lower()
values[key] = value[1]
if isinstance(value[1], (tuple, list)):
# value is a list in tuple ("in", ("A", "B"))
_rhs = " ({0})".format(", ".join(self.escape(v) for v in value[1]))
del values[key]
if _operator not in ["=", "!=", ">", ">=", "<", "<=", "like", "in", "not in", "not like"]:
_operator = "="
if "[" in key:
split_key = key.split("[")
condition = "coalesce(`" + split_key[0] + "`, " + split_key[1][:-1] + ") " \
+ _operator + _rhs
else:
condition = "`" + key + "` " + _operator + _rhs
conditions.append(condition)
if isinstance(filters, int):
# docname is a number, convert to string
filters = str(filters)
if isinstance(filters, str):
filters = { "name": filters }
for f in filters:
_build_condition(f)
return " and ".join(conditions), values
def get(self, doctype, filters=None, as_dict=True, cache=False):
"""Returns `get_value` with fieldname='*'"""
return self.get_value(doctype, filters, "*", as_dict=as_dict, cache=cache)
@ -822,30 +769,6 @@ class Database(object):
frappe.cache().set_value('doctype:count:{}'.format(dt), count, expires_in_sec = 86400)
return count
def sum(self, dt, fieldname, filters=None):
return self._get_aggregation('SUM', dt, fieldname, filters)
def avg(self, dt, fieldname, filters=None):
return self._get_aggregation('AVG', dt, fieldname, filters)
def min(self, dt, fieldname, filters=None):
return self._get_aggregation('MIN', dt, fieldname, filters)
def max(self, dt, fieldname, filters=None):
return self._get_aggregation('MAX', dt, fieldname, filters)
def _get_aggregation(self, function, dt, fieldname, filters=None):
if not self.has_column(dt, fieldname):
frappe.throw(frappe._('Invalid column'), self.InvalidColumnName)
query = f'SELECT {function}({fieldname}) AS value FROM `tab{dt}`'
values = ()
if filters:
conditions, values = self.build_conditions(filters)
query = f"{query} WHERE {conditions}"
return self.sql(query, values)[0][0] or 0
@staticmethod
def format_date(date):
return getdate(date).strftime("%Y-%m-%d")

View file

@ -46,12 +46,6 @@ class TestDB(unittest.TestCase):
def test_escape(self):
frappe.db.escape("香港濟生堂製藥有限公司 - IT".encode("utf-8"))
def test_aggregation(self):
self.assertTrue(type(frappe.db.sum('DocField', 'permlevel', dict(parent=('like', 'doc')))) in (int, float))
self.assertTrue(type(frappe.db.avg('DocField', 'permlevel')) in (int, float))
self.assertTrue(type(frappe.db.min('DocField', 'permlevel')) in (int, float))
self.assertTrue(type(frappe.db.max('DocField', 'permlevel')) in (int, float))
def test_get_single_value(self):
#setup
values_dict = {

View file

@ -147,12 +147,7 @@ def get_safe_globals():
get_single_value = frappe.db.get_single_value,
get_default = frappe.db.get_default,
escape = frappe.db.escape,
sql = read_sql,
sum = frappe.db.sum,
avg = frappe.db.avg,
count = frappe.db.count,
min = frappe.db.min,
max = frappe.db.max
sql = read_sql
)
if frappe.response: