Merge pull request #25295 from ankush/virtual-doctype-dx

fix(DX)!: virtual doctype APIs
This commit is contained in:
Ankush Menat 2024-03-11 18:37:20 +05:30 committed by GitHub
commit 8a3dd85503
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 75 additions and 79 deletions

View file

@ -60,15 +60,15 @@ class PermissionInspector(Document):
...
@staticmethod
def get_list(args):
def get_list():
...
@staticmethod
def get_count(args):
def get_count():
...
@staticmethod
def get_stats(args):
def get_stats():
...
def delete(self):

View file

@ -39,12 +39,10 @@ class Recorder(Document):
super(Document, self).__init__(request)
@staticmethod
def get_list(args):
start = cint(args.get("start"))
page_length = cint(args.get("page_length")) or 20
requests = Recorder.get_filtered_requests(args)[start : start + page_length]
def get_list(filters=None, start=0, page_length=20, order_by="duration desc"):
requests = Recorder.get_filtered_requests(filters)[start : start + page_length]
if order_by_statment := args.get("order_by"):
if order_by_statment := order_by:
if "." in order_by_statment:
order_by_statment = order_by_statment.split(".")[1]
@ -60,12 +58,11 @@ class Recorder(Document):
return sorted(requests, key=lambda r: r.duration, reverse=1)
@staticmethod
def get_count(args):
return len(Recorder.get_filtered_requests(args))
def get_count(filters=None):
return len(Recorder.get_filtered_requests(filters))
@staticmethod
def get_filtered_requests(args):
filters = args.get("filters")
def get_filtered_requests(filters):
requests = [serialize_request(request) for request in get_recorder_data()]
return [req for req in requests if evaluate_filters(req, filters)]

View file

@ -39,15 +39,15 @@ class RecorderQuery(Document):
pass
@staticmethod
def get_list(args):
def get_list():
pass
@staticmethod
def get_count(args):
def get_count():
pass
@staticmethod
def get_stats(args):
def get_stats():
pass
def delete(self):

View file

@ -76,22 +76,18 @@ class RQJob(Document):
return self._job_obj
@staticmethod
def get_list(args):
start = cint(args.get("start"))
page_length = cint(args.get("page_length")) or 20
order_desc = "desc" in args.get("order_by", "")
matched_job_ids = RQJob.get_matching_job_ids(args)[start : start + page_length]
def get_list(filters=None, start=0, page_length=20, order_by="modified desc"):
matched_job_ids = RQJob.get_matching_job_ids(filters=filters)[start : start + page_length]
conn = get_redis_conn()
jobs = [serialize_job(job) for job in Job.fetch_many(job_ids=matched_job_ids, connection=conn) if job]
order_desc = "desc" in order_by
return sorted(jobs, key=lambda j: j.modified, reverse=order_desc)
@staticmethod
def get_matching_job_ids(args) -> list[str]:
filters = make_filter_dict(args.get("filters"))
def get_matching_job_ids(filters) -> list[str]:
filters = make_filter_dict(filters or [])
queues = _eval_filters(filters.get("queue"), QUEUES)
statuses = _eval_filters(filters.get("status"), JOB_STATUSES)
@ -117,12 +113,12 @@ class RQJob(Document):
frappe.msgprint(_("Job is not running."), title=_("Invalid Operation"))
@staticmethod
def get_count(args) -> int:
return len(RQJob.get_matching_job_ids(args))
def get_count(filters=None) -> int:
return len(RQJob.get_matching_job_ids(filters))
# None of these methods apply to virtual job doctype, overriden for sanity.
@staticmethod
def get_stats(args):
def get_stats():
return {}
def db_insert(self, *args, **kwargs):

View file

@ -61,18 +61,22 @@ class TestRQJob(FrappeTestCase):
def test_get_list_filtering(self):
# Check failed job clearning and filtering
remove_failed_jobs()
jobs = RQJob.get_list({"filters": [["RQ Job", "status", "=", "failed"]]})
jobs = frappe.get_all("RQ Job", {"status": "failed"})
self.assertEqual(jobs, [])
# Pass a job
job = frappe.enqueue(method=self.BG_JOB, queue="short")
self.check_status(job, "finished")
# Fail a job
job = frappe.enqueue(method=self.BG_JOB, queue="short", fail=True)
self.check_status(job, "failed")
jobs = RQJob.get_list({"filters": [["RQ Job", "status", "=", "failed"]]})
jobs = frappe.get_all("RQ Job", {"status": "failed"})
self.assertEqual(len(jobs), 1)
self.assertTrue(jobs[0].exc_info)
# Assert that non-failed job still exists
non_failed_jobs = RQJob.get_list({"filters": [["RQ Job", "status", "!=", "failed"]]})
non_failed_jobs = frappe.get_all("RQ Job", {"status": ("!=", "failed")})
self.assertGreaterEqual(len(non_failed_jobs), 1)
# Create a slow job and check if it's stuck in "Started"
@ -174,7 +178,7 @@ class TestRQJob(FrappeTestCase):
jobs = [frappe.enqueue(method=self.BG_JOB, queue="short", fail=True) for _ in range(limit * 2)]
self.check_status(jobs[-1], "failed")
self.assertLessEqual(RQJob.get_count({"filters": [["RQ Job", "status", "=", "failed"]]}), limit * 1.1)
self.assertLessEqual(RQJob.get_count(filters=[["RQ Job", "status", "=", "failed"]]), limit * 1.1)
def test_func(fail=False, sleep=0):

