274 lines
8.1 KiB
Python
274 lines
8.1 KiB
Python
import logging
|
|
from collections.abc import Callable
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
from inspect import isfunction, ismethod
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import frappe
|
|
|
|
from .integration_test_case import IntegrationTestCase
|
|
from .unit_test_case import UnitTestCase
|
|
|
|
if TYPE_CHECKING:
|
|
from frappe.model import Document
|
|
|
|
# NOTE: please lazily import any further namespaces within the contextmanager below
|
|
|
|
logger = logging.Logger(__file__)
|
|
|
|
###############################################################
|
|
# Decorators and Context Managers Implementation
|
|
# (each cm is automatically a decorator)
|
|
# NOTE: Keep all imports local to the decorator (!)
|
|
###############################################################
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def freeze_time(time_to_freeze: Any, is_utc: bool = False, *args: Any, **kwargs: Any) -> None:
|
|
"""Temporarily: freeze time with freezegun."""
|
|
import pytz
|
|
from freezegun import freeze_time as freezegun_freeze_time
|
|
|
|
from frappe.utils.data import get_datetime, get_system_timezone
|
|
|
|
if not is_utc:
|
|
# Freeze time expects UTC or tzaware objects. We have neither, so convert to UTC.
|
|
timezone = pytz.timezone(get_system_timezone())
|
|
time_to_freeze = timezone.localize(get_datetime(time_to_freeze)).astimezone(pytz.utc)
|
|
|
|
with freezegun_freeze_time(time_to_freeze, *args, **kwargs):
|
|
yield
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def set_user(user: str) -> None:
|
|
"""Temporarily: set the user."""
|
|
old_user = frappe.session.user
|
|
frappe.set_user(user)
|
|
yield
|
|
frappe.set_user(old_user)
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def patch_hooks(overridden_hooks: dict) -> None:
|
|
"""Temporarily: patch a hook."""
|
|
from unittest.mock import patch
|
|
|
|
get_hooks = frappe.get_hooks
|
|
|
|
def patched_hooks(hook=None, default="_KEEP_DEFAULT_LIST", app_name=None):
|
|
if hook in overridden_hooks:
|
|
return overridden_hooks[hook]
|
|
return get_hooks(hook, default, app_name)
|
|
|
|
with patch.object(frappe, "get_hooks", patched_hooks):
|
|
yield
|
|
|
|
|
|
@IntegrationTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def change_settings(doctype, settings_dict=None, /, commit=False, **settings) -> None:
|
|
"""Temporarily: change settings in a settings doctype."""
|
|
import copy
|
|
|
|
if settings_dict is None:
|
|
settings_dict = settings
|
|
|
|
settings = frappe.get_doc(doctype)
|
|
previous_settings = copy.deepcopy(settings_dict)
|
|
for key in previous_settings:
|
|
previous_settings[key] = getattr(settings, key)
|
|
|
|
for key, value in settings_dict.items():
|
|
setattr(settings, key, value)
|
|
settings.save(ignore_permissions=True)
|
|
# singles are cached by default, clear to avoid flake
|
|
frappe.db.value_cache[settings] = {}
|
|
if commit:
|
|
frappe.db.commit()
|
|
yield
|
|
settings = frappe.get_doc(doctype)
|
|
for key, value in previous_settings.items():
|
|
setattr(settings, key, value)
|
|
settings.save(ignore_permissions=True)
|
|
if commit:
|
|
frappe.db.commit()
|
|
|
|
|
|
@IntegrationTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def switch_site(site: str) -> None:
|
|
"""Temporarily: drop current connection and switch to a different site."""
|
|
old_site = frappe.local.site
|
|
frappe.init(site, force=True)
|
|
frappe.connect()
|
|
yield
|
|
frappe.init(old_site, force=True)
|
|
frappe.connect()
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def enable_safe_exec() -> None:
|
|
"""Temporarily: enable safe exec (server scripts)."""
|
|
import os
|
|
|
|
from frappe.installer import update_site_config
|
|
from frappe.utils.safe_exec import SAFE_EXEC_CONFIG_KEY
|
|
|
|
conf = os.path.join(frappe.local.sites_path, "common_site_config.json")
|
|
update_site_config(SAFE_EXEC_CONFIG_KEY, 1, validate=False, site_config_path=conf)
|
|
yield
|
|
update_site_config(SAFE_EXEC_CONFIG_KEY, 0, validate=False, site_config_path=conf)
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def debug_on(*exceptions) -> None:
|
|
"""Temporarily: enter an interactive debugger on specified exceptions, default: (AssertionError,)."""
|
|
import pdb
|
|
import sys
|
|
import traceback
|
|
|
|
if not exceptions:
|
|
exceptions = (AssertionError,)
|
|
|
|
try:
|
|
yield
|
|
except exceptions as e:
|
|
exc_type, exc_value, exc_traceback = sys.exc_info()
|
|
# Pretty print the exception
|
|
print("\n\033[91m" + "=" * 60 + "\033[0m") # Red line
|
|
print("\033[93m" + str(exc_type.__name__) + ": " + str(exc_value) + "\033[0m")
|
|
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
|
|
|
|
# Print the formatted traceback
|
|
traceback_lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
|
|
for line in traceback_lines:
|
|
print("\033[96m" + line.rstrip() + "\033[0m") # Cyan color
|
|
|
|
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
|
|
print("\033[92mEntering post-mortem debugging\033[0m")
|
|
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
|
|
pdb.post_mortem()
|
|
|
|
raise e
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def timeout_context(seconds=30, error_message="Operation timed out.") -> None:
|
|
"""Temporarily: timeout an operation."""
|
|
import signal
|
|
|
|
def _handle_timeout(signum, frame):
|
|
raise Exception(error_message)
|
|
|
|
signal.signal(signal.SIGALRM, _handle_timeout)
|
|
signal.alarm(30 if callable(seconds) else seconds)
|
|
yield
|
|
signal.alarm(0)
|
|
|
|
|
|
def timeout(seconds=30, error_message="Operation timed out."):
|
|
"""Timeout decorator to ensure a test doesn't run for too long."""
|
|
|
|
def decorator(func=None):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
with timeout_context(seconds, error_message):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
# Support bare @timeout
|
|
if callable(seconds):
|
|
return decorator(seconds)
|
|
return decorator
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def trace_fields(
|
|
doc_class: type,
|
|
field_name: str | None = None,
|
|
forbidden_values: list | None = None,
|
|
custom_validation: Callable | None = None,
|
|
**field_configs: dict[str, list | Callable | None],
|
|
) -> "Document":
|
|
"""
|
|
A context manager for temporarily tracing fields in a DocType.
|
|
|
|
Can be used in two ways:
|
|
1. Tracing a single field:
|
|
trace_fields(DocType, "field_name", forbidden_values=[...], custom_validation=...)
|
|
2. Tracing multiple fields:
|
|
trace_fields(DocType, field1={"forbidden_values": [...], "custom_validation": ...}, ...)
|
|
|
|
Args:
|
|
doc_class (Document): The DocType class to modify.
|
|
field_name (str, optional): The name of the field to trace (for single field tracing).
|
|
forbidden_values (list, optional): A list of forbidden values for the field (for single field tracing).
|
|
custom_validation (callable, optional): A custom validation function (for single field tracing).
|
|
**field_configs: Keyword arguments for multiple field tracing, where each key is a field name and
|
|
the value is a dict containing 'forbidden_values' and/or 'custom_validation'.
|
|
|
|
Yields:
|
|
Document class
|
|
"""
|
|
from frappe.model.trace import traced_field
|
|
|
|
original_attrs = {}
|
|
original_init = doc_class.__init__
|
|
|
|
# Prepare configurations
|
|
if field_name:
|
|
field_configs = {
|
|
field_name: {"forbidden_values": forbidden_values, "custom_validation": custom_validation}
|
|
}
|
|
|
|
# Apply traced fields
|
|
for f_name, config in field_configs.items():
|
|
original_attrs[f_name] = getattr(doc_class, f_name, None)
|
|
f_forbidden_values = config.get("forbidden_values")
|
|
f_custom_validation = config.get("custom_validation")
|
|
setattr(doc_class, f_name, traced_field(f_name, f_forbidden_values, f_custom_validation))
|
|
|
|
# Modify init method
|
|
def new_init(self, *args, **kwargs):
|
|
original_init(self, *args, **kwargs)
|
|
for f_name in field_configs:
|
|
setattr(self, f"_{f_name}", getattr(self, f_name, None))
|
|
|
|
doc_class.__init__ = new_init
|
|
|
|
yield doc_class
|
|
|
|
# Restore original attributes and init method
|
|
for f_name, original_attr in original_attrs.items():
|
|
if original_attr is not None:
|
|
setattr(doc_class, f_name, original_attr)
|
|
else:
|
|
delattr(doc_class, f_name)
|
|
doc_class.__init__ = original_init
|
|
|
|
|
|
# NOTE: declare those who should also be made available directly frappe.tests.* namespace
|
|
# these can be general purpose context managers who do NOT depend on a particular
|
|
# test class setup, such as for example the IntegrationTestCase's connection to site
|
|
__all__ = [
|
|
"freeze_time",
|
|
"set_user",
|
|
"patch_hooks",
|
|
"change_settings",
|
|
"switch_site",
|
|
"enable_safe_exec",
|
|
"debug_on",
|
|
"timeout_context",
|
|
"timeout",
|
|
"trace_fields",
|
|
]
|