From fe73c9f1aa14d9a135405dfc56eb500e8cdc00db Mon Sep 17 00:00:00 2001 From: David Date: Thu, 5 Sep 2024 21:02:50 +0200 Subject: [PATCH] feat: add value tracer for test debugging --- frappe/model/trace.py | 304 +++++++++++++++++++++++++++++++++++++ frappe/tests/test_trace.py | 163 ++++++++++++++++++++ 2 files changed, 467 insertions(+) create mode 100644 frappe/model/trace.py create mode 100644 frappe/tests/test_trace.py diff --git a/frappe/model/trace.py b/frappe/model/trace.py new file mode 100644 index 0000000000..31ed20e5ae --- /dev/null +++ b/frappe/model/trace.py @@ -0,0 +1,304 @@ +""" +Traced Fields for Frappe + +This module provides utilities for creating traced fields in Frappe documents, +which is particularly useful for instrumenting or debugging test cases and +enforcing strict validation rules. + +Key features: +- Create fields that can be monitored for specific value changes +- Enforce forbidden values on fields +- Apply custom validation logic to fields +- Seamlessly integrate with Frappe's document model + +Usage in test cases: +1. Subclass your DocType from TracedDocument alongside Document +2. Use traced_field to define fields you want to monitor +3. Specify forbidden values or custom validation functions +4. In your tests, attempt to set values and check for raised exceptions + +Example of standard usage: + from frappe.model.trace import TracedDocument, traced_field + + class CustomSalesInvoice(SalesInvoice, TracedDocument): + ... + def validate_amount(self, value): + if value < 0: + raise AssertionError("Amount cannot be negative") + + loyalty_program = traced_field("Loyalty Program", forbidden_values = ["FORBIDDEN_PROGRAM"]) + amount = traced_field("Amount", custom_validation = validate_amount) + ... + + class TestCustomInvoice(unittest.TestCase): + def setUp(self): + self.invoice = CustomSalesInvoice() + + def test_forbidden_loyalty_program(self): + with self.assertRaises(AssertionError): + self.invoice.loyalty_program = "FORBIDDEN_PROGRAM" + + def test_negative_amount(self): + with self.assertRaises(AssertionError): + self.invoice.amount = -100 + +Benefits for testing: +- Easily catch unauthorized value changes +- Enforce business rules at the field level +- Improve test coverage by explicitly checking field-level validations +- Simulate and test error conditions more effectively + +Monkey Patching for Debugging: +For temporary tracing of fields in existing DocTypes, use the traced_field_context +context manager. This allows you to add tracing to any field without modifying +the original DocType class. + +Example of monkey patching with context manager: + import unittest + from frappe.model.document import Document + from frappe.model.trace import traced_field_context + + class TestExistingDocType(unittest.TestCase): + def test_debug_value(self): + def validate_some_field(obj, value): + if value == 'debug_value': + raise AssertionError("Debug value detected") + + doc = frappe.get_doc("My Doc Type") + + with traced_field_context( + doc.__class__, + 'some_field', + custom_validation=validate_some_field + ): + with self.assertRaises(AssertionError): + doc.some_field = 'debug_value' + + # Outside the context, the original behavior is restored + doc.some_field = 'debug_value' # This will not raise an error + +This approach allows you to: +- Easily add temporary tracing to any field in any DocType +- Debug issues by catching specific value changes +- Add custom validation logic for debugging purposes +- Automatically reverts changes after the context, ensuring no side effects +- Cleaner and more Pythonic approach to temporary monkey patching + +Note: While primarily designed for testing, this can also be used in +production code to enforce strict data integrity rules. However, be +mindful of potential performance implications in high-traffic scenarios. +""" + +import contextlib + +import frappe +from frappe.model.document import Document + + +class TracedValue: + """ + A descriptor class for creating traced fields in Frappe documents. + + This class allows for monitoring and validating changes to specific fields + in a Frappe document. It can enforce forbidden values and apply custom + validation logic. + + Attributes: + field_name (str): The name of the field being traced. + forbidden_values (list): A list of values that are not allowed for this field. + custom_validation (callable): A function for custom validation logic. + """ + + def __init__(self, field_name, forbidden_values=None, custom_validation=None): + """ + Initialize a TracedValue instance. + + Args: + field_name (str): The name of the field to be traced. + forbidden_values (list, optional): A list of values that should not be allowed. + custom_validation (callable, optional): A function for additional validation. + """ + self.field_name = field_name + self.forbidden_values = forbidden_values or [] + self.custom_validation = custom_validation + + def __get__(self, obj, objtype=None): + """ + Get the value of the traced field. + + Args: + obj (object): The instance that this descriptor is accessed from. + objtype (type, optional): The type of the instance. + + Returns: + The value of the traced field, or self if accessed from the class. + """ + if obj is None: + return self + + return getattr(obj, f"_{self.field_name}", None) + + def __set__(self, obj, value): + """ + Set the value of the traced field with validation. + + This method checks against forbidden values and applies custom validation + before setting the value. + + Args: + obj (object): The instance that this descriptor is accessed from. + value: The value to set for the traced field. + + Raises: + ValueError: If the value is forbidden or fails custom validation. + Note: returns AssertionError in test mode to debug with the `--pdb` flag. + + """ + if value in self.forbidden_values: + if frappe.flags.in_test: + frappe.throw(f"{self.field_name} cannot be set to {value}", AssertionError) + else: + frappe.throw(f"{self.field_name} cannot be set to {value}") + + if self.custom_validation: + try: + self.custom_validation(obj, value) + except Exception as e: + if frappe.flags.in_test: + frappe.throw(str(e), AssertionError) + else: + frappe.throw(str(e)) + + setattr(obj, f"_{self.field_name}", value) + + +def traced_field(*args, **kwargs): + """ + A convenience function for creating TracedValue instances. + + This function simplifies the creation of traced fields in Frappe documents. + + Args: + *args: Positional arguments to pass to TracedValue constructor. + **kwargs: Keyword arguments to pass to TracedValue constructor. + + Returns: + TracedValue: An instance of the TracedValue descriptor. + """ + return TracedValue(*args, **kwargs) + + +class TracedDocument(Document): + """ + A base class for Frappe documents with traced fields. + + This class extends Frappe's Document class to provide support for + traced fields created with TracedValue. + + Attributes: + Inherits all attributes from frappe.model.document.Document + """ + + def __init__(self, *args, **kwargs): + """ + Initialize a TracedDocument instance. + + This method sets up traced fields and initializes the parent Document. + + Args: + *args: Positional arguments to pass to the parent constructor. + **kwargs: Keyword arguments to pass to the parent constructor. + """ + super().__init__(*args, **kwargs) + for name, attr in self.__class__.__dict__.items(): + if isinstance(attr, TracedValue): + setattr(self, f"_{name}", getattr(self, name)) + + def get_valid_dict(self, *args, **kwargs): + """ + Get a valid dictionary representation of the document. + + This method extends the parent method to properly handle traced fields. + + Args: + *args: Positional arguments to pass to the parent method. + **kwargs: Keyword arguments to pass to the parent method. + + Returns: + dict: A dictionary representation of the document, including traced fields. + """ + d = super().get_valid_dict(*args, **kwargs) + for name, attr in self.__class__.__dict__.items(): + if isinstance(attr, TracedValue): + d[name] = getattr(self, name) + return d + + +@contextlib.contextmanager +def traced_field_context(doc_class, field_name, forbidden_values=None, custom_validation=None): + """ + A context manager for temporarily tracing a field in a DocType. + + Args: + doc_class (type): The DocType class to modify. + field_name (str): The name of the field to trace. + forbidden_values (list, optional): A list of forbidden values for the field. + custom_validation (callable, optional): A custom validation function. + + Yields: + None + """ + original_attr = getattr(doc_class, field_name, None) + original_init = doc_class.__init__ + + try: + setattr(doc_class, field_name, traced_field(field_name, forbidden_values, custom_validation)) + + def new_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + setattr(self, f"_{field_name}", getattr(self, field_name, None)) + + doc_class.__init__ = new_init + + yield + + finally: + if original_attr is not None: + setattr(doc_class, field_name, original_attr) + else: + delattr(doc_class, field_name) + + doc_class.__init__ = original_init + + +def trace_fields(**field_configs): + """ + A class decorator to permanently trace fields in a DocType. + + Args: + **field_configs: Keyword arguments where each key is a field name and + the value is a dict containing 'forbidden_values' and/or + 'custom_validation'. + + Returns: + callable: A decorator function that modifies the DocType class. + """ + + def decorator(doc_class): + original_init = doc_class.__init__ + + def new_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + for field_name in field_configs: + setattr(self, f"_{field_name}", getattr(self, field_name, None)) + + doc_class.__init__ = new_init + + for field_name, config in field_configs.items(): + forbidden_values = config.get("forbidden_values") + custom_validation = config.get("custom_validation") + setattr(doc_class, field_name, traced_field(field_name, forbidden_values, custom_validation)) + + return doc_class + + return decorator diff --git a/frappe/tests/test_trace.py b/frappe/tests/test_trace.py new file mode 100644 index 0000000000..343fbb0acb --- /dev/null +++ b/frappe/tests/test_trace.py @@ -0,0 +1,163 @@ +import unittest +from unittest.mock import MagicMock, patch + +import frappe +from frappe.model.document import Document +from frappe.model.trace import TracedDocument, trace_fields, traced_field, traced_field_context + + +def create_mock_meta(doctype): + mock_meta = MagicMock() + mock_meta.get_table_fields.return_value = [] + return mock_meta + + +class TestDocument(Document): + def __init__(self, *args, **kwargs): + kwargs["doctype"] = "TestDocument" + with patch("frappe.get_meta", return_value=create_mock_meta("TestDocument")): + super().__init__(*args, **kwargs) + + +class TestTracedDocument(TracedDocument): + def __init__(self, *args, **kwargs): + kwargs["doctype"] = "TestTracedDocument" + with patch("frappe.get_meta", return_value=create_mock_meta("TestTracedDocument")): + super().__init__(*args, **kwargs) + + test_field = traced_field("test_field", forbidden_values=["forbidden"]) + + def validate_positive(self, value): + if value <= 0: + raise ValueError("Value must be positive") + + positive_field = traced_field("positive_field", custom_validation=validate_positive) + + +class TestTrace(unittest.TestCase): + def setUp(self): + self.traced_doc = TestTracedDocument() + + def test_traced_field_get(self): + self.traced_doc._test_field = "test_value" + self.assertEqual(self.traced_doc.test_field, "test_value") + + def test_traced_field_set(self): + self.traced_doc.test_field = "new_value" + self.assertEqual(self.traced_doc._test_field, "new_value") + + def test_traced_field_forbidden_value(self): + with self.assertRaises(AssertionError): + self.traced_doc.test_field = "forbidden" + + def test_traced_field_custom_validation(self): + self.traced_doc.positive_field = 10 + self.assertEqual(self.traced_doc._positive_field, 10) + + with self.assertRaises(AssertionError): + self.traced_doc.positive_field = -5 + + def test_get_valid_dict(self): + self.traced_doc.test_field = "valid_value" + self.traced_doc.positive_field = 15 + valid_dict = self.traced_doc.get_valid_dict() + self.assertEqual(valid_dict["test_field"], "valid_value") + self.assertEqual(valid_dict["positive_field"], 15) + + +class TestTracedFieldContext(unittest.TestCase): + def test_traced_field_context(self): + doc = TestDocument() + + # Before context + doc.test_field = "forbidden" + self.assertEqual(doc.test_field, "forbidden") + + with traced_field_context(TestDocument, "test_field", forbidden_values=["forbidden"]): + # Inside context + with self.assertRaises(AssertionError): + doc.test_field = "forbidden" + + doc.test_field = "allowed" + self.assertEqual(doc.test_field, "allowed") + + # After context + doc.test_field = "forbidden" + self.assertEqual(doc.test_field, "forbidden") + + def test_traced_field_context_custom_validation(self): + doc = TestDocument() + + def validate_even(obj, value): + if value % 2 != 0: + raise ValueError("Value must be even") + + with traced_field_context(TestDocument, "number_field", custom_validation=validate_even): + doc.number_field = 2 + self.assertEqual(doc.number_field, 2) + + with self.assertRaises(AssertionError): + doc.number_field = 3 + + # After context, validation should not apply + doc.number_field = 3 + self.assertEqual(doc.number_field, 3) + + def test_traced_field_context_not_in_test_mode(self): + doc = TestDocument() + + # Temporarily set frappe.flags.in_test to False + original_in_test = frappe.flags.in_test + frappe.flags.in_test = False + + try: + with traced_field_context(TestDocument, "test_field", forbidden_values=["forbidden"]): + with self.assertRaises(frappe.exceptions.ValidationError): + doc.test_field = "forbidden" + + doc.test_field = "allowed" + self.assertEqual(doc.test_field, "allowed") + finally: + # Restore the original in_test flag + frappe.flags.in_test = original_in_test + + # After context + doc.test_field = "forbidden" + self.assertEqual(doc.test_field, "forbidden") + + +def validate_positive(obj, value): + if value <= 0: + raise ValueError("Value must be positive") + + +class TestTraceFieldDecorator(unittest.TestCase): + @trace_fields(decorated_field={"forbidden_values": ["bad"]}) + class DecoratedTestDocument(TestDocument): + pass + + def test_trace_field_decorator(self): + doc = self.DecoratedTestDocument() + + with self.assertRaises(AssertionError): + doc.decorated_field = "bad" + + doc.decorated_field = "good" + self.assertEqual(doc.decorated_field, "good") + + @trace_fields(positive_field={"custom_validation": validate_positive}) + class PositiveFieldDocument(TestDocument): + pass + + def test_trace_field_decorator_custom_validation(self): + doc = self.PositiveFieldDocument() + + with self.assertRaises(AssertionError): + doc.positive_field = -1 + + doc.positive_field = 1 + self.assertEqual(doc.positive_field, 1) + + +if __name__ == "__main__": + unittest.main()