diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bf4dcdd27d..e1bea7cc91 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,6 +19,7 @@ repos: - id: check-toml - id: check-yaml - id: debug-statements + exclude: ^frappe/tests/utils\.py$ - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.2.0 diff --git a/frappe/commands/utils.py b/frappe/commands/utils.py index 37fbf89492..b2d9a54532 100644 --- a/frappe/commands/utils.py +++ b/frappe/commands/utils.py @@ -752,6 +752,7 @@ def transform_database(context, table, engine, row_format, failfast): ) @click.option("--test", multiple=True, help="Specific test") @click.option("--module", help="Run tests in a module") +@click.option("--pdb", is_flag=True, default=False, help="Open pdb on AssertionError") @click.option("--profile", is_flag=True, default=False) @click.option("--coverage", is_flag=True, default=False) @click.option("--skip-test-records", is_flag=True, default=False, help="Don't create test records") @@ -776,9 +777,14 @@ def run_tests( skip_before_tests=False, failfast=False, case=None, + pdb=False, ): """Run python unit-tests""" + pdb_on_exceptions = None + if pdb: + pdb_on_exceptions = (AssertionError,) + with CodeCoverage(coverage, app): import frappe import frappe.test_runner @@ -810,6 +816,7 @@ def run_tests( case=case, skip_test_records=skip_test_records, skip_before_tests=skip_before_tests, + pdb_on_exceptions=pdb_on_exceptions, ) if len(ret.failures) == 0 and len(ret.errors) == 0: diff --git a/frappe/test_runner.py b/frappe/test_runner.py index 6f20de3d87..1c40d48232 100644 --- a/frappe/test_runner.py +++ b/frappe/test_runner.py @@ -53,6 +53,7 @@ def main( case=None, skip_test_records=False, skip_before_tests=False, + pdb_on_exceptions=False, ): global unittest_runner @@ -78,6 +79,7 @@ def main( try: frappe.flags.print_messages = verbose frappe.flags.in_test = True + frappe.flags.pdb_on_exceptions = pdb_on_exceptions # workaround! since there is no separate test db frappe.clear_cache() @@ -266,23 +268,34 @@ def _run_unittest( ): frappe.db.begin() - test_suite = unittest.TestSuite() + final_test_suite = unittest.TestSuite() if not isinstance(modules, list | tuple): modules = [modules] + def iterate_suite(suite): + for test in suite: + if isinstance(test, unittest.TestSuite): + yield from iterate_suite(test) + elif isinstance(test, unittest.TestCase): + yield test + for module in modules: if case: - module_test_cases = unittest.TestLoader().loadTestsFromTestCase(getattr(module, case)) + test_suite = unittest.TestLoader().loadTestsFromTestCase(getattr(module, case)) else: - module_test_cases = unittest.TestLoader().loadTestsFromModule(module) + test_suite = unittest.TestLoader().loadTestsFromModule(module) if tests: - for each in module_test_cases: - for test_case in each.__dict__["_tests"]: - if test_case.__dict__["_testMethodName"] in tests: - test_suite.addTest(test_case) + for test_case in iterate_suite(test_suite): + if test_case._testMethodName in tests: + final_test_suite.addTest(test_case) else: - test_suite.addTest(module_test_cases) + final_test_suite.addTest(test_suite) + + if frappe.flags.pdb_on_exceptions: + for test_case in iterate_suite(final_test_suite): + if hasattr(test_case, "_apply_debug_decorator"): + test_case._apply_debug_decorator(frappe.flags.pdb_on_exceptions) if junit_xml_output: runner = unittest_runner(verbosity=1 + cint(verbose), failfast=failfast) @@ -300,7 +313,7 @@ def _run_unittest( frappe.flags.tests_verbose = verbose - out = runner.run(test_suite) + out = runner.run(final_test_suite) if profile: pr.disable() diff --git a/frappe/tests/utils.py b/frappe/tests/utils.py index 2b90b7fc08..853399c841 100644 --- a/frappe/tests/utils.py +++ b/frappe/tests/utils.py @@ -1,7 +1,11 @@ import copy import datetime +import functools import os +import pdb import signal +import sys +import traceback import unittest from collections.abc import Sequence from contextlib import contextmanager @@ -17,6 +21,62 @@ from frappe.utils.data import convert_utc_to_timezone, get_datetime, get_system_ datetime_like_types = (datetime.datetime, datetime.date, datetime.time, datetime.timedelta) +def debug_on(*exceptions): + """ + A decorator to automatically start the debugger when specified exceptions occur. + + This decorator allows you to automatically invoke the debugger (pdb) when certain + exceptions are raised in the decorated function. If no exceptions are specified, + it defaults to catching AssertionError. + + Args: + *exceptions: Variable length argument list of exception classes to catch. + If none provided, defaults to (AssertionError,). + + Returns: + function: A decorator function. + + Usage: + 1. Basic usage (catches AssertionError): + @debug_on() + def test_assertion_error(): + assert False, "This will start the debugger" + + 2. Catching specific exceptions: + @debug_on(ValueError, TypeError) + def test_specific_exceptions(): + raise ValueError("This will start the debugger") + + 3. Using on a method in a test class: + class TestMyFunctionality(unittest.TestCase): + @debug_on(ZeroDivisionError) + def test_division_by_zero(self): + result = 1 / 0 + + Note: + When an exception is caught, this decorator will print the exception traceback + and then start the post-mortem debugger, allowing you to inspect the state of + the program at the point where the exception was raised. + """ + if not exceptions: + exceptions = (AssertionError,) + + def decorator(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except exceptions as e: + info = sys.exc_info() + traceback.print_exception(*info) + pdb.post_mortem(info[2]) + raise e + + return wrapper + + return decorator + + class FrappeTestCase(unittest.TestCase): """Base test class for Frappe tests. @@ -47,6 +107,9 @@ class FrappeTestCase(unittest.TestCase): return super().setUpClass() + def _apply_debug_decorator(self, exceptions=()): + setattr(self, self._testMethodName, debug_on(*exceptions)(getattr(self, self._testMethodName))) + def assertSequenceSubset(self, larger: Sequence, smaller: Sequence, msg=None): """Assert that `expected` is a subset of `actual`.""" self.assertTrue(set(smaller).issubset(set(larger)), msg=msg)