View file

@ -46,22 +46,19 @@ class RQWorker(Document):
super(Document, self).__init__(d)
@staticmethod
def get_list(args):
start = cint(args.get("start"))
page_length = cint(args.get("page_length")) or 20
def get_list(start=0, page_length=20):
workers = get_workers()
valid_workers = [w for w in workers if w.pid][start : start + page_length]
return [serialize_worker(worker) for worker in valid_workers]
@staticmethod
def get_count(args) -> int:
def get_count() -> int:
return len(get_workers())
# None of these methods apply to virtual workers, overriden for sanity.
@staticmethod
def get_stats(args):
def get_stats():
return {}
def db_insert(self, *args, **kwargs):

View file

@ -8,10 +8,10 @@ from frappe.tests.utils import FrappeTestCase
class TestRQWorker(FrappeTestCase):
def test_get_worker_list(self):
workers = RQWorker.get_list({})
workers = RQWorker.get_list()
self.assertGreaterEqual(len(workers), 1)
self.assertTrue(any("short" in w.queue_type for w in workers))
def test_worker_serialization(self):
workers = RQWorker.get_list({})
workers = RQWorker.get_list()
frappe.get_doc("RQ Worker", workers[0].name)

View file

@ -13,7 +13,7 @@ from frappe.model import child_table_fields, default_fields, get_permitted_field
from frappe.model.base_document import get_controller
from frappe.model.db_query import DatabaseQuery
from frappe.model.utils import is_virtual_doctype
from frappe.utils import add_user_info, format_duration
from frappe.utils import add_user_info, cint, format_duration
@frappe.whitelist()
@ -23,7 +23,7 @@ def get():
# If virtual doctype, get data from controller get_list method
if is_virtual_doctype(args.doctype):
controller = get_controller(args.doctype)
data = compress(controller.get_list(args))
data = compress(frappe.call(controller.get_list, args=args, **args))
else:
data = compress(execute(**args), args=args)
return data
@ -36,7 +36,7 @@ def get_list():
if is_virtual_doctype(args.doctype):
controller = get_controller(args.doctype)
data = controller.get_list(args)
data = frappe.call(controller.get_list, args=args, **args)
else:
# uncompressed (refactored from frappe.model.db_query.get_list)
data = execute(**args)
@ -51,7 +51,7 @@ def get_count() -> int:
if is_virtual_doctype(args.doctype):
controller = get_controller(args.doctype)
data = controller.get_count(args)
data = frappe.call(controller.get_count, args=args, **args)
else:
distinct = "distinct " if args.distinct == "true" else ""
args.fields = [f"count({distinct}`tab{args.doctype}`.name) as total_count"]
@ -227,6 +227,10 @@ def parse_json(data):
data["save_user_settings"] = json.loads(data["save_user_settings"])
else:
data["save_user_settings"] = True
if isinstance(data.get("start"), str):
data["start"] = cint(data.get("start"))
if isinstance(data.get("page_length"), str):
data["page_length"] = cint(data.get("page_length"))
def get_parenttype_and_fieldname(field, data):
@ -509,7 +513,7 @@ def get_sidebar_stats(stats, doctype, filters=None):
if is_virtual_doctype(doctype):
controller = get_controller(doctype)
args = {"stats": stats, "filters": filters}
data = controller.get_stats(args)
data = frappe.call(controller.get_stats, args=args, **args)
else:
data = get_stats(stats, doctype, filters)

View file

@ -180,7 +180,7 @@ class DatabaseQuery:
"pluck": pluck,
"parent_doctype": parent_doctype,
} | self.__dict__
return controller.get_list(kwargs)
return frappe.call(controller.get_list, args=kwargs, **kwargs)
self.columns = self.get_table_columns()

View file

