refactor: organize test contextmanagers (#28041)

* refactor: prefer staticmethod decorator

* refactor: add cm register utility and keep cms in one file

* refactor: enter safe_exec enabled context (treewide)

* refactor: move trace fields to the other test context managers

* chore: marke all test_runner functions for deprecation

* chore: mark some tests.utils functions for deprecation (moved)

* chore: mark traced_field_conext for deprecation (moved)

* chore: placate semgrep in dumpster

* fix: show deprecation warnings per module in tests (incl. from dumpster)

* chore: remove use of deprecated functions from tests
This commit is contained in:
David Arnold 2024-10-09 02:09:19 +02:00 committed by GitHub
parent c8f42fe15d
commit 95950c8d81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 638 additions and 514 deletions

View file

@ -29,7 +29,7 @@ class UnitTestReport(UnitTestCase):
class TestReport(IntegrationTestCase):
@classmethod
def setUpClass(cls) -> None:
cls.enable_safe_exec()
cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass()
def test_report_builder(self):

View file

@ -127,7 +127,7 @@ class TestServerScript(IntegrationTestCase):
script_doc = frappe.get_doc(doctype="Server Script")
script_doc.update(script)
script_doc.insert()
cls.enable_safe_exec()
cls.enterClassContext(cls.enable_safe_exec())
frappe.db.commit()
return super().setUpClass()

View file

@ -288,3 +288,245 @@ def get_translated_dict():
)
def validate_roles(self):
self.populate_role_profile_roles()
@deprecated(
"frappe.tests_runner.get_dependencies", "2024-20-08", "v17", "use frappe.tests.utils.get_dependencies"
)
def test_runner_get_dependencies(doctype):
from frappe.tests.utils import get_dependencies
return get_dependencies(doctype)
@deprecated("frappe.tests_runner.get_modules", "2024-20-08", "v17", "use frappe.tests.utils.get_modules")
def test_runner_get_modules(doctype):
from frappe.tests.utils import get_modules
return get_modules(doctype)
@deprecated(
"frappe.tests_runner.make_test_records", "2024-20-08", "v17", "use frappe.tests.utils.make_test_records"
)
def test_runner_make_test_records(*args, **kwargs):
from frappe.tests.utils import make_test_records
return make_test_records(*args, **kwargs)
@deprecated(
"frappe.tests_runner.make_test_objects", "2024-20-08", "v17", "use frappe.tests.utils.make_test_objects"
)
def test_runner_make_test_objects(*args, **kwargs):
from frappe.tests.utils import make_test_objects
return make_test_objects(*args, **kwargs)
@deprecated(
"frappe.tests_runner.make_test_records_for_doctype",
"2024-20-08",
"v17",
"use frappe.tests.utils.make_test_records_for_doctype",
)
def test_runner_make_test_records_for_doctype(*args, **kwargs):
from frappe.tests.utils import make_test_records_for_doctype
return make_test_records_for_doctype(*args, **kwargs)
@deprecated(
"frappe.tests_runner.print_mandatory_fields",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_runner_print_mandatory_fields(*args, **kwargs):
from frappe.tests.utils.generators import print_mandatory_fields
return print_mandatory_fields(*args, **kwargs)
@deprecated(
"frappe.tests_runner.get_test_record_log",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_runner_get_test_record_log(doctype):
from frappe.tests.utils.generators import TestRecordLog
return TestRecordLog().get(doctype)
@deprecated(
"frappe.tests_runner.add_to_test_record_log",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_runner_add_to_test_record_log(doctype):
from frappe.tests.utils.generators import TestRecordLog
return TestRecordLog().add(doctype)
@deprecated(
"frappe.tests_runner.main",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_runner_main(*args, **kwargs):
from frappe.commands.testing import main
return main(*args, **kwargs)
@deprecated(
"frappe.tests_runner.xmlrunner_wrapper",
"2024-20-08",
"v17",
"no public api anymore",
)
def test_xmlrunner_wrapper(output):
"""Convenience wrapper to keep method signature unchanged for XMLTestRunner and TextTestRunner"""
try:
import xmlrunner
except ImportError:
print("Development dependencies are required to execute this command. To install run:")
print("$ bench setup requirements --dev")
raise
def _runner(*args, **kwargs):
kwargs["output"] = output
return xmlrunner.XMLTestRunner(*args, **kwargs)
return _runner
@deprecated(
"frappe.tests.upate_system_settings",
"2024-20-08",
"v17",
"use with `self.change_settings(...):` context manager",
)
def tests_update_system_settings(args, commit=False):
import frappe
doc = frappe.get_doc("System Settings")
doc.update(args)
doc.flags.ignore_mandatory = 1
doc.save()
if commit:
# moved here
frappe.db.commit() # nosemgrep
@deprecated(
"frappe.tests.get_system_setting",
"2024-20-08",
"v17",
"use `frappe.db.get_single_value('System Settings', key)`",
)
def tests_get_system_setting(key):
import frappe
return frappe.db.get_single_value("System Settings", key)
@deprecated(
"frappe.tests.utils.change_settings",
"2024-20-08",
"v17",
"use `frappe.tests.change_settings` or the cls.change_settings",
)
def tests_change_settings(*args, **kwargs):
from frappe.tests.classes.context_managers import change_settings
return change_settings(*args, **kwargs)
@deprecated(
"frappe.tests.utils.patch_hooks",
"2024-20-08",
"v17",
"use `frappe.tests.patch_hooks` or the cls.patch_hooks",
)
def tests_patch_hooks(*args, **kwargs):
from frappe.tests.classes.context_managers import patch_hooks
return patch_hooks(*args, **kwargs)
@deprecated(
"frappe.tests.utils.debug_on",
"2024-20-08",
"v17",
"use `frappe.tests.debug_on` or the cls.debug_on",
)
def tests_debug_on(*args, **kwargs):
from frappe.tests.classes.context_managers import debug_on
return debug_on(*args, **kwargs)
@deprecated(
"frappe.tests.utils.timeout",
"2024-20-08",
"v17",
"use `frappe.tests.timeout` or the cls.timeout",
)
def tests_timeout(*args, **kwargs):
from frappe.tests.classes.context_managers import timeout
return timeout(*args, **kwargs)
@deprecated(
"frappe.tests.utils.FrappeTestCase",
"2024-20-08",
"v17",
"use `frappe.tests.UnitTestCase` or `frappe.tests.IntegrationTestCase` respectively",
)
def tests_FrappeTestCase(*args, **kwargs):
from frappe.tests import IntegrationTestCase
return IntegrationTestCase(*args, **kwargs)
@deprecated(
"frappe.tests.utils.IntegrationTestCase",
"2024-20-08",
"v17",
"use `frappe.tests.IntegrationTestCase`",
)
def tests_IntegrationTestCase(*args, **kwargs):
from frappe.tests import IntegrationTestCase
return IntegrationTestCase(*args, **kwargs)
@deprecated(
"frappe.tests.utils.UnitTestCase",
"2024-20-08",
"v17",
"use `frappe.tests.UnitTestCase`",
)
def tests_UnitTestCase(*args, **kwargs):
from frappe.tests import UnitTestCase
return UnitTestCase(*args, **kwargs)
@deprecated(
"frappe.model.trace.traced_field_context",
"2024-20-08",
"v17",
"use `cls.trace_fields`",
)
def model_trace_traced_field_context(*args, **kwargs):
from frappe.tests.classes.context_managers import trace_fields
return trace_fields(*args, **kwargs)

View file

@ -16,7 +16,7 @@ class UnitTestSystemConsole(UnitTestCase):
class TestSystemConsole(IntegrationTestCase):
@classmethod
def setUpClass(cls) -> None:
cls.enable_safe_exec()
cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass()
def test_system_console(self):

View file

@ -2,8 +2,7 @@
Traced Fields for Frappe
This module provides utilities for creating traced fields in Frappe documents,
which is particularly useful for instrumenting or debugging test cases and
enforcing strict validation rules.
which is particularly useful for enforcing strict value lifetime validation rules.
Key features:
- Create fields that can be monitored for specific value changes
@ -11,12 +10,6 @@ Key features:
- Apply custom validation logic to fields
- Seamlessly integrate with Frappe's document model
Usage in test cases:
1. Subclass your DocType from TracedDocument alongside Document
2. Use traced_field to define fields you want to monitor
3. Specify forbidden values or custom validation functions
4. In your tests, attempt to set values and check for raised exceptions
Example of standard usage:
from frappe.model.trace import TracedDocument, traced_field
@ -30,67 +23,9 @@ Example of standard usage:
amount = traced_field("Amount", custom_validation = validate_amount)
...
class TestCustomInvoice(unittest.TestCase):
def setUp(self):
self.invoice = CustomSalesInvoice()
def test_forbidden_loyalty_program(self):
with self.assertRaises(AssertionError):
self.invoice.loyalty_program = "FORBIDDEN_PROGRAM"
def test_negative_amount(self):
with self.assertRaises(AssertionError):
self.invoice.amount = -100
Benefits for testing:
- Easily catch unauthorized value changes
- Enforce business rules at the field level
- Improve test coverage by explicitly checking field-level validations
- Simulate and test error conditions more effectively
Monkey Patching for Debugging:
For temporary tracing of fields in existing DocTypes, use the traced_field_context
context manager. This allows you to add tracing to any field without modifying
the original DocType class.
Example of monkey patching with context manager:
import unittest
from frappe.model.document import Document
from frappe.model.trace import traced_field_context
class TestExistingDocType(unittest.TestCase):
def test_debug_value(self):
def validate_some_field(obj, value):
if value == 'debug_value':
raise AssertionError("Debug value detected")
doc = frappe.get_doc("My Doc Type")
with traced_field_context(
doc.__class__,
'some_field',
custom_validation=validate_some_field
):
with self.assertRaises(AssertionError):
doc.some_field = 'debug_value'
# Outside the context, the original behavior is restored
doc.some_field = 'debug_value' # This will not raise an error
This approach allows you to:
- Easily add temporary tracing to any field in any DocType
- Debug issues by catching specific value changes
- Add custom validation logic for debugging purposes
- Automatically reverts changes after the context, ensuring no side effects
- Cleaner and more Pythonic approach to temporary monkey patching
Note: While primarily designed for testing, this can also be used in
production code to enforce strict data integrity rules. However, be
mindful of potential performance implications in high-traffic scenarios.
See frappe.tests.classes.context_managers for a context manager built into test classes.
"""
import contextlib
import frappe
from frappe.model.document import Document
@ -172,22 +107,6 @@ class TracedValue:
setattr(obj, f"_{self.field_name}", value)
def traced_field(*args, **kwargs):
"""
A convenience function for creating TracedValue instances.
This function simplifies the creation of traced fields in Frappe documents.
Args:
*args: Positional arguments to pass to TracedValue constructor.
**kwargs: Keyword arguments to pass to TracedValue constructor.
Returns:
TracedValue: An instance of the TracedValue descriptor.
"""
return TracedValue(*args, **kwargs)
class TracedDocument(Document):
"""
A base class for Frappe documents with traced fields.
@ -234,71 +153,20 @@ class TracedDocument(Document):
return d
@contextlib.contextmanager
def traced_field_context(doc_class, field_name, forbidden_values=None, custom_validation=None):
def traced_field(*args, **kwargs):
"""
A context manager for temporarily tracing a field in a DocType.
A convenience function for creating TracedValue instances.
This function simplifies the creation of traced fields in Frappe documents.
Args:
doc_class (type): The DocType class to modify.
field_name (str): The name of the field to trace.
forbidden_values (list, optional): A list of forbidden values for the field.
custom_validation (callable, optional): A custom validation function.
Yields:
None
"""
original_attr = getattr(doc_class, field_name, None)
original_init = doc_class.__init__
try:
setattr(doc_class, field_name, traced_field(field_name, forbidden_values, custom_validation))
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
setattr(self, f"_{field_name}", getattr(self, field_name, None))
doc_class.__init__ = new_init
yield
finally:
if original_attr is not None:
setattr(doc_class, field_name, original_attr)
else:
delattr(doc_class, field_name)
doc_class.__init__ = original_init
def trace_fields(**field_configs):
"""
A class decorator to permanently trace fields in a DocType.
Args:
**field_configs: Keyword arguments where each key is a field name and
the value is a dict containing 'forbidden_values' and/or
'custom_validation'.
*args: Positional arguments to pass to TracedValue constructor.
**kwargs: Keyword arguments to pass to TracedValue constructor.
Returns:
callable: A decorator function that modifies the DocType class.
TracedValue: An instance of the TracedValue descriptor.
"""
return TracedValue(*args, **kwargs)
def decorator(doc_class):
original_init = doc_class.__init__
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
for field_name in field_configs:
setattr(self, f"_{field_name}", getattr(self, field_name, None))
doc_class.__init__ = new_init
for field_name, config in field_configs.items():
forbidden_values = config.get("forbidden_values")
custom_validation = config.get("custom_validation")
setattr(doc_class, field_name, traced_field(field_name, forbidden_values, custom_validation))
return doc_class
return decorator
from frappe.deprecation_dumpster import model_trace_traced_field_context as traced_field_context

View file

@ -6,6 +6,7 @@ import signal
import sys
import time
import unittest
import warnings
import click
import requests
@ -31,6 +32,8 @@ class ParallelTestRunner:
self.total_tests = 0
self.test_result = None
self.setup_test_file_list()
warnings.simplefilter("module", DeprecationWarning)
warnings.simplefilter("module", PendingDeprecationWarning)
def setup_and_run(self):
self.setup_test_site()

View file

@ -8,43 +8,34 @@ This entire file is deprecated and will be removed in v17.
DO NOT ADD ANYTHING!
"""
from frappe.commands.testing import main
from frappe.testing.result import SLOW_TEST_THRESHOLD
def xmlrunner_wrapper(output):
"""Convenience wrapper to keep method signature unchanged for XMLTestRunner and TextTestRunner"""
try:
import xmlrunner
except ImportError:
print("Development dependencies are required to execute this command. To install run:")
print("$ bench setup requirements --dev")
raise
def _runner(*args, **kwargs):
kwargs["output"] = output
return xmlrunner.XMLTestRunner(*args, **kwargs)
return _runner
# TODO: move to deprecation dumpster
from frappe.tests.utils import (
TestRecordLog,
get_dependencies,
get_modules,
make_test_objects,
make_test_records,
make_test_records_for_doctype,
print_mandatory_fields,
from frappe.deprecation_dumpster import (
test_runner_add_to_test_record_log as add_to_test_record_log,
)
# TODO: move to deprecation dumpster
# Compatibility functions
def add_to_test_record_log(doctype):
TestRecordLog().add(doctype)
def get_test_record_log():
return TestRecordLog().get()
from frappe.deprecation_dumpster import (
test_runner_get_dependencies as get_dependencies,
)
from frappe.deprecation_dumpster import (
test_runner_get_modules as get_modules,
)
from frappe.deprecation_dumpster import (
test_runner_get_test_record_log as get_test_record_log,
)
from frappe.deprecation_dumpster import (
test_runner_main as main,
)
from frappe.deprecation_dumpster import (
test_runner_make_test_objects as make_test_objects,
)
from frappe.deprecation_dumpster import (
test_runner_make_test_records as make_test_records,
)
from frappe.deprecation_dumpster import (
test_runner_make_test_records_for_doctype as make_test_records_for_doctype,
)
from frappe.deprecation_dumpster import (
test_runner_print_mandatory_fields as print_mandatory_fields,
)
from frappe.deprecation_dumpster import (
test_xmlrunner_wrapper as xml_runner_wrapper,
)
from frappe.testing.result import SLOW_TEST_THRESHOLD

View file

@ -32,6 +32,7 @@ from pathlib import Path
import click
import frappe
from frappe.tests.classes.context_managers import debug_on
from .config import TestConfig
from .discovery import TestRunnerError
@ -112,8 +113,11 @@ class TestRunner(unittest.TextTestRunner):
def _apply_debug_decorators(self, suite):
if self.cfg.pdb_on_exceptions:
for test in self._iterate_suite(suite):
if hasattr(test, "_apply_debug_decorator"):
test._apply_debug_decorator(self.cfg.pdb_on_exceptions)
setattr(
test,
test._testMethodName,
debug_on(*self.cfg.pdb_on_exceptions)(getattr(test, test._testMethodName)),
)
@contextlib.contextmanager
def _profile(self):

View file

@ -5,28 +5,9 @@ from .classes.context_managers import *
global_test_dependencies = ["User"]
# TODO: move to dumpster - not meant to be a public interface anymore
import frappe.tests.utils as utils
utils.IntegrationTestCase = IntegrationTestCase
utils.UnitTestCase = UnitTestCase
utils.FrappeTestCase = IntegrationTestCase
utils.change_settings = IntegrationTestCase.change_settings
utils.patch_hooks = UnitTestCase.patch_hooks
utils.debug_on = debug_on
utils.timeout = timeout
# TODO: move to dumpster
def update_system_settings(args, commit=False):
doc = frappe.get_doc("System Settings")
doc.update(args)
doc.flags.ignore_mandatory = 1
doc.save()
if commit:
frappe.db.commit()
# TODO: move to dumpster
def get_system_setting(key):
return frappe.db.get_single_value("System Settings", key)
from frappe.deprecation_dumpster import (
tests_get_system_setting as get_system_setting,
)
from frappe.deprecation_dumpster import (
tests_update_system_settings as update_system_settings,
)

View file

@ -1,116 +1,274 @@
import functools
import logging
import pdb
import signal
import sys
import traceback
from collections.abc import Callable
from contextlib import contextmanager
from functools import wraps
from inspect import isfunction, ismethod
from typing import TYPE_CHECKING, Any
import frappe
from .integration_test_case import IntegrationTestCase
from .unit_test_case import UnitTestCase
if TYPE_CHECKING:
from frappe.model import Document
# NOTE: please lazily import any further namespaces within the contextmanager below
logger = logging.Logger(__file__)
###############################################################
# Decorators and Context Managers Implementation
# (each cm is automatically a decorator)
# NOTE: Keep all imports local to the decorator (!)
###############################################################
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def freeze_time(time_to_freeze: Any, is_utc: bool = False, *args: Any, **kwargs: Any) -> None:
"""Temporarily: freeze time with freezegun."""
import pytz
from freezegun import freeze_time as freezegun_freeze_time
from frappe.utils.data import get_datetime, get_system_timezone
if not is_utc:
# Freeze time expects UTC or tzaware objects. We have neither, so convert to UTC.
timezone = pytz.timezone(get_system_timezone())
time_to_freeze = timezone.localize(get_datetime(time_to_freeze)).astimezone(pytz.utc)
with freezegun_freeze_time(time_to_freeze, *args, **kwargs):
yield
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def set_user(user: str) -> None:
"""Temporarily: set the user."""
old_user = frappe.session.user
frappe.set_user(user)
yield
frappe.set_user(old_user)
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def patch_hooks(overridden_hooks: dict) -> None:
"""Temporarily: patch a hook."""
from unittest.mock import patch
get_hooks = frappe.get_hooks
def patched_hooks(hook=None, default="_KEEP_DEFAULT_LIST", app_name=None):
if hook in overridden_hooks:
return overridden_hooks[hook]
return get_hooks(hook, default, app_name)
with patch.object(frappe, "get_hooks", patched_hooks):
yield
@IntegrationTestCase.registerAs(staticmethod)
@contextmanager
def change_settings(doctype, settings_dict=None, /, commit=False, **settings) -> None:
"""Temporarily: change settings in a settings doctype."""
import copy
if settings_dict is None:
settings_dict = settings
settings = frappe.get_doc(doctype)
previous_settings = copy.deepcopy(settings_dict)
for key in previous_settings:
previous_settings[key] = getattr(settings, key)
for key, value in settings_dict.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
# singles are cached by default, clear to avoid flake
frappe.db.value_cache[settings] = {}
if commit:
frappe.db.commit()
yield
settings = frappe.get_doc(doctype)
for key, value in previous_settings.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
if commit:
frappe.db.commit()
@IntegrationTestCase.registerAs(staticmethod)
@contextmanager
def switch_site(site: str) -> None:
"""Temporarily: drop current connection and switch to a different site."""
old_site = frappe.local.site
frappe.init(site, force=True)
frappe.connect()
yield
frappe.init(old_site, force=True)
frappe.connect()
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def enable_safe_exec() -> None:
"""Temporarily: enable safe exec (server scripts)."""
import os
from frappe.installer import update_site_config
from frappe.utils.safe_exec import SAFE_EXEC_CONFIG_KEY
conf = os.path.join(frappe.local.sites_path, "common_site_config.json")
update_site_config(SAFE_EXEC_CONFIG_KEY, 1, validate=False, site_config_path=conf)
yield
update_site_config(SAFE_EXEC_CONFIG_KEY, 0, validate=False, site_config_path=conf)
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def debug_on(*exceptions) -> None:
"""Temporarily: enter an interactive debugger on specified exceptions, default: (AssertionError,)."""
import pdb
import sys
import traceback
if not exceptions:
exceptions = (AssertionError,)
try:
yield
except exceptions as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
# Pretty print the exception
print("\n\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[93m" + str(exc_type.__name__) + ": " + str(exc_value) + "\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
# Print the formatted traceback
traceback_lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
for line in traceback_lines:
print("\033[96m" + line.rstrip() + "\033[0m") # Cyan color
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[92mEntering post-mortem debugging\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
pdb.post_mortem()
raise e
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def timeout_context(seconds=30, error_message="Operation timed out.") -> None:
"""Temporarily: timeout an operation."""
import signal
def _handle_timeout(signum, frame):
raise Exception(error_message)
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(30 if callable(seconds) else seconds)
yield
signal.alarm(0)
def timeout(seconds=30, error_message="Operation timed out."):
"""Timeout decorator to ensure a test doesn't run for too long."""
def decorator(func=None):
@wraps(func)
def wrapper(*args, **kwargs):
with timeout_context(seconds, error_message):
return func(*args, **kwargs)
return wrapper
# Support bare @timeout
if callable(seconds):
return decorator(seconds)
return decorator
@UnitTestCase.registerAs(staticmethod)
@contextmanager
def trace_fields(
doc_class: type,
field_name: str | None = None,
forbidden_values: list | None = None,
custom_validation: Callable | None = None,
**field_configs: dict[str, dict[str, list | Callable]],
) -> "Document":
"""
A context manager for temporarily tracing fields in a DocType.
Can be used in two ways:
1. Tracing a single field:
trace_fields(DocType, "field_name", forbidden_values=[...], custom_validation=...)
2. Tracing multiple fields:
trace_fields(DocType, field1={"forbidden_values": [...], "custom_validation": ...}, ...)
Args:
doc_class (Document): The DocType class to modify.
field_name (str, optional): The name of the field to trace (for single field tracing).
forbidden_values (list, optional): A list of forbidden values for the field (for single field tracing).
custom_validation (callable, optional): A custom validation function (for single field tracing).
**field_configs: Keyword arguments for multiple field tracing, where each key is a field name and
the value is a dict containing 'forbidden_values' and/or 'custom_validation'.
Yields:
Document class
"""
from frappe.model.trace import traced_field
original_attrs = {}
original_init = doc_class.__init__
# Prepare configurations
if field_name:
field_configs = {
field_name: {"forbidden_values": forbidden_values, "custom_validation": custom_validation}
}
# Apply traced fields
for f_name, config in field_configs.items():
original_attrs[f_name] = getattr(doc_class, f_name, None)
f_forbidden_values = config.get("forbidden_values")
f_custom_validation = config.get("custom_validation")
setattr(doc_class, f_name, traced_field(f_name, f_forbidden_values, f_custom_validation))
# Modify init method
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
for f_name in field_configs:
setattr(self, f"_{f_name}", getattr(self, f_name, None))
doc_class.__init__ = new_init
yield doc_class
# Restore original attributes and init method
for f_name, original_attr in original_attrs.items():
if original_attr is not None:
setattr(doc_class, f_name, original_attr)
else:
delattr(doc_class, f_name)
doc_class.__init__ = original_init
# NOTE: declare those who should also be made available directly frappe.tests.* namespace
# these can be general purpose context managers who do NOT depend on a particular
# test class setup, such as for example the IntegrationTestCase's connection to site
__all__ = [
"freeze_time",
"set_user",
"patch_hooks",
"change_settings",
"switch_site",
"enable_safe_exec",
"debug_on",
"timeout_context",
"timeout",
"trace_fields",
]
def debug_on(*exceptions):
"""
A decorator to automatically start the debugger when specified exceptions occur.
This decorator allows you to automatically invoke the debugger (pdb) when certain
exceptions are raised in the decorated function. If no exceptions are specified,
it defaults to catching AssertionError.
Args:
*exceptions: Variable length argument list of exception classes to catch.
If none provided, defaults to (AssertionError,).
Returns:
function: A decorator function.
Usage:
1. Basic usage (catches AssertionError):
@debug_on()
def test_assertion_error():
assert False, "This will start the debugger"
2. Catching specific exceptions:
@debug_on(ValueError, TypeError)
def test_specific_exceptions():
raise ValueError("This will start the debugger")
3. Using on a method in a test class:
class TestMyFunctionality(unittest.TestCase):
@debug_on(ZeroDivisionError)
def test_division_by_zero(self):
result = 1 / 0
Note:
When an exception is caught, this decorator will print the exception traceback
and then start the post-mortem debugger, allowing you to inspect the state of
the program at the point where the exception was raised.
"""
if not exceptions:
exceptions = (AssertionError,)
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except exceptions as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
# Pretty print the exception
print("\n\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[93m" + str(exc_type.__name__) + ": " + str(exc_value) + "\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
# Print the formatted traceback
traceback_lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
for line in traceback_lines:
print("\033[96m" + line.rstrip() + "\033[0m") # Cyan color
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[92mEntering post-mortem debugging\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
pdb.post_mortem()
raise e
return wrapper
return decorator
def timeout(seconds=30, error_message="Test timed out."):
"""Timeout decorator to ensure a test doesn't run for too long.
adapted from https://stackoverflow.com/a/2282656"""
# Support @timeout (without function call)
no_args = bool(callable(seconds))
actual_timeout = 30 if no_args else seconds
actual_error_message = "Test timed out" if no_args else error_message
def decorator(func):
def _handle_timeout(signum, frame):
raise Exception(actual_error_message)
def wrapper(*args, **kwargs):
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(actual_timeout)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
return wrapper
if no_args:
return decorator(seconds)
return decorator

View file

@ -31,6 +31,8 @@ class IntegrationTestCase(UnitTestCase):
@classmethod
def setUpClass(cls) -> None:
if getattr(cls, "_integration_test_case_class_setup_done", None):
return
super().setUpClass()
# Site initialization
@ -57,6 +59,7 @@ class IntegrationTestCase(UnitTestCase):
# enqueue teardown actions (executed in LIFO order)
cls.addClassCleanup(_restore_thread_locals, copy.deepcopy(frappe.local.flags))
cls.addClassCleanup(_rollback_db)
cls._integration_test_case_class_setup_done = True
@classmethod
def tearDownClass(cls) -> None:
@ -161,68 +164,6 @@ class IntegrationTestCase(UnitTestCase):
finally:
frappe.db.sql = orig_sql
@contextmanager
def switch_site(self, site: str) -> AbstractContextManager[None]:
"""Switch connection to different site.
Note: Drops current site connection completely."""
try:
old_site = frappe.local.site
frappe.init(site, force=True)
frappe.connect()
yield
finally:
frappe.init(old_site, force=True)
frappe.connect()
@staticmethod
@contextmanager
def change_settings(doctype, settings_dict=None, /, commit=False, **settings):
"""A context manager to ensure that settings are changed before running
function and restored after running it regardless of exceptions occurred.
This is useful in tests where you want to make changes in a function but
don't retain those changes.
import and use as decorator to cover full function or using `with` statement.
example:
@change_settings("Print Settings", {"send_print_as_pdf": 1})
def test_case(self):
...
@change_settings("Print Settings", send_print_as_pdf=1)
def test_case(self):
...
"""
if settings_dict is None:
settings_dict = settings
try:
settings = frappe.get_doc(doctype)
# remember setting
previous_settings = copy.deepcopy(settings_dict)
for key in previous_settings:
previous_settings[key] = getattr(settings, key)
# change setting
for key, value in settings_dict.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
# singles are cached by default, clear to avoid flake
frappe.db.value_cache[settings] = {}
if commit:
frappe.db.commit()
yield # yield control to calling function
finally:
# restore settings
settings = frappe.get_doc(doctype)
for key, value in previous_settings.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
if commit:
frappe.db.commit()
def _commit_watcher():
import traceback

View file

@ -4,26 +4,29 @@ import logging
import os
import unittest
from collections.abc import Sequence
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytz
import frappe
from frappe.model.base_document import BaseDocument
from frappe.utils import cint
from frappe.utils.data import get_datetime, get_system_timezone
from .context_managers import debug_on
logger = logging.Logger(__file__)
datetime_like_types = (datetime.datetime, datetime.date, datetime.time, datetime.timedelta)
class UnitTestCase(unittest.TestCase):
class BaseTestCase:
@classmethod
def registerAs(cls, _as):
def decorator(cm_func):
setattr(cls, cm_func.__name__, _as(cm_func))
return cm_func
return decorator
class UnitTestCase(unittest.TestCase, BaseTestCase):
"""Unit test class for Frappe tests.
This class extends unittest.TestCase and provides additional utilities
@ -41,27 +44,12 @@ class UnitTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
if getattr(cls, "_unit_test_case_class_setup_done", None):
return
super().setUpClass()
cls.doctype = cls._get_doctype_from_module()
cls.doctype = _get_doctype_from_module(cls)
cls.module = frappe.get_module(cls.__module__)
@classmethod
def _get_doctype_from_module(cls):
module_path = cls.__module__.split(".")
try:
doctype_index = module_path.index("doctype")
doctype_snake_case = module_path[doctype_index + 1]
json_file_path = Path(*module_path[:-1]).joinpath(f"{doctype_snake_case}.json")
if json_file_path.is_file():
doctype_data = json.loads(json_file_path.read_text())
return doctype_data.get("name")
except (ValueError, IndexError):
# 'doctype' not found in module_path
pass
return None
def _apply_debug_decorator(self, exceptions=()):
setattr(self, self._testMethodName, debug_on(*exceptions)(getattr(self, self._testMethodName)))
cls._unit_test_case_class_setup_done = True
def assertQueryEqual(self, first: str, second: str) -> None:
self.assertEqual(self.normalize_sql(first), self.normalize_sql(second))
@ -101,65 +89,31 @@ class UnitTestCase(unittest.TestCase):
else:
self.assertEqual(expected, actual, msg=msg)
def normalize_html(self, code: str) -> str:
@staticmethod
def normalize_html(code: str) -> str:
"""Formats HTML consistently so simple string comparisons can work on them."""
from bs4 import BeautifulSoup
return BeautifulSoup(code, "html.parser").prettify(formatter=None)
@contextmanager
def set_user(self, user: str) -> AbstractContextManager[None]:
try:
old_user = frappe.session.user
frappe.set_user(user)
yield
finally:
frappe.set_user(old_user)
def normalize_sql(self, query: str) -> str:
@staticmethod
def normalize_sql(query: str) -> str:
"""Formats SQL consistently so simple string comparisons can work on them."""
import sqlparse
return sqlparse.format(query.strip(), keyword_case="upper", reindent=True, strip_comments=True)
@classmethod
def enable_safe_exec(cls) -> None:
"""Enable safe exec and disable them after test case is completed."""
from frappe.installer import update_site_config
from frappe.utils.safe_exec import SAFE_EXEC_CONFIG_KEY
cls._common_conf = os.path.join(frappe.local.sites_path, "common_site_config.json")
update_site_config(SAFE_EXEC_CONFIG_KEY, 1, validate=False, site_config_path=cls._common_conf)
cls.addClassCleanup(
lambda: update_site_config(
SAFE_EXEC_CONFIG_KEY, 0, validate=False, site_config_path=cls._common_conf
)
)
@staticmethod
@contextmanager
def patch_hooks(overridden_hooks: dict) -> AbstractContextManager[None]:
get_hooks = frappe.get_hooks
def patched_hooks(hook=None, default="_KEEP_DEFAULT_LIST", app_name=None):
if hook in overridden_hooks:
return overridden_hooks[hook]
return get_hooks(hook, default, app_name)
with patch.object(frappe, "get_hooks", patched_hooks):
yield
@contextmanager
def freeze_time(
self, time_to_freeze: Any, is_utc: bool = False, *args: Any, **kwargs: Any
) -> AbstractContextManager[None]:
from freezegun import freeze_time
if not is_utc:
# Freeze time expects UTC or tzaware objects. We have neither, so convert to UTC.
timezone = pytz.timezone(get_system_timezone())
time_to_freeze = timezone.localize(get_datetime(time_to_freeze)).astimezone(pytz.utc)
with freeze_time(time_to_freeze, *args, **kwargs):
yield
def _get_doctype_from_module(cls):
module_path = cls.__module__.split(".")
try:
doctype_index = module_path.index("doctype")
doctype_snake_case = module_path[doctype_index + 1]
json_file_path = Path(*module_path[:-1]).joinpath(f"{doctype_snake_case}.json")
if json_file_path.is_file():
doctype_data = json.loads(json_file_path.read_text())
return doctype_data.get("name")
except (ValueError, IndexError):
# 'doctype' not found in module_path
pass
return None

View file

@ -31,7 +31,7 @@ class TestBootData(IntegrationTestCase):
class TestPermissionQueries(IntegrationTestCase):
@classmethod
def setUpClass(cls) -> None:
cls.enable_safe_exec()
cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass()
def test_get_user_pages_or_reports_with_permission_query(self):

View file

@ -13,8 +13,6 @@ from frappe.tests import IntegrationTestCase
from frappe.utils import cint, now_datetime, set_request
from frappe.website.serve import get_response
from . import update_system_settings
class CustomTestNote(Note):
@property
@ -538,13 +536,13 @@ class TestDocumentWebView(IntegrationTestCase):
document_key = todo.get_document_share_key()
# with old-style signature key
update_system_settings({"allow_older_web_view_links": True}, True)
old_document_key = todo.get_signature()
url = f"/ToDo/{todo.name}?key={old_document_key}"
self.assertEqual(self.get(url).status, "200 OK")
with self.change_settings("System Settings", {"allow_older_web_view_links": True}):
old_document_key = todo.get_signature()
url = f"/ToDo/{todo.name}?key={old_document_key}"
self.assertEqual(self.get(url).status, "200 OK")
update_system_settings({"allow_older_web_view_links": False}, True)
self.assertEqual(self.get(url).status, "401 UNAUTHORIZED")
with self.change_settings("System Settings", {"allow_older_web_view_links": False}):
self.assertEqual(self.get(url).status, "401 UNAUTHORIZED")
# with valid key
url = f"/ToDo/{todo.name}?key={document_key}"

View file

@ -11,7 +11,7 @@ from frappe.utils.xlsxutils import make_xlsx
class TestQueryReport(IntegrationTestCase):
@classmethod
def setUpClass(cls) -> None:
cls.enable_safe_exec()
cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass()
def tearDown(self):

View file

@ -8,7 +8,7 @@ from frappe.utils.safe_exec import ServerScriptNotEnabled, get_safe_globals, saf
class TestSafeExec(IntegrationTestCase):
@classmethod
def setUpClass(cls) -> None:
cls.enable_safe_exec()
cls.enterClassContext(cls.enable_safe_exec())
return super().setUpClass()
def test_import_fails(self):

View file

@ -3,7 +3,8 @@ from unittest.mock import MagicMock, patch
import frappe
from frappe.model.document import Document
from frappe.model.trace import TracedDocument, trace_fields, traced_field, traced_field_context
from frappe.model.trace import TracedDocument, traced_field
from frappe.tests import UnitTestCase
def create_mock_meta(doctype):
@ -65,7 +66,7 @@ class TestTrace(unittest.TestCase):
self.assertEqual(valid_dict["positive_field"], 15)
class TestTracedFieldContext(unittest.TestCase):
class TestTracedFieldContext(UnitTestCase):
def test_traced_field_context(self):
doc = TestDocument()
@ -73,7 +74,7 @@ class TestTracedFieldContext(unittest.TestCase):
doc.test_field = "forbidden"
self.assertEqual(doc.test_field, "forbidden")
with traced_field_context(TestDocument, "test_field", forbidden_values=["forbidden"]):
with self.trace_fields(TestDocument, "test_field", forbidden_values=["forbidden"]):
# Inside context
with self.assertRaises(AssertionError):
doc.test_field = "forbidden"
@ -92,7 +93,7 @@ class TestTracedFieldContext(unittest.TestCase):
if value % 2 != 0:
raise ValueError("Value must be even")
with traced_field_context(TestDocument, "number_field", custom_validation=validate_even):
with self.trace_fields(TestDocument, "number_field", custom_validation=validate_even):
doc.number_field = 2
self.assertEqual(doc.number_field, 2)
@ -111,7 +112,7 @@ class TestTracedFieldContext(unittest.TestCase):
frappe.flags.in_test = False
try:
with traced_field_context(TestDocument, "test_field", forbidden_values=["forbidden"]):
with self.trace_fields(TestDocument, test_field={"forbidden_values": ["forbidden"]}):
with self.assertRaises(frappe.exceptions.ValidationError):
doc.test_field = "forbidden"
@ -126,38 +127,5 @@ class TestTracedFieldContext(unittest.TestCase):
self.assertEqual(doc.test_field, "forbidden")
def validate_positive(obj, value):
if value <= 0:
raise ValueError("Value must be positive")
class TestTraceFieldDecorator(unittest.TestCase):
@trace_fields(decorated_field={"forbidden_values": ["bad"]})
class DecoratedTestDocument(TestDocument):
pass
def test_trace_field_decorator(self):
doc = self.DecoratedTestDocument()
with self.assertRaises(AssertionError):
doc.decorated_field = "bad"
doc.decorated_field = "good"
self.assertEqual(doc.decorated_field, "good")
@trace_fields(positive_field={"custom_validation": validate_positive})
class PositiveFieldDocument(TestDocument):
pass
def test_trace_field_decorator_custom_validation(self):
doc = self.PositiveFieldDocument()
with self.assertRaises(AssertionError):
doc.positive_field = -1
doc.positive_field = 1
self.assertEqual(doc.positive_field, 1)
if __name__ == "__main__":
unittest.main()

View file

@ -20,26 +20,19 @@ from frappe.twofactor import (
)
from frappe.utils import cint, set_request
from . import get_system_setting, update_system_settings
class TestTwoFactor(IntegrationTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.default_allowed_login_attempts = get_system_setting("allow_consecutive_login_attempts")
def setUp(self):
self.http_requests = create_http_request()
self.login_manager = frappe.local.login_manager
self.user = self.login_manager.user
update_system_settings({"allow_consecutive_login_attempts": 2})
self.enterContext(self.change_settings("System Settings", {"allow_consecutive_login_attempts": 2}))
def tearDown(self):
frappe.local.response["verification"] = None
frappe.local.response["tmp_id"] = None
disable_2fa()
frappe.clear_cache(user=self.user)
update_system_settings({"allow_consecutive_login_attempts": self.default_allowed_login_attempts})
def test_should_run_2fa(self):
"""Should return true if enabled."""

View file

@ -24,3 +24,26 @@ def check_orpahned_doctypes():
frappe.throw(
"Following doctypes exist in DB without controller.\n {}".format("\n".join(orpahned_doctypes))
)
from frappe.deprecation_dumpster import (
tests_change_settings as change_settings,
)
from frappe.deprecation_dumpster import (
tests_debug_on as debug_on,
)
from frappe.deprecation_dumpster import (
tests_FrappeTestCase as FrappeTestCase,
)
from frappe.deprecation_dumpster import (
tests_IntegrationTestCase as IntegrationTestCase,
)
from frappe.deprecation_dumpster import (
tests_patch_hooks as patch_hooks,
)
from frappe.deprecation_dumpster import (
tests_timeout as timeout,
)
from frappe.deprecation_dumpster import (
tests_UnitTestCase as UnitTestCase,
)