refactor: organize test contextmanagers (#28041)

* refactor: prefer staticmethod decorator

* refactor: add cm register utility and keep cms in one file

* refactor: enter safe_exec enabled context (treewide)

* refactor: move trace fields to the other test context managers

* chore: marke all test_runner functions for deprecation

* chore: mark some tests.utils functions for deprecation (moved)

* chore: mark traced_field_conext for deprecation (moved)

* chore: placate semgrep in dumpster

* fix: show deprecation warnings per module in tests (incl. from dumpster)

* chore: remove use of deprecated functions from tests
This commit is contained in:
David Arnold 2024-10-09 02:09:19 +02:00 committed by GitHub
parent c8f42fe15d
commit 95950c8d81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 638 additions and 514 deletions

View file

@ -29,7 +29,7 @@ class UnitTestReport(UnitTestCase):
class TestReport(IntegrationTestCase): class TestReport(IntegrationTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
cls.enable_safe_exec() cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass() return super().setUpClass()
def test_report_builder(self): def test_report_builder(self):

View file

@ -127,7 +127,7 @@ class TestServerScript(IntegrationTestCase):
script_doc = frappe.get_doc(doctype="Server Script") script_doc = frappe.get_doc(doctype="Server Script")
script_doc.update(script) script_doc.update(script)
script_doc.insert() script_doc.insert()
cls.enable_safe_exec() cls.enterClassContext(cls.enable_safe_exec())
frappe.db.commit() frappe.db.commit()
return super().setUpClass() return super().setUpClass()

View file

@ -288,3 +288,245 @@ def get_translated_dict():
) )
def validate_roles(self): def validate_roles(self):
self.populate_role_profile_roles() self.populate_role_profile_roles()
@deprecated(
"frappe.tests_runner.get_dependencies", "2024-20-08", "v17", "use frappe.tests.utils.get_dependencies"
)
def test_runner_get_dependencies(doctype):
from frappe.tests.utils import get_dependencies
return get_dependencies(doctype)
@deprecated("frappe.tests_runner.get_modules", "2024-20-08", "v17", "use frappe.tests.utils.get_modules")
def test_runner_get_modules(doctype):
from frappe.tests.utils import get_modules
return get_modules(doctype)
@deprecated(
"frappe.tests_runner.make_test_records", "2024-20-08", "v17", "use frappe.tests.utils.make_test_records"
)
def test_runner_make_test_records(*args, **kwargs):
from frappe.tests.utils import make_test_records
return make_test_records(*args, **kwargs)
@deprecated(
"frappe.tests_runner.make_test_objects", "2024-20-08", "v17", "use frappe.tests.utils.make_test_objects"
)
def test_runner_make_test_objects(*args, **kwargs):
from frappe.tests.utils import make_test_objects
return make_test_objects(*args, **kwargs)
@deprecated(
"frappe.tests_runner.make_test_records_for_doctype",
"2024-20-08",
"v17",
"use frappe.tests.utils.make_test_records_for_doctype",
)
def test_runner_make_test_records_for_doctype(*args, **kwargs):
from frappe.tests.utils import make_test_records_for_doctype
return make_test_records_for_doctype(*args, **kwargs)
@deprecated(
"frappe.tests_runner.print_mandatory_fields",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_runner_print_mandatory_fields(*args, **kwargs):
from frappe.tests.utils.generators import print_mandatory_fields
return print_mandatory_fields(*args, **kwargs)
@deprecated(
"frappe.tests_runner.get_test_record_log",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_runner_get_test_record_log(doctype):
from frappe.tests.utils.generators import TestRecordLog
return TestRecordLog().get(doctype)
@deprecated(
"frappe.tests_runner.add_to_test_record_log",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_runner_add_to_test_record_log(doctype):
from frappe.tests.utils.generators import TestRecordLog
return TestRecordLog().add(doctype)
@deprecated(
"frappe.tests_runner.main",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_runner_main(*args, **kwargs):
from frappe.commands.testing import main
return main(*args, **kwargs)
@deprecated(
"frappe.tests_runner.xmlrunner_wrapper",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_xmlrunner_wrapper(output):
"""Convenience wrapper to keep method signature unchanged for XMLTestRunner and TextTestRunner"""
try:
import xmlrunner
except ImportError:
print("Development dependencies are required to execute this command. To install run:")
print("$ bench setup requirements --dev")
raise
def _runner(*args, **kwargs):
kwargs["output"] = output
return xmlrunner.XMLTestRunner(*args, **kwargs)
return _runner
@deprecated(
"frappe.tests.upate_system_settings",
"2024-20-08",
"v17",
"use with `self.change_settings(...):` context manager",
)
def tests_update_system_settings(args, commit=False):
import frappe
doc = frappe.get_doc("System Settings")
doc.update(args)
doc.flags.ignore_mandatory = 1
doc.save()
if commit:
# moved here
frappe.db.commit() # nosemgrep
@deprecated(
"frappe.tests.get_system_setting",
"2024-20-08",
"v17",
"use `frappe.db.get_single_value('System Settings', key)`",
)
def tests_get_system_setting(key):
import frappe
return frappe.db.get_single_value("System Settings", key)
@deprecated(
"frappe.tests.utils.change_settings",
"2024-20-08",
"v17",
"use `frappe.tests.change_settings` or the cls.change_settings",
)
def tests_change_settings(*args, **kwargs):
from frappe.tests.classes.context_managers import change_settings
return change_settings(*args, **kwargs)
@deprecated(
"frappe.tests.utils.patch_hooks",
"2024-20-08",
"v17",
"use `frappe.tests.patch_hooks` or the cls.patch_hooks",
)
def tests_patch_hooks(*args, **kwargs):
from frappe.tests.classes.context_managers import patch_hooks
return patch_hooks(*args, **kwargs)
@deprecated(
"frappe.tests.utils.debug_on",
"2024-20-08",
"v17",
"use `frappe.tests.debug_on` or the cls.debug_on",
)
def tests_debug_on(*args, **kwargs):
from frappe.tests.classes.context_managers import debug_on
return debug_on(*args, **kwargs)
@deprecated(
"frappe.tests.utils.timeout",
"2024-20-08",
"v17",
"use `frappe.tests.timeout` or the cls.timeout",
)
def tests_timeout(*args, **kwargs):
from frappe.tests.classes.context_managers import timeout
return timeout(*args, **kwargs)
@deprecated(
"frappe.tests.utils.FrappeTestCase",
"2024-20-08",
"v17",
"use `frappe.tests.UnitTestCase` or `frappe.tests.IntegrationTestCase` respectively",
)
def tests_FrappeTestCase(*args, **kwargs):
from frappe.tests import IntegrationTestCase
return IntegrationTestCase(*args, **kwargs)
@deprecated(
"frappe.tests.utils.IntegrationTestCase",
"2024-20-08",
"v17",
"use `frappe.tests.IntegrationTestCase`",
)
def tests_IntegrationTestCase(*args, **kwargs):
from frappe.tests import IntegrationTestCase
return IntegrationTestCase(*args, **kwargs)
@deprecated(
"frappe.tests.utils.UnitTestCase",
"2024-20-08",
"v17",
"use `frappe.tests.UnitTestCase`",
)
def tests_UnitTestCase(*args, **kwargs):
from frappe.tests import UnitTestCase
return UnitTestCase(*args, **kwargs)
@deprecated(
"frappe.model.trace.traced_field_context",
"2024-20-08",
"v17",
"use `cls.trace_fields`",
)
def model_trace_traced_field_context(*args, **kwargs):
from frappe.tests.classes.context_managers import trace_fields
return trace_fields(*args, **kwargs)

View file

@ -16,7 +16,7 @@ class UnitTestSystemConsole(UnitTestCase):
class TestSystemConsole(IntegrationTestCase): class TestSystemConsole(IntegrationTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
cls.enable_safe_exec() cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass() return super().setUpClass()
def test_system_console(self): def test_system_console(self):

View file

@ -2,8 +2,7 @@
Traced Fields for Frappe Traced Fields for Frappe
This module provides utilities for creating traced fields in Frappe documents, This module provides utilities for creating traced fields in Frappe documents,
which is particularly useful for instrumenting or debugging test cases and which is particularly useful for enforcing strict value lifetime validation rules.
enforcing strict validation rules.
Key features: Key features:
- Create fields that can be monitored for specific value changes - Create fields that can be monitored for specific value changes
@ -11,12 +10,6 @@ Key features:
- Apply custom validation logic to fields - Apply custom validation logic to fields
- Seamlessly integrate with Frappe's document model - Seamlessly integrate with Frappe's document model
Usage in test cases:
1. Subclass your DocType from TracedDocument alongside Document
2. Use traced_field to define fields you want to monitor
3. Specify forbidden values or custom validation functions
4. In your tests, attempt to set values and check for raised exceptions
Example of standard usage: Example of standard usage:
from frappe.model.trace import TracedDocument, traced_field from frappe.model.trace import TracedDocument, traced_field
@ -30,67 +23,9 @@ Example of standard usage:
amount = traced_field("Amount", custom_validation = validate_amount) amount = traced_field("Amount", custom_validation = validate_amount)
... ...
class TestCustomInvoice(unittest.TestCase): See frappe.tests.classes.context_managers for a context manager built into test classes.
def setUp(self):
self.invoice = CustomSalesInvoice()
def test_forbidden_loyalty_program(self):
with self.assertRaises(AssertionError):
self.invoice.loyalty_program = "FORBIDDEN_PROGRAM"
def test_negative_amount(self):
with self.assertRaises(AssertionError):
self.invoice.amount = -100
Benefits for testing:
- Easily catch unauthorized value changes
- Enforce business rules at the field level
- Improve test coverage by explicitly checking field-level validations
- Simulate and test error conditions more effectively
Monkey Patching for Debugging:
For temporary tracing of fields in existing DocTypes, use the traced_field_context
context manager. This allows you to add tracing to any field without modifying
the original DocType class.
Example of monkey patching with context manager:
import unittest
from frappe.model.document import Document
from frappe.model.trace import traced_field_context
class TestExistingDocType(unittest.TestCase):
def test_debug_value(self):
def validate_some_field(obj, value):
if value == 'debug_value':
raise AssertionError("Debug value detected")
doc = frappe.get_doc("My Doc Type")
with traced_field_context(
doc.__class__,
'some_field',
custom_validation=validate_some_field
):
with self.assertRaises(AssertionError):
doc.some_field = 'debug_value'
# Outside the context, the original behavior is restored
doc.some_field = 'debug_value' # This will not raise an error
This approach allows you to:
- Easily add temporary tracing to any field in any DocType
- Debug issues by catching specific value changes
- Add custom validation logic for debugging purposes
- Automatically reverts changes after the context, ensuring no side effects
- Cleaner and more Pythonic approach to temporary monkey patching
Note: While primarily designed for testing, this can also be used in
production code to enforce strict data integrity rules. However, be
mindful of potential performance implications in high-traffic scenarios.
""" """
import contextlib
import frappe import frappe
from frappe.model.document import Document from frappe.model.document import Document
@ -172,22 +107,6 @@ class TracedValue:
setattr(obj, f"_{self.field_name}", value) setattr(obj, f"_{self.field_name}", value)
def traced_field(*args, **kwargs):
"""
A convenience function for creating TracedValue instances.
This function simplifies the creation of traced fields in Frappe documents.
Args:
*args: Positional arguments to pass to TracedValue constructor.
**kwargs: Keyword arguments to pass to TracedValue constructor.
Returns:
TracedValue: An instance of the TracedValue descriptor.
"""
return TracedValue(*args, **kwargs)
class TracedDocument(Document): class TracedDocument(Document):
""" """
A base class for Frappe documents with traced fields. A base class for Frappe documents with traced fields.
@ -234,71 +153,20 @@ class TracedDocument(Document):
return d return d
@contextlib.contextmanager def traced_field(*args, **kwargs):
def traced_field_context(doc_class, field_name, forbidden_values=None, custom_validation=None):
""" """
A context manager for temporarily tracing a field in a DocType. A convenience function for creating TracedValue instances.
This function simplifies the creation of traced fields in Frappe documents.
Args: Args:
doc_class (type): The DocType class to modify. *args: Positional arguments to pass to TracedValue constructor.
field_name (str): The name of the field to trace. **kwargs: Keyword arguments to pass to TracedValue constructor.
forbidden_values (list, optional): A list of forbidden values for the field.
custom_validation (callable, optional): A custom validation function.
Yields:
None
"""
original_attr = getattr(doc_class, field_name, None)
original_init = doc_class.__init__
try:
setattr(doc_class, field_name, traced_field(field_name, forbidden_values, custom_validation))
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
setattr(self, f"_{field_name}", getattr(self, field_name, None))
doc_class.__init__ = new_init
yield
finally:
if original_attr is not None:
setattr(doc_class, field_name, original_attr)
else:
delattr(doc_class, field_name)
doc_class.__init__ = original_init
def trace_fields(**field_configs):
"""
A class decorator to permanently trace fields in a DocType.
Args:
**field_configs: Keyword arguments where each key is a field name and
the value is a dict containing 'forbidden_values' and/or
'custom_validation'.
Returns: Returns:
callable: A decorator function that modifies the DocType class. TracedValue: An instance of the TracedValue descriptor.
""" """
return TracedValue(*args, **kwargs)
def decorator(doc_class):
original_init = doc_class.__init__
def new_init(self, *args, **kwargs): from frappe.deprecation_dumpster import model_trace_traced_field_context as traced_field_context
original_init(self, *args, **kwargs)
for field_name in field_configs:
setattr(self, f"_{field_name}", getattr(self, field_name, None))
doc_class.__init__ = new_init
for field_name, config in field_configs.items():
forbidden_values = config.get("forbidden_values")
custom_validation = config.get("custom_validation")
setattr(doc_class, field_name, traced_field(field_name, forbidden_values, custom_validation))
return doc_class
return decorator

View file

@ -6,6 +6,7 @@ import signal
import sys import sys
import time import time
import unittest import unittest
import warnings
import click import click
import requests import requests
@ -31,6 +32,8 @@ class ParallelTestRunner:
self.total_tests = 0 self.total_tests = 0
self.test_result = None self.test_result = None
self.setup_test_file_list() self.setup_test_file_list()
warnings.simplefilter("module", DeprecationWarning)
warnings.simplefilter("module", PendingDeprecationWarning)
def setup_and_run(self): def setup_and_run(self):
self.setup_test_site() self.setup_test_site()

View file

@ -8,43 +8,34 @@ This entire file is deprecated and will be removed in v17.
DO NOT ADD ANYTHING! DO NOT ADD ANYTHING!
""" """
from frappe.commands.testing import main from frappe.deprecation_dumpster import (
from frappe.testing.result import SLOW_TEST_THRESHOLD test_runner_add_to_test_record_log as add_to_test_record_log,
def xmlrunner_wrapper(output):
"""Convenience wrapper to keep method signature unchanged for XMLTestRunner and TextTestRunner"""
try:
import xmlrunner
except ImportError:
print("Development dependencies are required to execute this command. To install run:")
print("$ bench setup requirements --dev")
raise
def _runner(*args, **kwargs):
kwargs["output"] = output
return xmlrunner.XMLTestRunner(*args, **kwargs)
return _runner
# TODO: move to deprecation dumpster
from frappe.tests.utils import (
TestRecordLog,
get_dependencies,
get_modules,
make_test_objects,
make_test_records,
make_test_records_for_doctype,
print_mandatory_fields,
) )
from frappe.deprecation_dumpster import (
test_runner_get_dependencies as get_dependencies,
# TODO: move to deprecation dumpster )
# Compatibility functions from frappe.deprecation_dumpster import (
def add_to_test_record_log(doctype): test_runner_get_modules as get_modules,
TestRecordLog().add(doctype) )
from frappe.deprecation_dumpster import (
test_runner_get_test_record_log as get_test_record_log,
def get_test_record_log(): )
return TestRecordLog().get() from frappe.deprecation_dumpster import (
test_runner_main as main,
)
from frappe.deprecation_dumpster import (
test_runner_make_test_objects as make_test_objects,
)
from frappe.deprecation_dumpster import (
test_runner_make_test_records as make_test_records,
)
from frappe.deprecation_dumpster import (
test_runner_make_test_records_for_doctype as make_test_records_for_doctype,
)
from frappe.deprecation_dumpster import (
test_runner_print_mandatory_fields as print_mandatory_fields,
)
from frappe.deprecation_dumpster import (
test_xmlrunner_wrapper as xml_runner_wrapper,
)
from frappe.testing.result import SLOW_TEST_THRESHOLD

View file

@ -32,6 +32,7 @@ from pathlib import Path
import click import click
import frappe import frappe
from frappe.tests.classes.context_managers import debug_on
from .config import TestConfig from .config import TestConfig
from .discovery import TestRunnerError from .discovery import TestRunnerError
@ -112,8 +113,11 @@ class TestRunner(unittest.TextTestRunner):
def _apply_debug_decorators(self, suite): def _apply_debug_decorators(self, suite):
if self.cfg.pdb_on_exceptions: if self.cfg.pdb_on_exceptions:
for test in self._iterate_suite(suite): for test in self._iterate_suite(suite):
if hasattr(test, "_apply_debug_decorator"): setattr(
test._apply_debug_decorator(self.cfg.pdb_on_exceptions) test,
test._testMethodName,
debug_on(*self.cfg.pdb_on_exceptions)(getattr(test, test._testMethodName)),
)
@contextlib.contextmanager @contextlib.contextmanager
def _profile(self): def _profile(self):

View file

@ -5,28 +5,9 @@ from .classes.context_managers import *
global_test_dependencies = ["User"] global_test_dependencies = ["User"]
# TODO: move to dumpster - not meant to be a public interface anymore from frappe.deprecation_dumpster import (
import frappe.tests.utils as utils tests_get_system_setting as get_system_setting,
)
utils.IntegrationTestCase = IntegrationTestCase from frappe.deprecation_dumpster import (
utils.UnitTestCase = UnitTestCase tests_update_system_settings as update_system_settings,
utils.FrappeTestCase = IntegrationTestCase )
utils.change_settings = IntegrationTestCase.change_settings
utils.patch_hooks = UnitTestCase.patch_hooks
utils.debug_on = debug_on
utils.timeout = timeout
# TODO: move to dumpster
def update_system_settings(args, commit=False):
doc = frappe.get_doc("System Settings")
doc.update(args)
doc.flags.ignore_mandatory = 1
doc.save()
if commit:
frappe.db.commit()
# TODO: move to dumpster
def get_system_setting(key):
return frappe.db.get_single_value("System Settings", key)

View file

@ -1,116 +1,274 @@
import functools
import logging import logging
import pdb from collections.abc import Callable
import signal from contextlib import contextmanager
import sys from functools import wraps
import traceback from inspect import isfunction, ismethod
from typing import TYPE_CHECKING, Any
import frappe
from .integration_test_case import IntegrationTestCase
from .unit_test_case import UnitTestCase
if TYPE_CHECKING:
from frappe.model import Document
# NOTE: please lazily import any further namespaces within the contextmanager below
logger = logging.Logger(__file__) logger = logging.Logger(__file__)
###############################################################
# Decorators and Context Managers Implementation
# (each cm is automatically a decorator)
# NOTE: Keep all imports local to the decorator (!)
###############################################################
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def freeze_time(time_to_freeze: Any, is_utc: bool = False, *args: Any, **kwargs: Any) -> None:
"""Temporarily: freeze time with freezegun."""
import pytz
from freezegun import freeze_time as freezegun_freeze_time
from frappe.utils.data import get_datetime, get_system_timezone
if not is_utc:
# Freeze time expects UTC or tzaware objects. We have neither, so convert to UTC.
timezone = pytz.timezone(get_system_timezone())
time_to_freeze = timezone.localize(get_datetime(time_to_freeze)).astimezone(pytz.utc)
with freezegun_freeze_time(time_to_freeze, *args, **kwargs):
yield
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def set_user(user: str) -> None:
"""Temporarily: set the user."""
old_user = frappe.session.user
frappe.set_user(user)
yield
frappe.set_user(old_user)
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def patch_hooks(overridden_hooks: dict) -> None:
"""Temporarily: patch a hook."""
from unittest.mock import patch
get_hooks = frappe.get_hooks
def patched_hooks(hook=None, default="_KEEP_DEFAULT_LIST", app_name=None):
if hook in overridden_hooks:
return overridden_hooks[hook]
return get_hooks(hook, default, app_name)
with patch.object(frappe, "get_hooks", patched_hooks):
yield
@IntegrationTestCase.registerAs(staticmethod)
@contextmanager
def change_settings(doctype, settings_dict=None, /, commit=False, **settings) -> None:
"""Temporarily: change settings in a settings doctype."""
import copy
if settings_dict is None:
settings_dict = settings
settings = frappe.get_doc(doctype)
previous_settings = copy.deepcopy(settings_dict)
for key in previous_settings:
previous_settings[key] = getattr(settings, key)
for key, value in settings_dict.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
# singles are cached by default, clear to avoid flake
frappe.db.value_cache[settings] = {}
if commit:
frappe.db.commit()
yield
settings = frappe.get_doc(doctype)
for key, value in previous_settings.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
if commit:
frappe.db.commit()
@IntegrationTestCase.registerAs(staticmethod)
@contextmanager
def switch_site(site: str) -> None:
"""Temporarily: drop current connection and switch to a different site."""
old_site = frappe.local.site
frappe.init(site, force=True)
frappe.connect()
yield
frappe.init(old_site, force=True)
frappe.connect()
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def enable_safe_exec() -> None:
"""Temporarily: enable safe exec (server scripts)."""
import os
from frappe.installer import update_site_config
from frappe.utils.safe_exec import SAFE_EXEC_CONFIG_KEY
conf = os.path.join(frappe.local.sites_path, "common_site_config.json")
update_site_config(SAFE_EXEC_CONFIG_KEY, 1, validate=False, site_config_path=conf)
yield
update_site_config(SAFE_EXEC_CONFIG_KEY, 0, validate=False, site_config_path=conf)
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def debug_on(*exceptions) -> None:
"""Temporarily: enter an interactive debugger on specified exceptions, default: (AssertionError,)."""
import pdb
import sys
import traceback
if not exceptions:
exceptions = (AssertionError,)
try:
yield
except exceptions as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
# Pretty print the exception
print("\n\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[93m" + str(exc_type.__name__) + ": " + str(exc_value) + "\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
# Print the formatted traceback
traceback_lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
for line in traceback_lines:
print("\033[96m" + line.rstrip() + "\033[0m") # Cyan color
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[92mEntering post-mortem debugging\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
pdb.post_mortem()
raise e
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def timeout_context(seconds=30, error_message="Operation timed out.") -> None:
"""Temporarily: timeout an operation."""
import signal
def _handle_timeout(signum, frame):
raise Exception(error_message)
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(30 if callable(seconds) else seconds)
yield
signal.alarm(0)
def timeout(seconds=30, error_message="Operation timed out."):
"""Timeout decorator to ensure a test doesn't run for too long."""
def decorator(func=None):
@wraps(func)
def wrapper(*args, **kwargs):
with timeout_context(seconds, error_message):
return func(*args, **kwargs)
return wrapper
# Support bare @timeout
if callable(seconds):
return decorator(seconds)
return decorator
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def trace_fields(
doc_class: type,
field_name: str | None = None,
forbidden_values: list | None = None,
custom_validation: Callable | None = None,
**field_configs: dict[str, dict[str, list | Callable]],
) -> "Document":
"""
A context manager for temporarily tracing fields in a DocType.
Can be used in two ways:
1. Tracing a single field:
trace_fields(DocType, "field_name", forbidden_values=[...], custom_validation=...)
2. Tracing multiple fields:
trace_fields(DocType, field1={"forbidden_values": [...], "custom_validation": ...}, ...)
Args:
doc_class (Document): The DocType class to modify.
field_name (str, optional): The name of the field to trace (for single field tracing).
forbidden_values (list, optional): A list of forbidden values for the field (for single field tracing).
custom_validation (callable, optional): A custom validation function (for single field tracing).
**field_configs: Keyword arguments for multiple field tracing, where each key is a field name and
the value is a dict containing 'forbidden_values' and/or 'custom_validation'.
Yields:
Document class
"""
from frappe.model.trace import traced_field
original_attrs = {}
original_init = doc_class.__init__
# Prepare configurations
if field_name:
field_configs = {
field_name: {"forbidden_values": forbidden_values, "custom_validation": custom_validation}
}
# Apply traced fields
for f_name, config in field_configs.items():
original_attrs[f_name] = getattr(doc_class, f_name, None)
f_forbidden_values = config.get("forbidden_values")
f_custom_validation = config.get("custom_validation")
setattr(doc_class, f_name, traced_field(f_name, f_forbidden_values, f_custom_validation))
# Modify init method
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
for f_name in field_configs:
setattr(self, f"_{f_name}", getattr(self, f_name, None))
doc_class.__init__ = new_init
yield doc_class
# Restore original attributes and init method
for f_name, original_attr in original_attrs.items():
if original_attr is not None:
setattr(doc_class, f_name, original_attr)
else:
delattr(doc_class, f_name)
doc_class.__init__ = original_init
# NOTE: declare those who should also be made available directly frappe.tests.* namespace # NOTE: declare those who should also be made available directly frappe.tests.* namespace
# these can be general purpose context managers who do NOT depend on a particular # these can be general purpose context managers who do NOT depend on a particular
# test class setup, such as for example the IntegrationTestCase's connection to site # test class setup, such as for example the IntegrationTestCase's connection to site
__all__ = [ __all__ = [
"freeze_time",
"set_user",
"patch_hooks",
"change_settings",
"switch_site",
"enable_safe_exec",
"debug_on", "debug_on",
"timeout_context",
"timeout", "timeout",
"trace_fields",
] ]
def debug_on(*exceptions):
"""
A decorator to automatically start the debugger when specified exceptions occur.
This decorator allows you to automatically invoke the debugger (pdb) when certain
exceptions are raised in the decorated function. If no exceptions are specified,
it defaults to catching AssertionError.
Args:
*exceptions: Variable length argument list of exception classes to catch.
If none provided, defaults to (AssertionError,).
Returns:
function: A decorator function.
Usage:
1. Basic usage (catches AssertionError):
@debug_on()
def test_assertion_error():
assert False, "This will start the debugger"
2. Catching specific exceptions:
@debug_on(ValueError, TypeError)
def test_specific_exceptions():
raise ValueError("This will start the debugger")
3. Using on a method in a test class:
class TestMyFunctionality(unittest.TestCase):
@debug_on(ZeroDivisionError)
def test_division_by_zero(self):
result = 1 / 0
Note:
When an exception is caught, this decorator will print the exception traceback
and then start the post-mortem debugger, allowing you to inspect the state of
the program at the point where the exception was raised.
"""
if not exceptions:
exceptions = (AssertionError,)
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except exceptions as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
# Pretty print the exception
print("\n\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[93m" + str(exc_type.__name__) + ": " + str(exc_value) + "\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
# Print the formatted traceback
traceback_lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
for line in traceback_lines:
print("\033[96m" + line.rstrip() + "\033[0m") # Cyan color
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[92mEntering post-mortem debugging\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
pdb.post_mortem()
raise e
return wrapper
return decorator
def timeout(seconds=30, error_message="Test timed out."):
"""Timeout decorator to ensure a test doesn't run for too long.
adapted from https://stackoverflow.com/a/2282656"""
# Support @timeout (without function call)
no_args = bool(callable(seconds))
actual_timeout = 30 if no_args else seconds
actual_error_message = "Test timed out" if no_args else error_message
def decorator(func):
def _handle_timeout(signum, frame):
raise Exception(actual_error_message)
def wrapper(*args, **kwargs):
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(actual_timeout)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
return wrapper
if no_args:
return decorator(seconds)
return decorator

View file

@ -31,6 +31,8 @@ class IntegrationTestCase(UnitTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
if getattr(cls, "_integration_test_case_class_setup_done", None):
return
super().setUpClass() super().setUpClass()
# Site initialization # Site initialization
@ -57,6 +59,7 @@ class IntegrationTestCase(UnitTestCase):
# enqueue teardown actions (executed in LIFO order) # enqueue teardown actions (executed in LIFO order)
cls.addClassCleanup(_restore_thread_locals, copy.deepcopy(frappe.local.flags)) cls.addClassCleanup(_restore_thread_locals, copy.deepcopy(frappe.local.flags))
cls.addClassCleanup(_rollback_db) cls.addClassCleanup(_rollback_db)
cls._integration_test_case_class_setup_done = True
@classmethod @classmethod
def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:
@ -161,68 +164,6 @@ class IntegrationTestCase(UnitTestCase):
finally: finally:
frappe.db.sql = orig_sql frappe.db.sql = orig_sql
@contextmanager
def switch_site(self, site: str) -> AbstractContextManager[None]:
"""Switch connection to different site.
Note: Drops current site connection completely."""
try:
old_site = frappe.local.site
frappe.init(site, force=True)
frappe.connect()
yield
finally:
frappe.init(old_site, force=True)
frappe.connect()
@staticmethod
@contextmanager
def change_settings(doctype, settings_dict=None, /, commit=False, **settings):
"""A context manager to ensure that settings are changed before running
function and restored after running it regardless of exceptions occurred.
This is useful in tests where you want to make changes in a function but
don't retain those changes.
import and use as decorator to cover full function or using `with` statement.
example:
@change_settings("Print Settings", {"send_print_as_pdf": 1})
def test_case(self):
...
@change_settings("Print Settings", send_print_as_pdf=1)
def test_case(self):
...
"""
if settings_dict is None:
settings_dict = settings
try:
settings = frappe.get_doc(doctype)
# remember setting
previous_settings = copy.deepcopy(settings_dict)
for key in previous_settings:
previous_settings[key] = getattr(settings, key)
# change setting
for key, value in settings_dict.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
# singles are cached by default, clear to avoid flake
frappe.db.value_cache[settings] = {}
if commit:
frappe.db.commit()
yield # yield control to calling function
finally:
# restore settings
settings = frappe.get_doc(doctype)
for key, value in previous_settings.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
if commit:
frappe.db.commit()
def _commit_watcher(): def _commit_watcher():
import traceback import traceback

View file

@ -4,26 +4,29 @@ import logging
import os import os
import unittest import unittest
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import patch
import pytz
import frappe import frappe
from frappe.model.base_document import BaseDocument from frappe.model.base_document import BaseDocument
from frappe.utils import cint from frappe.utils import cint
from frappe.utils.data import get_datetime, get_system_timezone
from .context_managers import debug_on
logger = logging.Logger(__file__) logger = logging.Logger(__file__)
datetime_like_types = (datetime.datetime, datetime.date, datetime.time, datetime.timedelta) datetime_like_types = (datetime.datetime, datetime.date, datetime.time, datetime.timedelta)
class UnitTestCase(unittest.TestCase): class BaseTestCase:
@classmethod
def registerAs(cls, _as):
def decorator(cm_func):
setattr(cls, cm_func.__name__, _as(cm_func))
return cm_func
return decorator
class UnitTestCase(unittest.TestCase, BaseTestCase):
"""Unit test class for Frappe tests. """Unit test class for Frappe tests.
This class extends unittest.TestCase and provides additional utilities This class extends unittest.TestCase and provides additional utilities
@ -41,27 +44,12 @@ class UnitTestCase(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
if getattr(cls, "_unit_test_case_class_setup_done", None):
return
super().setUpClass() super().setUpClass()
cls.doctype = cls._get_doctype_from_module() cls.doctype = _get_doctype_from_module(cls)
cls.module = frappe.get_module(cls.__module__) cls.module = frappe.get_module(cls.__module__)
cls._unit_test_case_class_setup_done = True
@classmethod
def _get_doctype_from_module(cls):
module_path = cls.__module__.split(".")
try:
doctype_index = module_path.index("doctype")
doctype_snake_case = module_path[doctype_index + 1]
json_file_path = Path(*module_path[:-1]).joinpath(f"{doctype_snake_case}.json")
if json_file_path.is_file():
doctype_data = json.loads(json_file_path.read_text())
return doctype_data.get("name")
except (ValueError, IndexError):
# 'doctype' not found in module_path
pass
return None
def _apply_debug_decorator(self, exceptions=()):
setattr(self, self._testMethodName, debug_on(*exceptions)(getattr(self, self._testMethodName)))
def assertQueryEqual(self, first: str, second: str) -> None: def assertQueryEqual(self, first: str, second: str) -> None:
self.assertEqual(self.normalize_sql(first), self.normalize_sql(second)) self.assertEqual(self.normalize_sql(first), self.normalize_sql(second))
@ -101,65 +89,31 @@ class UnitTestCase(unittest.TestCase):
else: else:
self.assertEqual(expected, actual, msg=msg) self.assertEqual(expected, actual, msg=msg)
def normalize_html(self, code: str) -> str: @staticmethod
def normalize_html(code: str) -> str:
"""Formats HTML consistently so simple string comparisons can work on them.""" """Formats HTML consistently so simple string comparisons can work on them."""
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
return BeautifulSoup(code, "html.parser").prettify(formatter=None) return BeautifulSoup(code, "html.parser").prettify(formatter=None)
@contextmanager @staticmethod
def set_user(self, user: str) -> AbstractContextManager[None]: def normalize_sql(query: str) -> str:
try:
old_user = frappe.session.user
frappe.set_user(user)
yield
finally:
frappe.set_user(old_user)
def normalize_sql(self, query: str) -> str:
"""Formats SQL consistently so simple string comparisons can work on them.""" """Formats SQL consistently so simple string comparisons can work on them."""
import sqlparse import sqlparse
return sqlparse.format(query.strip(), keyword_case="upper", reindent=True, strip_comments=True) return sqlparse.format(query.strip(), keyword_case="upper", reindent=True, strip_comments=True)
@classmethod
def enable_safe_exec(cls) -> None:
"""Enable safe exec and disable them after test case is completed."""
from frappe.installer import update_site_config
from frappe.utils.safe_exec import SAFE_EXEC_CONFIG_KEY
cls._common_conf = os.path.join(frappe.local.sites_path, "common_site_config.json") def _get_doctype_from_module(cls):
update_site_config(SAFE_EXEC_CONFIG_KEY, 1, validate=False, site_config_path=cls._common_conf) module_path = cls.__module__.split(".")
try:
cls.addClassCleanup( doctype_index = module_path.index("doctype")
lambda: update_site_config( doctype_snake_case = module_path[doctype_index + 1]
SAFE_EXEC_CONFIG_KEY, 0, validate=False, site_config_path=cls._common_conf json_file_path = Path(*module_path[:-1]).joinpath(f"{doctype_snake_case}.json")
) if json_file_path.is_file():
) doctype_data = json.loads(json_file_path.read_text())
return doctype_data.get("name")
@staticmethod except (ValueError, IndexError):
@contextmanager # 'doctype' not found in module_path
def patch_hooks(overridden_hooks: dict) -> AbstractContextManager[None]: pass
get_hooks = frappe.get_hooks return None
def patched_hooks(hook=None, default="_KEEP_DEFAULT_LIST", app_name=None):
if hook in overridden_hooks:
return overridden_hooks[hook]
return get_hooks(hook, default, app_name)
with patch.object(frappe, "get_hooks", patched_hooks):
yield
@contextmanager
def freeze_time(
self, time_to_freeze: Any, is_utc: bool = False, *args: Any, **kwargs: Any
) -> AbstractContextManager[None]:
from freezegun import freeze_time
if not is_utc:
# Freeze time expects UTC or tzaware objects. We have neither, so convert to UTC.
timezone = pytz.timezone(get_system_timezone())
time_to_freeze = timezone.localize(get_datetime(time_to_freeze)).astimezone(pytz.utc)
with freeze_time(time_to_freeze, *args, **kwargs):
yield

View file

@ -31,7 +31,7 @@ class TestBootData(IntegrationTestCase):
class TestPermissionQueries(IntegrationTestCase): class TestPermissionQueries(IntegrationTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
cls.enable_safe_exec() cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass() return super().setUpClass()
def test_get_user_pages_or_reports_with_permission_query(self): def test_get_user_pages_or_reports_with_permission_query(self):

View file

@ -13,8 +13,6 @@ from frappe.tests import IntegrationTestCase
from frappe.utils import cint, now_datetime, set_request from frappe.utils import cint, now_datetime, set_request
from frappe.website.serve import get_response from frappe.website.serve import get_response
from . import update_system_settings
class CustomTestNote(Note): class CustomTestNote(Note):
@property @property
@ -538,13 +536,13 @@ class TestDocumentWebView(IntegrationTestCase):
document_key = todo.get_document_share_key() document_key = todo.get_document_share_key()
# with old-style signature key # with old-style signature key
update_system_settings({"allow_older_web_view_links": True}, True) with self.change_settings("System Settings", {"allow_older_web_view_links": True}):
old_document_key = todo.get_signature() old_document_key = todo.get_signature()
url = f"/ToDo/{todo.name}?key={old_document_key}" url = f"/ToDo/{todo.name}?key={old_document_key}"
self.assertEqual(self.get(url).status, "200 OK") self.assertEqual(self.get(url).status, "200 OK")
update_system_settings({"allow_older_web_view_links": False}, True) with self.change_settings("System Settings", {"allow_older_web_view_links": False}):
self.assertEqual(self.get(url).status, "401 UNAUTHORIZED") self.assertEqual(self.get(url).status, "401 UNAUTHORIZED")
# with valid key # with valid key
url = f"/ToDo/{todo.name}?key={document_key}" url = f"/ToDo/{todo.name}?key={document_key}"

View file

@ -11,7 +11,7 @@ from frappe.utils.xlsxutils import make_xlsx
class TestQueryReport(IntegrationTestCase): class TestQueryReport(IntegrationTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
cls.enable_safe_exec() cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass() return super().setUpClass()
def tearDown(self): def tearDown(self):

View file

@ -8,7 +8,7 @@ from frappe.utils.safe_exec import ServerScriptNotEnabled, get_safe_globals, saf
class TestSafeExec(IntegrationTestCase): class TestSafeExec(IntegrationTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
cls.enable_safe_exec() cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass() return super().setUpClass()
def test_import_fails(self): def test_import_fails(self):

View file

@ -3,7 +3,8 @@ from unittest.mock import MagicMock, patch
import frappe import frappe
from frappe.model.document import Document from frappe.model.document import Document
from frappe.model.trace import TracedDocument, trace_fields, traced_field, traced_field_context from frappe.model.trace import TracedDocument, traced_field
from frappe.tests import UnitTestCase
def create_mock_meta(doctype): def create_mock_meta(doctype):
@ -65,7 +66,7 @@ class TestTrace(unittest.TestCase):
self.assertEqual(valid_dict["positive_field"], 15) self.assertEqual(valid_dict["positive_field"], 15)
class TestTracedFieldContext(unittest.TestCase): class TestTracedFieldContext(UnitTestCase):
def test_traced_field_context(self): def test_traced_field_context(self):
doc = TestDocument() doc = TestDocument()
@ -73,7 +74,7 @@ class TestTracedFieldContext(unittest.TestCase):
doc.test_field = "forbidden" doc.test_field = "forbidden"
self.assertEqual(doc.test_field, "forbidden") self.assertEqual(doc.test_field, "forbidden")
with traced_field_context(TestDocument, "test_field", forbidden_values=["forbidden"]): with self.trace_fields(TestDocument, "test_field", forbidden_values=["forbidden"]):
# Inside context # Inside context
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
doc.test_field = "forbidden" doc.test_field = "forbidden"
@ -92,7 +93,7 @@ class TestTracedFieldContext(unittest.TestCase):
if value % 2 != 0: if value % 2 != 0:
raise ValueError("Value must be even") raise ValueError("Value must be even")
with traced_field_context(TestDocument, "number_field", custom_validation=validate_even): with self.trace_fields(TestDocument, "number_field", custom_validation=validate_even):
doc.number_field = 2 doc.number_field = 2
self.assertEqual(doc.number_field, 2) self.assertEqual(doc.number_field, 2)
@ -111,7 +112,7 @@ class TestTracedFieldContext(unittest.TestCase):
frappe.flags.in_test = False frappe.flags.in_test = False
try: try:
with traced_field_context(TestDocument, "test_field", forbidden_values=["forbidden"]): with self.trace_fields(TestDocument, test_field={"forbidden_values": ["forbidden"]}):
with self.assertRaises(frappe.exceptions.ValidationError): with self.assertRaises(frappe.exceptions.ValidationError):
doc.test_field = "forbidden" doc.test_field = "forbidden"
@ -126,38 +127,5 @@ class TestTracedFieldContext(unittest.TestCase):
self.assertEqual(doc.test_field, "forbidden") self.assertEqual(doc.test_field, "forbidden")
def validate_positive(obj, value):
if value <= 0:
raise ValueError("Value must be positive")
class TestTraceFieldDecorator(unittest.TestCase):
@trace_fields(decorated_field={"forbidden_values": ["bad"]})
class DecoratedTestDocument(TestDocument):
pass
def test_trace_field_decorator(self):
doc = self.DecoratedTestDocument()
with self.assertRaises(AssertionError):
doc.decorated_field = "bad"
doc.decorated_field = "good"
self.assertEqual(doc.decorated_field, "good")
@trace_fields(positive_field={"custom_validation": validate_positive})
class PositiveFieldDocument(TestDocument):
pass
def test_trace_field_decorator_custom_validation(self):
doc = self.PositiveFieldDocument()
with self.assertRaises(AssertionError):
doc.positive_field = -1
doc.positive_field = 1
self.assertEqual(doc.positive_field, 1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -20,26 +20,19 @@ from frappe.twofactor import (
) )
from frappe.utils import cint, set_request from frappe.utils import cint, set_request
from . import get_system_setting, update_system_settings
class TestTwoFactor(IntegrationTestCase): class TestTwoFactor(IntegrationTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.default_allowed_login_attempts = get_system_setting("allow_consecutive_login_attempts")
def setUp(self): def setUp(self):
self.http_requests = create_http_request() self.http_requests = create_http_request()
self.login_manager = frappe.local.login_manager self.login_manager = frappe.local.login_manager
self.user = self.login_manager.user self.user = self.login_manager.user
update_system_settings({"allow_consecutive_login_attempts": 2}) self.enterContext(self.change_settings("System Settings", {"allow_consecutive_login_attempts": 2}))
def tearDown(self): def tearDown(self):
frappe.local.response["verification"] = None frappe.local.response["verification"] = None
frappe.local.response["tmp_id"] = None frappe.local.response["tmp_id"] = None
disable_2fa() disable_2fa()
frappe.clear_cache(user=self.user) frappe.clear_cache(user=self.user)
update_system_settings({"allow_consecutive_login_attempts": self.default_allowed_login_attempts})
def test_should_run_2fa(self): def test_should_run_2fa(self):
"""Should return true if enabled.""" """Should return true if enabled."""

View file

@ -24,3 +24,26 @@ def check_orpahned_doctypes():
frappe.throw( frappe.throw(
"Following doctypes exist in DB without controller.\n {}".format("\n".join(orpahned_doctypes)) "Following doctypes exist in DB without controller.\n {}".format("\n".join(orpahned_doctypes))
) )
from frappe.deprecation_dumpster import (
tests_change_settings as change_settings,
)
from frappe.deprecation_dumpster import (
tests_debug_on as debug_on,
)
from frappe.deprecation_dumpster import (
tests_FrappeTestCase as FrappeTestCase,
)
from frappe.deprecation_dumpster import (
tests_IntegrationTestCase as IntegrationTestCase,
)
from frappe.deprecation_dumpster import (
tests_patch_hooks as patch_hooks,
)
from frappe.deprecation_dumpster import (
tests_timeout as timeout,
)
from frappe.deprecation_dumpster import (
tests_UnitTestCase as UnitTestCase,
)