This is middle ground between caching it completely and requiring a restart/signal to reload vs always reloading it. I don't know any use cases that can break from this, nowhere in code configs should be expected to reload instantly. This change is only applied to requests for now
277 lines
8.1 KiB
Python
277 lines
8.1 KiB
Python
import logging
|
|
from collections.abc import Callable
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import frappe
|
|
|
|
from .integration_test_case import IntegrationTestCase
|
|
from .unit_test_case import UnitTestCase
|
|
|
|
if TYPE_CHECKING:
|
|
from frappe.model import Document
|
|
|
|
# NOTE: please lazily import any further namespaces within the contextmanager below
|
|
|
|
logger = logging.Logger(__file__)
|
|
|
|
###############################################################
|
|
# Decorators and Context Managers Implementation
|
|
# (each cm is automatically a decorator)
|
|
# NOTE: Keep all imports local to the decorator (!)
|
|
###############################################################
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def freeze_time(time_to_freeze: Any, is_utc: bool = False, *args: Any, **kwargs: Any) -> None:
|
|
"""Temporarily: freeze time with freezegun."""
|
|
from datetime import UTC
|
|
from zoneinfo import ZoneInfo
|
|
|
|
from freezegun import freeze_time as freezegun_freeze_time
|
|
|
|
from frappe.utils.data import get_datetime, get_system_timezone
|
|
|
|
if not is_utc:
|
|
# Freeze time expects UTC or tzaware objects. We have neither, so convert to UTC.
|
|
time_to_freeze = (
|
|
get_datetime(time_to_freeze).replace(tzinfo=ZoneInfo(get_system_timezone())).astimezone(UTC)
|
|
)
|
|
|
|
with freezegun_freeze_time(time_to_freeze, *args, **kwargs):
|
|
yield
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def set_user(user: str) -> None:
|
|
"""Temporarily: set the user."""
|
|
old_user = frappe.session.user
|
|
frappe.set_user(user)
|
|
yield
|
|
frappe.set_user(old_user)
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def patch_hooks(overridden_hooks: dict) -> None:
|
|
"""Temporarily: patch a hook."""
|
|
from unittest.mock import patch
|
|
|
|
get_hooks = frappe.get_hooks
|
|
|
|
def patched_hooks(hook=None, default="_KEEP_DEFAULT_LIST", app_name=None):
|
|
if hook in overridden_hooks:
|
|
return overridden_hooks[hook]
|
|
return get_hooks(hook, default, app_name)
|
|
|
|
with patch.object(frappe, "get_hooks", patched_hooks):
|
|
yield
|
|
|
|
|
|
@IntegrationTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def change_settings(doctype, settings_dict=None, /, commit=False, **settings) -> None:
|
|
"""Temporarily: change settings in a settings doctype."""
|
|
import copy
|
|
|
|
if settings_dict is None:
|
|
settings_dict = settings
|
|
|
|
settings = frappe.get_doc(doctype)
|
|
previous_settings = copy.deepcopy(settings_dict)
|
|
for key in previous_settings:
|
|
previous_settings[key] = getattr(settings, key)
|
|
|
|
for key, value in settings_dict.items():
|
|
setattr(settings, key, value)
|
|
settings.save(ignore_permissions=True)
|
|
# singles are cached by default, clear to avoid flake
|
|
frappe.db.value_cache[settings] = {}
|
|
if commit:
|
|
frappe.db.commit()
|
|
yield
|
|
settings = frappe.get_doc(doctype)
|
|
for key, value in previous_settings.items():
|
|
setattr(settings, key, value)
|
|
settings.save(ignore_permissions=True)
|
|
if commit:
|
|
frappe.db.commit()
|
|
|
|
|
|
@IntegrationTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def switch_site(site: str) -> None:
|
|
"""Temporarily: drop current connection and switch to a different site."""
|
|
old_site = frappe.local.site
|
|
frappe.init(site, force=True)
|
|
frappe.connect()
|
|
yield
|
|
frappe.destroy()
|
|
frappe.init(old_site, force=True)
|
|
frappe.connect()
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def enable_safe_exec() -> None:
|
|
"""Temporarily: enable safe exec (server scripts)."""
|
|
import os
|
|
|
|
from frappe.installer import update_site_config
|
|
from frappe.utils.safe_exec import SAFE_EXEC_CONFIG_KEY
|
|
|
|
conf = os.path.join(frappe.local.sites_path, "common_site_config.json")
|
|
update_site_config(SAFE_EXEC_CONFIG_KEY, 1, validate=False, site_config_path=conf)
|
|
yield
|
|
update_site_config(SAFE_EXEC_CONFIG_KEY, 0, validate=False, site_config_path=conf)
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def debug_on(*exceptions) -> None:
|
|
"""Temporarily: enter an interactive debugger on specified exceptions, default: (AssertionError,)."""
|
|
import pdb
|
|
import sys
|
|
import traceback
|
|
|
|
if not exceptions:
|
|
exceptions = (AssertionError,)
|
|
|
|
try:
|
|
yield
|
|
except exceptions as e:
|
|
exc_type, exc_value, exc_traceback = sys.exc_info()
|
|
# Pretty print the exception
|
|
print("\n\033[91m" + "=" * 60 + "\033[0m") # Red line
|
|
print("\033[93m" + str(exc_type.__name__) + ": " + str(exc_value) + "\033[0m")
|
|
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
|
|
|
|
# Print the formatted traceback
|
|
traceback_lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
|
|
for line in traceback_lines:
|
|
print("\033[96m" + line.rstrip() + "\033[0m") # Cyan color
|
|
|
|
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
|
|
print("\033[92mEntering post-mortem debugging\033[0m")
|
|
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
|
|
pdb.post_mortem()
|
|
|
|
raise e
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def timeout_context(seconds=30, error_message="Operation timed out.") -> None:
|
|
"""Temporarily: timeout an operation."""
|
|
import signal
|
|
|
|
def _handle_timeout(signum, frame):
|
|
raise Exception(error_message)
|
|
|
|
signal.signal(signal.SIGALRM, _handle_timeout)
|
|
signal.alarm(30 if callable(seconds) else seconds)
|
|
yield
|
|
signal.alarm(0)
|
|
|
|
|
|
def timeout(seconds=30, error_message="Operation timed out."):
|
|
"""Timeout decorator to ensure a test doesn't run for too long."""
|
|
|
|
def decorator(func=None):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
with timeout_context(seconds, error_message):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
# Support bare @timeout
|
|
if callable(seconds):
|
|
return decorator(seconds)
|
|
return decorator
|
|
|
|
|
|
@UnitTestCase.registerAs(staticmethod)
|
|
@contextmanager
|
|
def trace_fields(
|
|
doc_class: type,
|
|
field_name: str | None = None,
|
|
forbidden_values: list | None = None,
|
|
custom_validation: Callable | None = None,
|
|
**field_configs: dict[str, list | Callable | None],
|
|
) -> "Document":
|
|
"""
|
|
A context manager for temporarily tracing fields in a DocType.
|
|
|
|
Can be used in two ways:
|
|
1. Tracing a single field:
|
|
trace_fields(DocType, "field_name", forbidden_values=[...], custom_validation=...)
|
|
2. Tracing multiple fields:
|
|
trace_fields(DocType, field1={"forbidden_values": [...], "custom_validation": ...}, ...)
|
|
|
|
Args:
|
|
doc_class (Document): The DocType class to modify.
|
|
field_name (str, optional): The name of the field to trace (for single field tracing).
|
|
forbidden_values (list, optional): A list of forbidden values for the field (for single field tracing).
|
|
custom_validation (callable, optional): A custom validation function (for single field tracing).
|
|
**field_configs: Keyword arguments for multiple field tracing, where each key is a field name and
|
|
the value is a dict containing 'forbidden_values' and/or 'custom_validation'.
|
|
|
|
Yields:
|
|
Document class
|
|
"""
|
|
from frappe.model.trace import traced_field
|
|
|
|
original_attrs = {}
|
|
original_init = doc_class.__init__
|
|
|
|
# Prepare configurations
|
|
if field_name:
|
|
field_configs = {
|
|
field_name: {"forbidden_values": forbidden_values, "custom_validation": custom_validation}
|
|
}
|
|
|
|
# Apply traced fields
|
|
for f_name, config in field_configs.items():
|
|
original_attrs[f_name] = getattr(doc_class, f_name, None)
|
|
f_forbidden_values = config.get("forbidden_values")
|
|
f_custom_validation = config.get("custom_validation")
|
|
setattr(doc_class, f_name, traced_field(f_name, f_forbidden_values, f_custom_validation))
|
|
|
|
# Modify init method
|
|
def new_init(self, *args, **kwargs):
|
|
original_init(self, *args, **kwargs)
|
|
for f_name in field_configs:
|
|
setattr(self, f"_{f_name}", getattr(self, f_name, None))
|
|
|
|
doc_class.__init__ = new_init
|
|
|
|
yield doc_class
|
|
|
|
# Restore original attributes and init method
|
|
for f_name, original_attr in original_attrs.items():
|
|
if original_attr is not None:
|
|
setattr(doc_class, f_name, original_attr)
|
|
else:
|
|
delattr(doc_class, f_name)
|
|
doc_class.__init__ = original_init
|
|
|
|
|
|
# NOTE: declare those who should also be made available directly frappe.tests.* namespace
|
|
# these can be general purpose context managers who do NOT depend on a particular
|
|
# test class setup, such as for example the IntegrationTestCase's connection to site
|
|
__all__ = [
|
|
"change_settings",
|
|
"debug_on",
|
|
"enable_safe_exec",
|
|
"freeze_time",
|
|
"patch_hooks",
|
|
"set_user",
|
|
"switch_site",
|
|
"timeout",
|
|
"timeout_context",
|
|
"trace_fields",
|
|
]
|