diff --git a/frappe/__init__.py b/frappe/__init__.py index fff4f475f1..9e919a5c08 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -716,23 +716,6 @@ xss_safe_methods = [] allowed_http_methods_for_whitelisted_func = {} -def apply_validate_argument_types_wrapper(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - """Validate argument types of whitelisted functions. - - :param args: Function arguments. - :param kwargs: Function keyword arguments.""" - from frappe.utils.typing_validations import transform_parameter_types - - if getattr(local, "request", None) or local.flags.in_test: - args, kwargs = transform_parameter_types(func, args, kwargs) - - return func(*args, **kwargs) - - return wrapper - - def whitelist(allow_guest=False, xss_safe=False, methods=None): """ Decorator for whitelisting a function and making it accessible via HTTP. @@ -752,16 +735,21 @@ def whitelist(allow_guest=False, xss_safe=False, methods=None): methods = ["GET", "POST", "PUT", "DELETE"] def innerfn(fn): + from frappe.utils.typing_validations import validate_argument_types + global whitelisted, guest_methods, xss_safe_methods, allowed_http_methods_for_whitelisted_func + # validate argument types only if request is present + in_request_or_test = lambda: getattr(local, "request", None) or local.flags.in_test # noqa: E731 + # get function from the unbound / bound method # this is needed because functions can be compared, but not methods method = None if hasattr(fn, "__func__"): - method = apply_validate_argument_types_wrapper(fn) + method = validate_argument_types(fn, apply_condition=in_request_or_test) fn = method.__func__ else: - fn = apply_validate_argument_types_wrapper(fn) + fn = validate_argument_types(fn, apply_condition=in_request_or_test) whitelisted.append(fn) allowed_http_methods_for_whitelisted_func[fn] = methods diff --git a/frappe/utils/typing_validations.py b/frappe/utils/typing_validations.py index f9773c2c4e..e7ebcfbdff 100644 --- a/frappe/utils/typing_validations.py +++ b/frappe/utils/typing_validations.py @@ -1,4 +1,4 @@ -from functools import lru_cache +from functools import lru_cache, wraps from inspect import _empty, isclass, signature from types import EllipsisType from typing import Any, Callable, ForwardRef, TypeVar, Union @@ -19,6 +19,22 @@ class FrappePydanticConfig: arbitrary_types_allowed = True +def validate_argument_types(func: Callable, apply_condition: Callable = lambda: True): + @wraps(func) + def wrapper(*args, **kwargs): + """Validate argument types of whitelisted functions. + + :param args: Function arguments. + :param kwargs: Function keyword arguments.""" + + if apply_condition(): + args, kwargs = transform_parameter_types(func, args, kwargs) + + return func(*args, **kwargs) + + return wrapper + + def qualified_name(obj) -> str: """ Return the qualified name (e.g. package.module.Type) for the given object.