fix: minor improvements to extended class logic

This commit is contained in:
Sagar Vora 2025-09-09 11:07:00 +05:30
parent b4bd4f756d
commit c5bc3875df
2 changed files with 32 additions and 47 deletions

View file

@ -70,23 +70,22 @@ UNPICKLABLE_KEYS = (
) )
def _reconstruct_extended_instance(doctype, state): def _reduce_extended_instance(doc):
"""Helper function to reconstruct an extended class instance during unpickling. """Make extended class instances pickle-able.
This function is called during unpickling to recreate the extended class When unpickling, this will use get_controller() to recreate the extended class.
based on current hooks and restore the instance state. Respects the __getstate__ method for proper state handling.
"""
return (_reconstruct_extended_instance, (doc.doctype,), doc.__getstate__())
def _reconstruct_extended_instance(doctype):
"""
Helper function to reconstruct an extended class instance during unpickling.
""" """
# Get the current extended class (uses caching from get_controller) # Get the current extended class (uses caching from get_controller)
extended_class = get_controller(doctype) extended_class = get_controller(doctype)
instance = extended_class.__new__(extended_class) return extended_class.__new__(extended_class)
# Use __setstate__ if available, otherwise directly update __dict__
if hasattr(instance, "__setstate__"):
instance.__setstate__(state)
else:
instance.__dict__.update(state)
return instance
def get_controller(doctype): def get_controller(doctype):
@ -146,10 +145,10 @@ def import_controller(doctype):
if not issubclass(class_, BaseDocument): if not issubclass(class_, BaseDocument):
raise ImportError(f"{doctype}: {classname} is not a subclass of BaseDocument") raise ImportError(f"{doctype}: {classname} is not a subclass of BaseDocument")
return get_extended_class(class_, doctype) return _get_extended_class(class_, doctype)
def get_extended_class(base_class, doctype): def _get_extended_class(base_class, doctype):
"""Create an extended class by mixing extension classes with the base class. """Create an extended class by mixing extension classes with the base class.
Args: Args:
@ -169,38 +168,24 @@ def get_extended_class(base_class, doctype):
for extension_path in reversed(extensions): for extension_path in reversed(extensions):
try: try:
extension_class = frappe.get_attr(extension_path) extension_class = frappe.get_attr(extension_path)
except Exception: except Exception as e:
frappe.throw( raise ImportError(
_("Error retrieving extension class from path:<br><code>{0}</code>").format(extension_path) "Error retrieving extension class from path:\n{0}".format(extension_path)
) ) from e
extension_classes.append(extension_class) extension_classes.append(extension_class)
# Create the extended class by combining extension classes with base class # Create the extended class by combining extension classes with base class
# Extension classes come first in MRO, then base class # Extension classes come first in MRO, then base class
class_name = f"Extended{base_class.__name__}" return type(
f"Extended{base_class.__name__}",
def __reduce__(self):
"""Make extended class instances pickle-able.
When unpickling, this will use get_controller() to recreate the extended class
based on current hooks, ensuring the instance respects the current environment.
Respects the BaseDocument's __getstate__ method for proper state handling.
"""
return (_reconstruct_extended_instance, (self.doctype, self.__getstate__()))
extended_class = type(
class_name,
(*extension_classes, base_class), (*extension_classes, base_class),
{ {
"__reduce__": __reduce__, "__reduce__": _reduce_extended_instance,
"__module__": base_class.__module__, "__module__": base_class.__module__,
}, },
) )
return extended_class
RESERVED_KEYWORDS = frozenset( RESERVED_KEYWORDS = frozenset(
( (

View file

@ -2,7 +2,7 @@ import pickle
import frappe import frappe
from frappe.desk.doctype.todo.todo import ToDo from frappe.desk.doctype.todo.todo import ToDo
from frappe.model.base_document import BaseDocument, get_extended_class from frappe.model.base_document import BaseDocument, _get_extended_class
from frappe.tests import IntegrationTestCase from frappe.tests import IntegrationTestCase
@ -42,18 +42,18 @@ class TestBaseDocument(IntegrationTestCase):
self.assertEqual(doc.docstatus, 2) self.assertEqual(doc.docstatus, 2)
def test_get_extended_class_with_no_extensions(self): def test_get_extended_class_with_no_extensions(self):
"""Test that get_extended_class returns the base class when no extensions are provided.""" """Test that _get_extended_class returns the base class when no extensions are provided."""
with self.patch_hooks({"extend_doctype_class": {}}): with self.patch_hooks({"extend_doctype_class": {}}):
result = get_extended_class(ToDo, "ToDo") result = _get_extended_class(ToDo, "ToDo")
self.assertEqual(result, ToDo) self.assertEqual(result, ToDo)
with self.patch_hooks({"extend_doctype_class": {"ToDo": []}}): with self.patch_hooks({"extend_doctype_class": {"ToDo": []}}):
result = get_extended_class(ToDo, "ToDo") result = _get_extended_class(ToDo, "ToDo")
self.assertEqual(result, ToDo) self.assertEqual(result, ToDo)
def test_get_extended_class_with_extensions(self): def test_get_extended_class_with_extensions(self):
"""Test that get_extended_class properly combines extension classes with base class.""" """Test that _get_extended_class properly combines extension classes with base class."""
# Mock frappe.get_hooks to return extension paths # Mock frappe.get_hooks to return extension paths
extensions = [ extensions = [
"frappe.tests.test_base_document.TestExtensionA", "frappe.tests.test_base_document.TestExtensionA",
@ -61,7 +61,7 @@ class TestBaseDocument(IntegrationTestCase):
] ]
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}): with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
extended_class = get_extended_class(ToDo, "ToDo") extended_class = _get_extended_class(ToDo, "ToDo")
# Test that the extended class is different from base class # Test that the extended class is different from base class
self.assertNotEqual(extended_class, ToDo) self.assertNotEqual(extended_class, ToDo)
@ -96,7 +96,7 @@ class TestBaseDocument(IntegrationTestCase):
extensions = ["frappe.tests.test_base_document.TestToDoExtension"] extensions = ["frappe.tests.test_base_document.TestToDoExtension"]
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}): with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
extended_class = get_extended_class(ToDo, "ToDo") extended_class = _get_extended_class(ToDo, "ToDo")
# Test that the extended class is different from base ToDo # Test that the extended class is different from base ToDo
self.assertNotEqual(extended_class, ToDo) self.assertNotEqual(extended_class, ToDo)
@ -136,9 +136,9 @@ class TestBaseDocument(IntegrationTestCase):
] ]
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}): with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
# Test that frappe.ValidationError is raised for invalid extension path # Test that ImportError is raised for invalid extension path
with self.assertRaises(frappe.ValidationError) as context: with self.assertRaises(ImportError) as context:
get_extended_class(ToDo, "ToDo") _get_extended_class(ToDo, "ToDo")
# Check that the error message mentions the invalid path # Check that the error message mentions the invalid path
error_message = str(context.exception) error_message = str(context.exception)
@ -152,7 +152,7 @@ class TestBaseDocument(IntegrationTestCase):
extensions = ["frappe.tests.test_base_document.TestToDoExtension"] extensions = ["frappe.tests.test_base_document.TestToDoExtension"]
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}): with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
extended_class = get_extended_class(ToDo, "ToDo") extended_class = _get_extended_class(ToDo, "ToDo")
# Create an instance with some data # Create an instance with some data
original_instance = extended_class( original_instance = extended_class(