diff --git a/frappe/desk/doctype/dashboard_chart/test_dashboard_chart.py b/frappe/desk/doctype/dashboard_chart/test_dashboard_chart.py index 5e39998e62..13fea8282d 100644 --- a/frappe/desk/doctype/dashboard_chart/test_dashboard_chart.py +++ b/frappe/desk/doctype/dashboard_chart/test_dashboard_chart.py @@ -4,9 +4,9 @@ from __future__ import unicode_literals import unittest, frappe -from frappe.utils import getdate, formatdate, get_last_day -from frappe.desk.doctype.dashboard_chart.dashboard_chart import (get, - get_period_ending) +from frappe.utils import getdate, formatdate +from frappe.utils.dateutils import get_period_ending, get_period_beginning, get_period +from frappe.desk.doctype.dashboard_chart.dashboard_chart import get from datetime import datetime from dateutil.relativedelta import relativedelta @@ -53,15 +53,11 @@ class TestDashboardChart(unittest.TestCase): cur_date = datetime.now() - relativedelta(years=1) result = get(chart_name='Test Dashboard Chart', refresh=1) - self.assertEqual(result.get('labels')[0], formatdate(cur_date.strftime('%Y-%m-%d'))) - - if formatdate(cur_date.strftime('%Y-%m-%d')) == formatdate(get_last_day(cur_date).strftime('%Y-%m-%d')): - cur_date += relativedelta(months=1) + self.assertEqual(result.get('labels')[0], get_period(cur_date)) for idx in range(1, 13): - month = get_last_day(cur_date) month = formatdate(month.strftime('%Y-%m-%d')) - self.assertEqual(result.get('labels')[idx], month) + self.assertEqual(result.get('labels')[idx], get_period(month)) cur_date += relativedelta(months=1) frappe.db.rollback() @@ -87,15 +83,11 @@ class TestDashboardChart(unittest.TestCase): cur_date = datetime.now() - relativedelta(years=1) result = get(chart_name ='Test Empty Dashboard Chart', refresh=1) - self.assertEqual(result.get('labels')[0], formatdate(cur_date.strftime('%Y-%m-%d'))) - - if formatdate(cur_date.strftime('%Y-%m-%d')) == formatdate(get_last_day(cur_date).strftime('%Y-%m-%d')): - cur_date += relativedelta(months=1) + self.assertEqual(result.get('labels')[0], get_period(cur_date)) for idx in range(1, 13): - month = get_last_day(cur_date) month = formatdate(month.strftime('%Y-%m-%d')) - self.assertEqual(result.get('labels')[idx], month) + self.assertEqual(result.get('labels')[idx], get_period(month)) cur_date += relativedelta(months=1) frappe.db.rollback() @@ -124,15 +116,11 @@ class TestDashboardChart(unittest.TestCase): cur_date = datetime.now() - relativedelta(years=1) result = get(chart_name ='Test Empty Dashboard Chart 2', refresh = 1) - self.assertEqual(result.get('labels')[0], formatdate(cur_date.strftime('%Y-%m-%d'))) - - if formatdate(cur_date.strftime('%Y-%m-%d')) == formatdate(get_last_day(cur_date).strftime('%Y-%m-%d')): - cur_date += relativedelta(months=1) + self.assertEqual(result.get('labels')[0], get_period(cur_date)) for idx in range(1, 13): - month = get_last_day(cur_date) month = formatdate(month.strftime('%Y-%m-%d')) - self.assertEqual(result.get('labels')[idx], month) + self.assertEqual(result.get('labels')[idx], get_period(month)) cur_date += relativedelta(months=1) # only 1 data point with value diff --git a/frappe/utils/dateutils.py b/frappe/utils/dateutils.py index 2895eb0568..06b434a512 100644 --- a/frappe/utils/dateutils.py +++ b/frappe/utils/dateutils.py @@ -5,8 +5,7 @@ from __future__ import unicode_literals import frappe import frappe.defaults import datetime -from frappe.utils import get_datetime -from frappe.utils import add_to_date, getdate +from frappe.utils import get_datetime, add_to_date, getdate from frappe.utils.data import get_first_day, get_first_day_of_week, get_quarter_start, get_year_start,\ get_last_day, get_last_day_of_week, get_quarter_ending, get_year_ending from six import string_types @@ -130,32 +129,24 @@ def get_period(date, interval='Monthly'): 'Yearly': str(date.year) }[interval] -def get_period_beginning(date, timegrain): - as_str = True - if timegrain == 'Daily': - pass - elif timegrain == 'Weekly': - date = get_first_day_of_week(date, as_str=as_str) - elif timegrain == 'Monthly': - date = get_first_day(date, as_str=as_str) - elif timegrain == 'Quarterly': - date = get_quarter_start(date, as_str=as_str) - elif timegrain == 'Yearly': - date = get_year_start(date, as_str=as_str) - - return date +def get_period_beginning(date, timegrain, as_str=True): + return getdate({ + 'Daily': date, + 'Weekly': get_first_day_of_week(date), + 'Monthly': get_first_day(date), + 'Quarterly': get_quarter_start(date), + 'Yearly': get_year_start(date) + }[timegrain]) def get_period_ending(date, timegrain): date = getdate(date) if timegrain == 'Daily': - pass - elif timegrain == 'Weekly': - date = get_last_day_of_week(date) - elif timegrain == 'Monthly': - date = get_last_day(date) - elif timegrain == 'Quarterly': - date = get_quarter_ending(date) - elif timegrain == 'Yearly': - date = get_year_ending(date) - - return getdate(date) + return date + else: + return getdate({ + 'Daily': date, + 'Weekly': get_last_day_of_week(date), + 'Monthly': get_last_day(date), + 'Quarterly': get_quarter_ending(date), + 'Yearly': get_year_ending(date) + }[timegrain])