From 1ff85611ff6d04ea61553d39d32e7ff7e62d547a Mon Sep 17 00:00:00 2001 From: Sagar Vora <16315650+sagarvora@users.noreply.github.com> Date: Mon, 8 Sep 2025 16:35:23 +0530 Subject: [PATCH] feat: new `extend_doctype_class` hook --- frappe/model/base_document.py | 37 ++++++++- frappe/tests/test_base_document.py | 128 ++++++++++++++++++++++++++++- 2 files changed, 163 insertions(+), 2 deletions(-) diff --git a/frappe/model/base_document.py b/frappe/model/base_document.py index aa22772b9b..f02f9ac23e 100644 --- a/frappe/model/base_document.py +++ b/frappe/model/base_document.py @@ -127,7 +127,42 @@ def import_controller(doctype): if not issubclass(class_, BaseDocument): raise ImportError(f"{doctype}: {classname} is not a subclass of BaseDocument") - return class_ + return get_extended_class(class_, doctype) + + +def get_extended_class(base_class, doctype): + """Create an extended class by mixing extension classes with the base class. + + Args: + base_class: The base document class + doctype: The doctype name + + Returns: + Extended class that combines all extension classes with the base class + """ + + extensions = frappe.get_hooks("extend_doctype_class", {}).get(doctype) + if not extensions: + return base_class + + # Get extension classes in reverse order using frappe.get_attr + extension_classes = [] + for extension_path in reversed(extensions): + try: + extension_class = frappe.get_attr(extension_path) + except Exception: + frappe.throw( + _("Error retrieving extension class from path:
{0}").format(extension_path) + ) + + extension_classes.append(extension_class) + + # Create the extended class by combining extension classes with base class + # Extension classes come first in MRO, then base class + class_name = f"Extended{base_class.__name__}" + extended_class = type(class_name, (*extension_classes, base_class), {}) + + return extended_class RESERVED_KEYWORDS = frozenset( diff --git a/frappe/tests/test_base_document.py b/frappe/tests/test_base_document.py index 59f0dcf81b..e08e9def5b 100644 --- a/frappe/tests/test_base_document.py +++ b/frappe/tests/test_base_document.py @@ -1,7 +1,30 @@ -from frappe.model.base_document import BaseDocument +import frappe +from frappe.desk.doctype.todo.todo import ToDo +from frappe.model.base_document import BaseDocument, get_extended_class from frappe.tests import IntegrationTestCase +class TestExtensionA(BaseDocument): + def extension_method_a(self): + return "method_a" + + +class TestExtensionB(BaseDocument): + def extension_method_b(self): + return "method_b" + + +class TestToDoExtension(BaseDocument): + """Extension class that overrides ToDo's validate method""" + + def validate(self): + # Add our custom logic + self.custom_validation_called = True + + def extension_method(self): + return "extension_method_called" + + class TestBaseDocument(IntegrationTestCase): def test_docstatus(self): doc = BaseDocument({"docstatus": 0, "doctype": "ToDo"}) @@ -15,3 +38,106 @@ class TestBaseDocument(IntegrationTestCase): doc.docstatus = 2 self.assertTrue(doc.docstatus.is_cancelled()) self.assertEqual(doc.docstatus, 2) + + def test_get_extended_class_with_no_extensions(self): + """Test that get_extended_class returns the base class when no extensions are provided.""" + + with self.patch_hooks({"extend_doctype_class": {}}): + result = get_extended_class(ToDo, "ToDo") + self.assertEqual(result, ToDo) + + with self.patch_hooks({"extend_doctype_class": {"ToDo": []}}): + result = get_extended_class(ToDo, "ToDo") + self.assertEqual(result, ToDo) + + def test_get_extended_class_with_extensions(self): + """Test that get_extended_class properly combines extension classes with base class.""" + # Mock frappe.get_hooks to return extension paths + extensions = [ + "frappe.tests.test_base_document.TestExtensionA", + "frappe.tests.test_base_document.TestExtensionB", + ] + + with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}): + extended_class = get_extended_class(ToDo, "ToDo") + + # Test that the extended class is different from base class + self.assertNotEqual(extended_class, ToDo) + + # Test that the extended class has all methods from extensions and base + instance = extended_class({"doctype": "ToDo"}) + self.assertTrue(hasattr(instance, "extension_method_a")) + self.assertTrue(hasattr(instance, "extension_method_b")) + + # Test that methods work correctly + self.assertEqual(instance.extension_method_a(), "method_a") + self.assertEqual(instance.extension_method_b(), "method_b") + + # Test MRO (Method Resolution Order) - extensions should come first in reverse order + mro_classes = [cls.__name__ for cls in extended_class.__mro__] + self.assertIn("TestExtensionB", mro_classes) + self.assertIn("TestExtensionA", mro_classes) + self.assertIn("ToDo", mro_classes) + + # TestExtensionB should come before TestExtensionA (reverse order) + idx_b = mro_classes.index("TestExtensionB") + idx_a = mro_classes.index("TestExtensionA") + idx_base = mro_classes.index("ToDo") + self.assertLess(idx_b, idx_a) + self.assertLess(idx_a, idx_base) + + def test_extension_overrides_todo_method(self): + """Test that an extension can override methods from the actual ToDo class""" + from frappe.desk.doctype.todo.todo import ToDo + + # Mock the hooks to include our ToDo extension + extensions = ["frappe.tests.test_base_document.TestToDoExtension"] + + with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}): + extended_class = get_extended_class(ToDo, "ToDo") + + # Test that the extended class is different from base ToDo + self.assertNotEqual(extended_class, ToDo) + + # Create an instance of the extended ToDo + instance = extended_class({"doctype": "ToDo"}) + + # Test that extension method is available + self.assertTrue(hasattr(instance, "extension_method")) + self.assertEqual(instance.extension_method(), "extension_method_called") + + # Test that the validate method is overridden + # The extension's validate method should set custom_validation_called = True + instance.validate() + self.assertTrue(getattr(instance, "custom_validation_called", False)) + + # Test MRO - extension should come before ToDo class + mro_classes = [cls.__name__ for cls in extended_class.__mro__] + self.assertIn("TestToDoExtension", mro_classes) + self.assertIn("ToDo", mro_classes) + + # TestToDoExtension should come before ToDo + idx_extension = mro_classes.index("TestToDoExtension") + idx_todo = mro_classes.index("ToDo") + self.assertLess(idx_extension, idx_todo) + + def test_extension_invalid_path_raises_exception(self): + """Test that an invalid extension path raises an appropriate exception""" + from frappe.desk.doctype.todo.todo import ToDo + + # Mock the hooks to include an invalid extension path + path_to_invalid_extension = "invalid.module.path.NonExistentClass" + + extensions = [ + "frappe.tests.test_base_document.TestExtensionA", # valid + path_to_invalid_extension, # invalid + ] + + with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}): + # Test that frappe.ValidationError is raised for invalid extension path + with self.assertRaises(frappe.ValidationError) as context: + get_extended_class(ToDo, "ToDo") + + # Check that the error message mentions the invalid path + error_message = str(context.exception) + self.assertIn(path_to_invalid_extension, error_message)