@ -300,8 +300,9 @@ class Document(BaseDocument):
self.db_insert(ignore_if_duplicate=ignore_if_duplicate)
# children
for d in self.get_all_children():
d.db_insert()
if not getattr(self.meta, "is_virtual", False):
for d in self.get_all_children():
d.db_insert()
self.run_method("after_insert")
self.flags.in_insert = True
@ -415,6 +416,9 @@ class Document(BaseDocument):
def update_children(self):
"""update child tables"""
if getattr(self.meta, "is_virtual", False):
# Virtual doctypes manage their own children
return
for df in self.meta.get_table_fields():
self.update_child_table(df.fieldname, df)

View file

@ -21,17 +21,17 @@ class VirtualDoctype(Protocol):
# ============ class/static methods ============
@staticmethod
def get_list(args) -> list[frappe._dict]:
def get_list(**kwargs) -> list[frappe._dict]:
"""Similar to reportview.get_list"""
...
@staticmethod
def get_count(args) -> int:
def get_count(**kwargs) -> int:
"""Similar to reportview.get_count, return total count of documents on listview."""
...
@staticmethod
def get_stats(args):
def get_stats(**kwargs):
"""Similar to reportview.get_stats, return sidebar stats."""
...

View file

@ -295,24 +295,27 @@ def make_boilerplate(
dedent(
"""
def db_insert(self, *args, **kwargs):
pass
raise NotImplementedError
def load_from_db(self):
pass
raise NotImplementedError
def db_update(self):
raise NotImplementedError
def delete(self):
raise NotImplementedError
@staticmethod
def get_list(filters=None, page_length=20, **kwargs):
pass
@staticmethod
def get_list(args):
def get_count(filters=None, **kwargs):
pass
@staticmethod
def get_count(args):
pass
@staticmethod
def get_stats(args):
def get_stats(**kwargs):
pass
"""
),

View file

@ -1075,27 +1075,19 @@ class TestDBQuery(FrappeTestCase):
class VirtualDocType:
@staticmethod
def get_list(args):
...
def get_list(args=None, limit_page_length=0, doctype=None):
# Backward compatibility
self.assertEqual(args["filters"], [["Virtual DocType", "name", "=", "test"]])
self.assertEqual(limit_page_length, 1)
self.assertEqual(doctype, "Virtual DocType")
with patch(
"frappe.controllers",
new={frappe.local.site: {"Virtual DocType": VirtualDocType}},
):
VirtualDocType.get_list = MagicMock()
frappe.get_all("Virtual DocType", filters={"name": "test"}, fields=["name"], limit=1)
call_args = VirtualDocType.get_list.call_args[0][0]
VirtualDocType.get_list.assert_called_once()
self.assertIsInstance(call_args, dict)
self.assertEqual(call_args["doctype"], "Virtual DocType")
self.assertEqual(call_args["filters"], [["Virtual DocType", "name", "=", "test"]])
self.assertEqual(call_args["fields"], ["name"])
self.assertEqual(call_args["limit_page_length"], 1)
self.assertEqual(call_args["limit_start"], 0)
self.assertEqual(call_args["order_by"], DefaultOrderBy)
def test_coalesce_with_in_ops(self):
self.assertNotIn("ifnull", frappe.get_all("User", {"first_name": ("in", ["a", "b"])}, run=0))
self.assertIn("ifnull", frappe.get_all("User", {"first_name": ("in", ["a", None])}, run=0))

View file

@ -68,17 +68,17 @@ class VirtualDoctypeTest(Document):
self.update_data(data)
@staticmethod
def get_list(args):
def get_list():
data = VirtualDoctypeTest.get_current_data()
return [frappe._dict(doc) for name, doc in data.items()]
@staticmethod
def get_count(args):
def get_count():
data = VirtualDoctypeTest.get_current_data()
return len(data)
@staticmethod
def get_stats(args):
def get_stats():
return {}
@ -157,19 +157,18 @@ class TestVirtualDoctypes(FrappeTestCase):
updated_docs = {doc1.name, doc2.name}
self.assertEqual(docs, updated_docs)
listed_docs = {d.name for d in VirtualDoctypeTest.get_list({})}
listed_docs = {d.name for d in VirtualDoctypeTest.get_list()}
self.assertEqual(docs, listed_docs)
def test_get_count(self):
args = {"doctype": TEST_DOCTYPE_NAME, "filters": [], "fields": []}
self.assertIsInstance(VirtualDoctypeTest.get_count(args), int)
self.assertIsInstance(VirtualDoctypeTest.get_count(), int)
def test_delete_doc(self):
doc = frappe.get_doc(doctype=TEST_DOCTYPE_NAME).insert()
frappe.delete_doc(doc.doctype, doc.name)
listed_docs = {d.name for d in VirtualDoctypeTest.get_list({})}
listed_docs = {d.name for d in VirtualDoctypeTest.get_list()}
self.assertNotIn(doc.name, listed_docs)
def test_controller_validity(self):