fix: minor improvements to extended class logic
This commit is contained in:
parent
b4bd4f756d
commit
c5bc3875df
2 changed files with 32 additions and 47 deletions
|
|
@ -70,23 +70,22 @@ UNPICKLABLE_KEYS = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _reconstruct_extended_instance(doctype, state):
|
def _reduce_extended_instance(doc):
|
||||||
"""Helper function to reconstruct an extended class instance during unpickling.
|
"""Make extended class instances pickle-able.
|
||||||
|
|
||||||
This function is called during unpickling to recreate the extended class
|
When unpickling, this will use get_controller() to recreate the extended class.
|
||||||
based on current hooks and restore the instance state.
|
Respects the __getstate__ method for proper state handling.
|
||||||
|
"""
|
||||||
|
return (_reconstruct_extended_instance, (doc.doctype,), doc.__getstate__())
|
||||||
|
|
||||||
|
|
||||||
|
def _reconstruct_extended_instance(doctype):
|
||||||
|
"""
|
||||||
|
Helper function to reconstruct an extended class instance during unpickling.
|
||||||
"""
|
"""
|
||||||
# Get the current extended class (uses caching from get_controller)
|
# Get the current extended class (uses caching from get_controller)
|
||||||
extended_class = get_controller(doctype)
|
extended_class = get_controller(doctype)
|
||||||
instance = extended_class.__new__(extended_class)
|
return extended_class.__new__(extended_class)
|
||||||
|
|
||||||
# Use __setstate__ if available, otherwise directly update __dict__
|
|
||||||
if hasattr(instance, "__setstate__"):
|
|
||||||
instance.__setstate__(state)
|
|
||||||
else:
|
|
||||||
instance.__dict__.update(state)
|
|
||||||
|
|
||||||
return instance
|
|
||||||
|
|
||||||
|
|
||||||
def get_controller(doctype):
|
def get_controller(doctype):
|
||||||
|
|
@ -146,10 +145,10 @@ def import_controller(doctype):
|
||||||
if not issubclass(class_, BaseDocument):
|
if not issubclass(class_, BaseDocument):
|
||||||
raise ImportError(f"{doctype}: {classname} is not a subclass of BaseDocument")
|
raise ImportError(f"{doctype}: {classname} is not a subclass of BaseDocument")
|
||||||
|
|
||||||
return get_extended_class(class_, doctype)
|
return _get_extended_class(class_, doctype)
|
||||||
|
|
||||||
|
|
||||||
def get_extended_class(base_class, doctype):
|
def _get_extended_class(base_class, doctype):
|
||||||
"""Create an extended class by mixing extension classes with the base class.
|
"""Create an extended class by mixing extension classes with the base class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -169,38 +168,24 @@ def get_extended_class(base_class, doctype):
|
||||||
for extension_path in reversed(extensions):
|
for extension_path in reversed(extensions):
|
||||||
try:
|
try:
|
||||||
extension_class = frappe.get_attr(extension_path)
|
extension_class = frappe.get_attr(extension_path)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
frappe.throw(
|
raise ImportError(
|
||||||
_("Error retrieving extension class from path:<br><code>{0}</code>").format(extension_path)
|
"Error retrieving extension class from path:\n{0}".format(extension_path)
|
||||||
)
|
) from e
|
||||||
|
|
||||||
extension_classes.append(extension_class)
|
extension_classes.append(extension_class)
|
||||||
|
|
||||||
# Create the extended class by combining extension classes with base class
|
# Create the extended class by combining extension classes with base class
|
||||||
# Extension classes come first in MRO, then base class
|
# Extension classes come first in MRO, then base class
|
||||||
class_name = f"Extended{base_class.__name__}"
|
return type(
|
||||||
|
f"Extended{base_class.__name__}",
|
||||||
def __reduce__(self):
|
|
||||||
"""Make extended class instances pickle-able.
|
|
||||||
|
|
||||||
When unpickling, this will use get_controller() to recreate the extended class
|
|
||||||
based on current hooks, ensuring the instance respects the current environment.
|
|
||||||
Respects the BaseDocument's __getstate__ method for proper state handling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return (_reconstruct_extended_instance, (self.doctype, self.__getstate__()))
|
|
||||||
|
|
||||||
extended_class = type(
|
|
||||||
class_name,
|
|
||||||
(*extension_classes, base_class),
|
(*extension_classes, base_class),
|
||||||
{
|
{
|
||||||
"__reduce__": __reduce__,
|
"__reduce__": _reduce_extended_instance,
|
||||||
"__module__": base_class.__module__,
|
"__module__": base_class.__module__,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return extended_class
|
|
||||||
|
|
||||||
|
|
||||||
RESERVED_KEYWORDS = frozenset(
|
RESERVED_KEYWORDS = frozenset(
|
||||||
(
|
(
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import pickle
|
||||||
|
|
||||||
import frappe
|
import frappe
|
||||||
from frappe.desk.doctype.todo.todo import ToDo
|
from frappe.desk.doctype.todo.todo import ToDo
|
||||||
from frappe.model.base_document import BaseDocument, get_extended_class
|
from frappe.model.base_document import BaseDocument, _get_extended_class
|
||||||
from frappe.tests import IntegrationTestCase
|
from frappe.tests import IntegrationTestCase
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,18 +42,18 @@ class TestBaseDocument(IntegrationTestCase):
|
||||||
self.assertEqual(doc.docstatus, 2)
|
self.assertEqual(doc.docstatus, 2)
|
||||||
|
|
||||||
def test_get_extended_class_with_no_extensions(self):
|
def test_get_extended_class_with_no_extensions(self):
|
||||||
"""Test that get_extended_class returns the base class when no extensions are provided."""
|
"""Test that _get_extended_class returns the base class when no extensions are provided."""
|
||||||
|
|
||||||
with self.patch_hooks({"extend_doctype_class": {}}):
|
with self.patch_hooks({"extend_doctype_class": {}}):
|
||||||
result = get_extended_class(ToDo, "ToDo")
|
result = _get_extended_class(ToDo, "ToDo")
|
||||||
self.assertEqual(result, ToDo)
|
self.assertEqual(result, ToDo)
|
||||||
|
|
||||||
with self.patch_hooks({"extend_doctype_class": {"ToDo": []}}):
|
with self.patch_hooks({"extend_doctype_class": {"ToDo": []}}):
|
||||||
result = get_extended_class(ToDo, "ToDo")
|
result = _get_extended_class(ToDo, "ToDo")
|
||||||
self.assertEqual(result, ToDo)
|
self.assertEqual(result, ToDo)
|
||||||
|
|
||||||
def test_get_extended_class_with_extensions(self):
|
def test_get_extended_class_with_extensions(self):
|
||||||
"""Test that get_extended_class properly combines extension classes with base class."""
|
"""Test that _get_extended_class properly combines extension classes with base class."""
|
||||||
# Mock frappe.get_hooks to return extension paths
|
# Mock frappe.get_hooks to return extension paths
|
||||||
extensions = [
|
extensions = [
|
||||||
"frappe.tests.test_base_document.TestExtensionA",
|
"frappe.tests.test_base_document.TestExtensionA",
|
||||||
|
|
@ -61,7 +61,7 @@ class TestBaseDocument(IntegrationTestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
|
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
|
||||||
extended_class = get_extended_class(ToDo, "ToDo")
|
extended_class = _get_extended_class(ToDo, "ToDo")
|
||||||
|
|
||||||
# Test that the extended class is different from base class
|
# Test that the extended class is different from base class
|
||||||
self.assertNotEqual(extended_class, ToDo)
|
self.assertNotEqual(extended_class, ToDo)
|
||||||
|
|
@ -96,7 +96,7 @@ class TestBaseDocument(IntegrationTestCase):
|
||||||
extensions = ["frappe.tests.test_base_document.TestToDoExtension"]
|
extensions = ["frappe.tests.test_base_document.TestToDoExtension"]
|
||||||
|
|
||||||
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
|
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
|
||||||
extended_class = get_extended_class(ToDo, "ToDo")
|
extended_class = _get_extended_class(ToDo, "ToDo")
|
||||||
|
|
||||||
# Test that the extended class is different from base ToDo
|
# Test that the extended class is different from base ToDo
|
||||||
self.assertNotEqual(extended_class, ToDo)
|
self.assertNotEqual(extended_class, ToDo)
|
||||||
|
|
@ -136,9 +136,9 @@ class TestBaseDocument(IntegrationTestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
|
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
|
||||||
# Test that frappe.ValidationError is raised for invalid extension path
|
# Test that ImportError is raised for invalid extension path
|
||||||
with self.assertRaises(frappe.ValidationError) as context:
|
with self.assertRaises(ImportError) as context:
|
||||||
get_extended_class(ToDo, "ToDo")
|
_get_extended_class(ToDo, "ToDo")
|
||||||
|
|
||||||
# Check that the error message mentions the invalid path
|
# Check that the error message mentions the invalid path
|
||||||
error_message = str(context.exception)
|
error_message = str(context.exception)
|
||||||
|
|
@ -152,7 +152,7 @@ class TestBaseDocument(IntegrationTestCase):
|
||||||
extensions = ["frappe.tests.test_base_document.TestToDoExtension"]
|
extensions = ["frappe.tests.test_base_document.TestToDoExtension"]
|
||||||
|
|
||||||
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
|
with self.patch_hooks({"extend_doctype_class": {"ToDo": extensions}}):
|
||||||
extended_class = get_extended_class(ToDo, "ToDo")
|
extended_class = _get_extended_class(ToDo, "ToDo")
|
||||||
|
|
||||||
# Create an instance with some data
|
# Create an instance with some data
|
||||||
original_instance = extended_class(
|
original_instance = extended_class(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue