diff --git a/frappe/tests/test_utils.py b/frappe/tests/test_utils.py index d31bbdb7b7..11e9653372 100644 --- a/frappe/tests/test_utils.py +++ b/frappe/tests/test_utils.py @@ -1263,6 +1263,8 @@ class TestRounding(FrappeTestCase): class TestArgumentTypingValidations(FrappeTestCase): def test_validate_argument_types(self): + from unittest.mock import AsyncMock, MagicMock, Mock + from frappe.core.doctype.doctype.doctype import DocType from frappe.utils.typing_validations import ( FrappeTypeError, @@ -1281,6 +1283,10 @@ class TestArgumentTypingValidations(FrappeTestCase): def test_doctypes(a: DocType | dict): return a + @validate_argument_types + def test_mocks(a: str): + return a + self.assertEqual(test_simple_types(True, 2.0, True), (1, 2.0, True)) self.assertEqual(test_simple_types(1, 2, 1), (1, 2.0, True)) self.assertEqual(test_simple_types(1.0, 2, 1), (1, 2.0, True)) @@ -1304,6 +1310,13 @@ class TestArgumentTypingValidations(FrappeTestCase): with self.assertRaises(FrappeTypeError): test_doctypes("a") + self.assertEqual(test_mocks("Hello World"), "Hello World") + for obj in (AsyncMock, MagicMock, Mock): + obj_instance = obj() + self.assertEqual(test_mocks(obj_instance), obj_instance) + with self.assertRaises(FrappeTypeError): + test_mocks(1) + class TestChangeLog(FrappeTestCase): def test_check_release_on_github(self): diff --git a/frappe/utils/typing_validations.py b/frappe/utils/typing_validations.py index c11507f0c9..2b23a129cc 100644 --- a/frappe/utils/typing_validations.py +++ b/frappe/utils/typing_validations.py @@ -3,6 +3,7 @@ from functools import lru_cache, wraps from inspect import _empty, isclass, signature from types import EllipsisType from typing import ForwardRef, TypeVar, Union +from unittest import mock from pydantic import ConfigDict @@ -77,8 +78,8 @@ def transform_parameter_types(func: Callable, args: tuple, kwargs: dict): """ Validate the types of the arguments passed to a function with the type annotations defined on the function. - """ + if not (args or kwargs) or not func.__annotations__: return args, kwargs @@ -117,6 +118,9 @@ def transform_parameter_types(func: Callable, args: tuple, kwargs: dict): continue elif any(isinstance(x, ForwardRef | str) for x in getattr(current_arg_type, "__args__", [])): continue + # ignore unittest.mock objects + elif isinstance(current_arg_value, mock.Mock): + continue # allow slack for Frappe types if current_arg_type in SLACK_DICT: