Testing Improvements 3 (#27995)

* feat: set doctype on test classes

* refactor: Transform `make_test_records` into a generator

* feat: lazy create doctype records on first use

* perf: improve file walker

* fix: submission queue test

* refactor: improve logging a bit

* fix: global records install for app (semifix)
This commit is contained in:
David Arnold 2024-10-06 17:04:47 +02:00 committed by GitHub
parent c114e5fae8
commit c2c9d9062a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 159 additions and 147 deletions

View file

@ -10,6 +10,7 @@ be used to build database driven apps.
Read the documentation: https://frappeframework.com/docs Read the documentation: https://frappeframework.com/docs
""" """
import copy import copy
import faulthandler import faulthandler
import functools import functools
@ -23,6 +24,7 @@ import signal
import sys import sys
import traceback import traceback
import warnings import warnings
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeAlias, overload from typing import TYPE_CHECKING, Any, Literal, Optional, TypeAlias, overload
@ -269,7 +271,7 @@ def init(site: str, sites_path: str = ".", new_site: bool = False, force=False)
} }
) )
local.locked_documents = [] local.locked_documents = []
local.test_objects = {} local.test_objects = defaultdict(list)
local.site = site local.site = site
local.sites_path = sites_path local.sites_path = sites_path

View file

@ -763,7 +763,7 @@ def transform_database(context: CliCtxObj, table, engine, row_format, failfast):
@click.option("--pdb", is_flag=True, default=False, help="Open pdb on AssertionError") @click.option("--pdb", is_flag=True, default=False, help="Open pdb on AssertionError")
@click.option("--profile", is_flag=True, default=False) @click.option("--profile", is_flag=True, default=False)
@click.option("--coverage", is_flag=True, default=False) @click.option("--coverage", is_flag=True, default=False)
@click.option("--skip-test-records", is_flag=True, default=False, help="Don't create test records") @click.option("--skip-test-records", is_flag=True, default=False, help="DEPRECATED")
@click.option("--skip-before-tests", is_flag=True, default=False, help="Don't run before tests hook") @click.option("--skip-before-tests", is_flag=True, default=False, help="Don't run before tests hook")
@click.option("--junit-xml-output", help="Destination file path for junit xml report") @click.option("--junit-xml-output", help="Destination file path for junit xml report")
@click.option( @click.option(
@ -816,6 +816,12 @@ def run_tests(
click.secho(f"bench --site {site} set-config allow_tests true", fg="green") click.secho(f"bench --site {site} set-config allow_tests true", fg="green")
return return
if skip_test_records:
click.secho("--skip-test-records is deprecated and without effect!", bold=True)
click.secho("All records are loaded lazily on first use, so the flag is useless, now.")
click.secho("Simply remove the flag.", fg="green")
return
unit_ret, integration_ret = frappe.test_runner.main( unit_ret, integration_ret = frappe.test_runner.main(
site, site,
app, app,
@ -830,7 +836,6 @@ def run_tests(
doctype_list_path=doctype_list_path, doctype_list_path=doctype_list_path,
failfast=failfast, failfast=failfast,
case=case, case=case,
skip_test_records=skip_test_records,
skip_before_tests=skip_before_tests, skip_before_tests=skip_before_tests,
pdb_on_exceptions=pdb_on_exceptions, pdb_on_exceptions=pdb_on_exceptions,
selected_categories=[] if test_category == "all" else test_category, selected_categories=[] if test_category == "all" else test_category,

View file

@ -23,7 +23,9 @@ class UnitTestSubmissionQueue(UnitTestCase):
class TestSubmissionQueue(IntegrationTestCase): class TestSubmissionQueue(IntegrationTestCase):
queue = get_queue(qtype="default") @classmethod
def setUpClass(cls):
cls.queue = get_queue(qtype="default")
@timeout(seconds=20) @timeout(seconds=20)
def check_status(self, job: "Job", status, wait=True): def check_status(self, job: "Job", status, wait=True):

View file

@ -184,7 +184,6 @@ class TestCustomizeForm(IntegrationTestCase):
self.assertEqual(frappe.db.get_value("Custom Field", custom_field.name), None) self.assertEqual(frappe.db.get_value("Custom Field", custom_field.name), None)
frappe.local.test_objects["Custom Field"] = []
make_test_records_for_doctype("Custom Field") make_test_records_for_doctype("Custom Field")
def test_reset_to_defaults(self): def test_reset_to_defaults(self):
@ -194,7 +193,6 @@ class TestCustomizeForm(IntegrationTestCase):
self.assertEqual(d.get("fields", {"fieldname": "repeat_this_event"})[0].in_list_view, 0) self.assertEqual(d.get("fields", {"fieldname": "repeat_this_event"})[0].in_list_view, 0)
frappe.local.test_objects["Property Setter"] = []
make_test_records_for_doctype("Property Setter") make_test_records_for_doctype("Property Setter")
def test_set_allow_on_submit(self): def test_set_allow_on_submit(self):

View file

@ -87,27 +87,11 @@ class ParallelTestRunner:
frappe.set_user("Administrator") frappe.set_user("Administrator")
path, filename = file_info path, filename = file_info
module = self.get_module(path, filename) module = self.get_module(path, filename)
self.create_test_dependency_records(module, path, filename)
test_suite = unittest.TestSuite() test_suite = unittest.TestSuite()
module_test_cases = unittest.TestLoader().loadTestsFromModule(module) module_test_cases = unittest.TestLoader().loadTestsFromModule(module)
test_suite.addTest(module_test_cases) test_suite.addTest(module_test_cases)
test_suite(self.test_result) test_suite(self.test_result)
def create_test_dependency_records(self, module, path, filename):
if hasattr(module, "test_dependencies"):
for doctype in module.test_dependencies:
make_test_records(doctype, commit=True)
if os.path.basename(os.path.dirname(path)) == "doctype":
# test_data_migration_connector.py > data_migration_connector.json
test_record_filename = re.sub("^test_", "", filename).replace(".py", ".json")
test_record_file_path = os.path.join(path, test_record_filename)
if os.path.exists(test_record_file_path):
with open(test_record_file_path) as f:
doc = json.loads(f.read())
doctype = doc["name"]
make_test_records(doctype, commit=True)
def get_module(self, path, filename): def get_module(self, path, filename):
app_path = frappe.get_app_path(self.app) app_path = frappe.get_app_path(self.app)
relative_path = os.path.relpath(path, app_path) relative_path = os.path.relpath(path, app_path)
@ -179,7 +163,7 @@ def split_by_weight(work, weights, chunk_count):
def get_all_tests(app): def get_all_tests(app):
test_file_list = [] test_file_list = []
for path, folders, files in os.walk(frappe.get_app_path(app)): for path, folders, files in os.walk(frappe.get_app_path(app)):
for dontwalk in ("locals", ".git", "public", "__pycache__"): for dontwalk in ("node_modules", "locals", ".git", "public", "__pycache__"):
if dontwalk in folders: if dontwalk in folders:
folders.remove(dontwalk) folders.remove(dontwalk)

View file

@ -29,7 +29,7 @@ import click
import frappe import frappe
import frappe.utils.scheduler import frappe.utils.scheduler
from frappe.modules import get_module_name from frappe.modules import get_module_name
from frappe.tests.utils import IntegrationTestCase, make_test_records from frappe.tests.utils import IntegrationTestCase
from frappe.utils import cint from frappe.utils import cint
SLOW_TEST_THRESHOLD = 2 SLOW_TEST_THRESHOLD = 2
@ -40,10 +40,10 @@ logger = logging.getLogger(__name__)
def debug_timer(func): def debug_timer(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = time.time() start_time = time.monotonic()
result = func(*args, **kwargs) result = func(*args, **kwargs)
end_time = time.time() end_time = time.monotonic()
logger.debug(f" {func.__name__} took {end_time - start_time:.3f} seconds") logger.debug(f" {func.__name__:<50}{end_time - start_time:>6.3f} seconds")
return result return result
return wrapper return wrapper
@ -76,17 +76,8 @@ class TestRunner(unittest.TextTestRunner):
) )
self.junit_xml_output = junit_xml_output self.junit_xml_output = junit_xml_output
self.profile = profile self.profile = profile
self.test_record_callbacks = []
logger.debug("TestRunner initialized") logger.debug("TestRunner initialized")
def add_test_record_callback(self, callback):
self.test_record_callbacks.append(callback)
def execute_test_record_callbacks(self):
for callback in self.test_record_callbacks:
callback()
self.test_record_callbacks.clear()
def run( def run(
self, test_suites: tuple[unittest.TestSuite, unittest.TestSuite] self, test_suites: tuple[unittest.TestSuite, unittest.TestSuite]
) -> tuple[unittest.TestResult, unittest.TestResult | None]: ) -> tuple[unittest.TestResult, unittest.TestResult | None]:
@ -133,30 +124,26 @@ class TestRunner(unittest.TextTestRunner):
for app in apps: for app in apps:
app_path = Path(frappe.get_app_path(app)) app_path = Path(frappe.get_app_path(app))
for path in app_path.rglob("test_*.py"): for path, folders, files in os.walk(app_path):
if path.parts[-4:-1] == ("doctype", "doctype", "boilerplate"): folders[:] = [f for f in folders if not f.startswith(".")]
continue for dontwalk in ("node_modules", "locals", "public", "__pycache__"):
if path.name == "test_runner.py": if dontwalk in folders:
continue folders.remove(dontwalk)
relative_path = path.relative_to(app_path) if os.path.sep.join(["doctype", "doctype", "boilerplate"]) in path:
if any(part in relative_path.parts for part in ["locals", ".git", "public", "__pycache__"]): # in /doctype/doctype/boilerplate/
continue continue
module_name = ( path = Path(path)
f"{app_path.stem}.{'.'.join(relative_path.parent.parts)}.{path.stem}" for file in [
if str(relative_path.parent) != "." path.joinpath(filename)
else f"{app_path.stem}.{path.stem}" for filename in files
) if filename.startswith("test_")
module = importlib.import_module(module_name) and filename.endswith(".py")
and filename != "test_runner.py"
if path.parent.name == "doctype" and not config.skip_test_records: ]:
json_file = path.with_name(path.stem[5:] + ".json") module_name = f"{'.'.join(file.relative_to(app_path.parent).parent.parts)}.{file.stem}"
if json_file.exists(): module = importlib.import_module(module_name)
with json_file.open() as f: self._add_module_tests(module, unit_test_suite, integration_test_suite, config)
doctype = json.loads(f.read())["name"]
self.add_test_record_callback(lambda: make_test_records(doctype, commit=True))
self._add_module_tests(module, unit_test_suite, integration_test_suite, config)
logger.debug( logger.debug(
f"Discovered {unit_test_suite.countTestCases()} unit tests and {integration_test_suite.countTestCases()} integration tests" f"Discovered {unit_test_suite.countTestCases()} unit tests and {integration_test_suite.countTestCases()} integration tests"
@ -187,9 +174,6 @@ class TestRunner(unittest.TextTestRunner):
except ImportError: except ImportError:
logger.warning(f"No test module found for doctype {doctype}") logger.warning(f"No test module found for doctype {doctype}")
if not config.skip_test_records:
self.add_test_record_callback(lambda: make_test_records(doctype, force=force, commit=True))
return unit_test_suite, integration_test_suite return unit_test_suite, integration_test_suite
def discover_module_tests( def discover_module_tests(
@ -213,11 +197,6 @@ class TestRunner(unittest.TextTestRunner):
integration_test_suite: unittest.TestSuite, integration_test_suite: unittest.TestSuite,
config: TestConfig, config: TestConfig,
): ):
# Handle module test dependencies
if hasattr(module, "test_dependencies") and not config.skip_test_records:
for doctype in module.test_dependencies:
make_test_records(doctype, commit=True)
if config.case: if config.case:
test_suite = unittest.TestLoader().loadTestsFromTestCase(getattr(module, config.case)) test_suite = unittest.TestLoader().loadTestsFromTestCase(getattr(module, config.case))
else: else:
@ -249,13 +228,19 @@ class TestRunner(unittest.TextTestRunner):
class TestResult(unittest.TextTestResult): class TestResult(unittest.TextTestResult):
def startTest(self, test): def startTest(self, test):
logger.debug(f"--- Starting test: {test}")
self.tb_locals = True self.tb_locals = True
self._started_at = time.monotonic() self._started_at = time.monotonic()
super(unittest.TextTestResult, self).startTest(test) super(unittest.TextTestResult, self).startTest(test)
test_class = unittest.util.strclass(test.__class__) test_class = unittest.util.strclass(test.__class__)
if not hasattr(self, "current_test_class") or self.current_test_class != test_class: if getattr(self, "current_test_class", None) != test_class:
click.echo(f"\n{unittest.util.strclass(test.__class__)}") if new_doctypes := getattr(test.__class__, "_newly_created_test_records", None):
click.echo(f"\n{unittest.util.strclass(test.__class__)}")
click.secho(
f" Test Records created: {', '.join([f'{name} ({qty})' for name, qty in reversed(new_doctypes)])}",
fg="bright_black",
)
else:
click.echo(f"\n{unittest.util.strclass(test.__class__)}")
self.current_test_class = test_class self.current_test_class = test_class
def getTestMethodName(self, test): def getTestMethodName(self, test):
@ -265,34 +250,36 @@ class TestResult(unittest.TextTestResult):
super(unittest.TextTestResult, self).addSuccess(test) super(unittest.TextTestResult, self).addSuccess(test)
elapsed = time.monotonic() - self._started_at elapsed = time.monotonic() - self._started_at
threshold_passed = elapsed >= SLOW_TEST_THRESHOLD threshold_passed = elapsed >= SLOW_TEST_THRESHOLD
elapsed = click.style(f" ({elapsed:.03}s)", fg="red") if threshold_passed else "" elapsed_over_threashold = click.style(f" ({elapsed:.03}s)", fg="red") if threshold_passed else ""
click.echo(f" {click.style('', fg='green')} {self.getTestMethodName(test)}{elapsed}") logger.info(
logger.debug(f"=== Test passed: {test}") f" {click.style('', fg='green')} {self.getTestMethodName(test)}{elapsed_over_threashold}"
)
logger.debug(f"=== success === {test} {elapsed}")
def addError(self, test, err): def addError(self, test, err):
super(unittest.TextTestResult, self).addError(test, err) super(unittest.TextTestResult, self).addError(test, err)
click.echo(f" {click.style('', fg='red')} {self.getTestMethodName(test)}") click.echo(f" {click.style('', fg='red')} {self.getTestMethodName(test)}")
logger.debug(f"=== Test error: {test}") logger.debug(f"=== error === {test}")
def addFailure(self, test, err): def addFailure(self, test, err):
super(unittest.TextTestResult, self).addFailure(test, err) super(unittest.TextTestResult, self).addFailure(test, err)
click.echo(f" {click.style('', fg='red')} {self.getTestMethodName(test)}") click.echo(f" {click.style('', fg='red')} {self.getTestMethodName(test)}")
logger.debug(f"=== Test failed: {test}") logger.debug(f"=== failure === {test}")
def addSkip(self, test, reason): def addSkip(self, test, reason):
super(unittest.TextTestResult, self).addSkip(test, reason) super(unittest.TextTestResult, self).addSkip(test, reason)
click.echo(f" {click.style(' = ', fg='white')} {self.getTestMethodName(test)}") click.echo(f" {click.style(' = ', fg='white')} {self.getTestMethodName(test)}")
logger.debug(f"=== Test skipped: {test}") logger.debug(f"=== skipped === {test}")
def addExpectedFailure(self, test, err): def addExpectedFailure(self, test, err):
super(unittest.TextTestResult, self).addExpectedFailure(test, err) super(unittest.TextTestResult, self).addExpectedFailure(test, err)
click.echo(f" {click.style('', fg='red')} {self.getTestMethodName(test)}") click.echo(f" {click.style('', fg='red')} {self.getTestMethodName(test)}")
logger.debug(f"=== Test expected failure: {test}") logger.debug(f"=== expected failure === {test}")
def addUnexpectedSuccess(self, test): def addUnexpectedSuccess(self, test):
super(unittest.TextTestResult, self).addUnexpectedSuccess(test) super(unittest.TextTestResult, self).addUnexpectedSuccess(test)
click.echo(f" {click.style('', fg='green')} {self.getTestMethodName(test)}") click.echo(f" {click.style('', fg='green')} {self.getTestMethodName(test)}")
logger.debug(f"=== Test unexpected success: {test}") logger.debug(f"=== unexpected success === {test}")
def printErrors(self): def printErrors(self):
click.echo("\n") click.echo("\n")
@ -333,7 +320,6 @@ class TestConfig:
categories: dict = field(default_factory=lambda: {"unit": [], "integration": []}) categories: dict = field(default_factory=lambda: {"unit": [], "integration": []})
selected_categories: list[str] = field(default_factory=list) selected_categories: list[str] = field(default_factory=list)
skip_before_tests: bool = False skip_before_tests: bool = False
skip_test_records: bool = False # New attribute
def xmlrunner_wrapper(output): def xmlrunner_wrapper(output):
@ -366,7 +352,6 @@ def main(
doctype_list_path: str | None = None, doctype_list_path: str | None = None,
failfast: bool = False, failfast: bool = False,
case: str | None = None, case: str | None = None,
skip_test_records: bool = False,
skip_before_tests: bool = False, skip_before_tests: bool = False,
pdb_on_exceptions: bool = False, pdb_on_exceptions: bool = False,
selected_categories: list[str] | None = None, selected_categories: list[str] | None = None,
@ -407,7 +392,6 @@ def main(
pdb_on_exceptions=pdb_on_exceptions, pdb_on_exceptions=pdb_on_exceptions,
selected_categories=selected_categories or [], selected_categories=selected_categories or [],
skip_before_tests=skip_before_tests, skip_before_tests=skip_before_tests,
skip_test_records=skip_test_records,
) )
_initialize_test_environment(site, test_config) _initialize_test_environment(site, test_config)
@ -602,17 +586,14 @@ def _run_all_tests(
logger.debug(f"Running tests for apps: {apps}") logger.debug(f"Running tests for apps: {apps}")
try: try:
unit_test_suite, integration_test_suite = runner.discover_tests(apps, config) unit_test_suite, integration_test_suite = runner.discover_tests(apps, config)
logger.debug(
f"Discovered {len(list(runner._iterate_suite(unit_test_suite)))} unit tests and {len(list(runner._iterate_suite(integration_test_suite)))} integration tests"
)
if config.pdb_on_exceptions: if config.pdb_on_exceptions:
for test_suite in (unit_test_suite, integration_test_suite): for test_suite in (unit_test_suite, integration_test_suite):
for test_case in runner._iterate_suite(test_suite): for test_case in runner._iterate_suite(test_suite):
if hasattr(test_case, "_apply_debug_decorator"): if hasattr(test_case, "_apply_debug_decorator"):
test_case._apply_debug_decorator(config.pdb_on_exceptions) test_case._apply_debug_decorator(config.pdb_on_exceptions)
_prepare_integration_tests(runner, integration_test_suite, config, app) for app in apps:
_prepare_integration_tests(runner, integration_test_suite, config, app)
res = runner.run((unit_test_suite, integration_test_suite)) res = runner.run((unit_test_suite, integration_test_suite))
_cleanup_after_tests() _cleanup_after_tests()
return res return res
@ -635,7 +616,6 @@ def _run_doctype_tests(
for test_case in runner._iterate_suite(test_suite): for test_case in runner._iterate_suite(test_suite):
if hasattr(test_case, "_apply_debug_decorator"): if hasattr(test_case, "_apply_debug_decorator"):
test_case._apply_debug_decorator(config.pdb_on_exceptions) test_case._apply_debug_decorator(config.pdb_on_exceptions)
_prepare_integration_tests(runner, integration_test_suite, config, app) _prepare_integration_tests(runner, integration_test_suite, config, app)
res = runner.run((unit_test_suite, integration_test_suite)) res = runner.run((unit_test_suite, integration_test_suite))
_cleanup_after_tests() _cleanup_after_tests()
@ -677,53 +657,43 @@ def _prepare_integration_tests(
""" """
We perform specific setup steps only for integration tests: We perform specific setup steps only for integration tests:
1. Database Connection: 1. Before Tests Hooks:
- Initialized only for integration tests to avoid overhead in unit tests.
- Essential for end-to-end functionality testing in integration tests.
- Maintains separation between unit and integration tests.
2. Before Tests Hooks:
- Executed only for integration tests unless explicitly skipped. - Executed only for integration tests unless explicitly skipped.
- Provides necessary environment setup for integration tests. - Provides necessary environment setup for integration tests.
- Skipped for unit tests to maintain their independence and isolation. - Skipped for unit tests to maintain their independence and isolation.
3. Test Record Creation: 2. Global Test Record Creation:
- Performed only for integration tests unless explicitly skipped. - Performed only for integration tests.
- Creates or modifies database records needed for integration tests. - Creates or modifies global per-app database records needed for integration tests.
- Ensures consistent starting state and allows for complex test scenarios.
- Skipped for unit tests to maintain their isolation and reproducibility. - Skipped for unit tests to maintain their isolation and reproducibility.
These steps are crucial for integration tests but unnecessary or potentially
harmful for unit tests, which should be independent of external state and fast to execute.
By selectively applying these setup steps, we maintain the integrity and purpose
of both unit and integration tests while optimizing performance.
""" """
if not config.skip_before_tests: if not config.skip_before_tests:
_run_before_test_hooks(config, app) _run_before_test_hooks(config, app)
else: else:
logger.debug("Skipping before_tests hooks: Explicitly skipped") logger.debug("Skipping before_tests hooks: Explicitly skipped")
if app:
if not config.skip_test_records: _run_global_test_records_dependencies_install(app)
_execute_test_record_callbacks(runner)
else:
logger.debug("Skipping test record creation: Explicitly skipped")
else: else:
logger.debug("Skipping before_tests hooks and test record creation: No integration tests") logger.debug("Skipping before_tests hooks and global test record creation: No integration tests")
@debug_timer @debug_timer
def _run_before_test_hooks(config: TestConfig, app: str | None): def _run_before_test_hooks(config: TestConfig, app: str | None):
"""Run 'before_tests' hooks""" """Run 'before_tests' hooks"""
logger.debug('Running "before_tests" hooks') logger.debug(f'Running "before_tests" hooks for {app}')
for hook_function in frappe.get_hooks("before_tests", app_name=app): for hook_function in frappe.get_hooks("before_tests", app_name=app):
frappe.get_attr(hook_function)() frappe.get_attr(hook_function)()
@debug_timer @debug_timer
def _execute_test_record_callbacks(runner): def _run_global_test_records_dependencies_install(app: str):
"""Execute test record creation callbacks""" """Run global test records dependencies install"""
logger.debug("Running test record creation callbacks") test_module = frappe.get_module(f"{app}.tests")
runner.execute_test_record_callbacks() logger.debug(f"Loading global tests records from {test_module.__name__}")
if hasattr(test_module, "global_test_dependencies"):
for doctype in test_module.global_test_dependencies:
logger.debug(f" Loading records for {doctype}")
make_test_records(doctype, commit=True)
# Backwards-compatible aliases # Backwards-compatible aliases
@ -732,6 +702,7 @@ from frappe.tests.utils import (
get_dependencies, get_dependencies,
get_modules, get_modules,
make_test_objects, make_test_objects,
make_test_records,
make_test_records_for_doctype, make_test_records_for_doctype,
print_mandatory_fields, print_mandatory_fields,
) )

View file

@ -1,6 +1,7 @@
import copy import copy
import datetime import datetime
import functools import functools
import json
import os import os
import pdb import pdb
import signal import signal
@ -29,22 +30,6 @@ import logging
logger = logging.Logger(__file__) logger = logging.Logger(__file__)
# Moved from test_runner.py
def make_test_records(doctype, force=False, commit=False):
"""Make test records for the specified doctype"""
logger.debug(f"Making test records for doctype: {doctype}")
for options in get_dependencies(doctype):
if options == "[Select]":
continue
if options not in frappe.local.test_objects:
frappe.local.test_objects[options] = []
make_test_records(options, force, commit=commit)
make_test_records_for_doctype(options, force, commit=commit)
@cache @cache
def get_modules(doctype): def get_modules(doctype):
@ -70,7 +55,7 @@ def get_dependencies(doctype):
for df in meta.get_table_fields(): for df in meta.get_table_fields():
link_fields.extend(frappe.get_meta(df.options).get_link_fields()) link_fields.extend(frappe.get_meta(df.options).get_link_fields())
options_list = [df.options for df in link_fields] + [doctype] options_list = [df.options for df in link_fields]
if hasattr(test_module, "test_dependencies"): if hasattr(test_module, "test_dependencies"):
options_list += test_module.test_dependencies options_list += test_module.test_dependencies
@ -87,7 +72,48 @@ def get_dependencies(doctype):
return options_list return options_list
# Test record generation
def make_test_records(doctype, force=False, commit=False):
return list(_make_test_records(doctype, force, commit))
def make_test_records_for_doctype(doctype, force=False, commit=False): def make_test_records_for_doctype(doctype, force=False, commit=False):
return list(_make_test_records_for_doctype(doctype, force, commit))
def make_test_objects(doctype, test_records=None, reset=False, commit=False):
return list(_make_test_objects(doctype, test_records, reset, commit))
def _make_test_records(doctype, force=False, commit=False):
"""Make test records for the specified doctype"""
loadme = False
if doctype not in frappe.local.test_objects:
loadme = True
frappe.local.test_objects[doctype] = [] # infinite recursion guard, here
# First, create test records for dependencies
for dependency in get_dependencies(doctype):
if dependency != "[Select]" and dependency not in frappe.local.test_objects:
yield from _make_test_records(dependency, force, commit)
# Then, create test records for the doctype itself
if loadme:
# Yield the doctype and record length
yield (
doctype,
len(
# Create all test records
list(_make_test_records_for_doctype(doctype, force, commit))
),
)
def _make_test_records_for_doctype(doctype, force=False, commit=False):
"""Make test records for the specified doctype""" """Make test records for the specified doctype"""
test_record_log_instance = TestRecordLog() test_record_log_instance = TestRecordLog()
@ -95,32 +121,22 @@ def make_test_records_for_doctype(doctype, force=False, commit=False):
return return
module, test_module = get_modules(doctype) module, test_module = get_modules(doctype)
logger.debug(f"Making test records for {doctype}")
if hasattr(test_module, "_make_test_records"): if hasattr(test_module, "_make_test_records"):
frappe.local.test_objects[doctype] = ( yield from test_module._make_test_records()
frappe.local.test_objects.get(doctype, []) + test_module._make_test_records()
)
elif hasattr(test_module, "test_records"): elif hasattr(test_module, "test_records"):
frappe.local.test_objects[doctype] = frappe.local.test_objects.get(doctype, []) + make_test_objects( yield from _make_test_objects(doctype, test_module.test_records, force, commit=commit)
doctype, test_module.test_records, force, commit=commit
)
else: else:
test_records = frappe.get_test_records(doctype) test_records = frappe.get_test_records(doctype)
if test_records: if test_records:
frappe.local.test_objects[doctype] = frappe.local.test_objects.get( yield from _make_test_objects(doctype, test_records, force, commit=commit)
doctype, []
) + make_test_objects(doctype, test_records, force, commit=commit)
elif logger.getEffectiveLevel() < logging.INFO: elif logger.getEffectiveLevel() < logging.INFO:
print_mandatory_fields(doctype) print_mandatory_fields(doctype)
test_record_log_instance.add(doctype) test_record_log_instance.add(doctype)
def make_test_objects(doctype, test_records=None, reset=False, commit=False): def _make_test_objects(doctype, test_records=None, reset=False, commit=False):
"""Make test objects from given list of `test_records` or from `test_records.json`""" """Generator function to make test objects"""
logger.debug(f"Making test objects for doctype: {doctype}")
records = []
def revert_naming(d): def revert_naming(d):
if getattr(d, "naming_series", None): if getattr(d, "naming_series", None):
@ -177,11 +193,11 @@ def make_test_objects(doctype, test_records=None, reset=False, commit=False):
logger.debug(f"Error in making test record for {d.doctype} {d.name}") logger.debug(f"Error in making test record for {d.doctype} {d.name}")
raise raise
records.append(d.name)
if commit: if commit:
frappe.db.commit() frappe.db.commit()
return records
frappe.local.test_objects[doctype] += d.name
yield d.name
def print_mandatory_fields(doctype): def print_mandatory_fields(doctype):
@ -308,6 +324,27 @@ class UnitTestCase(unittest.TestCase):
to maintain the functionality of this base class. to maintain the functionality of this base class.
""" """
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.doctype = cls._get_doctype_from_module()
cls.module = frappe.get_module(cls.__module__)
@classmethod
def _get_doctype_from_module(cls):
module_path = cls.__module__.split(".")
try:
doctype_index = module_path.index("doctype")
doctype_snake_case = module_path[doctype_index + 1]
json_file_path = Path(*module_path[:-1]).joinpath(f"{doctype_snake_case}.json")
if json_file_path.is_file():
doctype_data = json.loads(json_file_path.read_text())
return doctype_data.get("name")
except (ValueError, IndexError):
# 'doctype' not found in module_path
pass
return None
def _apply_debug_decorator(self, exceptions=()): def _apply_debug_decorator(self, exceptions=()):
setattr(self, self._testMethodName, debug_on(*exceptions)(getattr(self, self._testMethodName))) setattr(self, self._testMethodName, debug_on(*exceptions)(getattr(self, self._testMethodName)))
@ -420,6 +457,7 @@ class IntegrationTestCase(UnitTestCase):
- Automatic database setup and teardown - Automatic database setup and teardown
- Utilities for managing database connections - Utilities for managing database connections
- Context managers for query counting and Redis call monitoring - Context managers for query counting and Redis call monitoring
- Lazy loading of test record dependencies
Note: If you override `setUpClass`, make sure to call `super().setUpClass()` Note: If you override `setUpClass`, make sure to call `super().setUpClass()`
to maintain the functionality of this base class. to maintain the functionality of this base class.
@ -433,11 +471,23 @@ class IntegrationTestCase(UnitTestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
super().setUpClass() super().setUpClass()
# Site initialization
cls.TEST_SITE = getattr(frappe.local, "site", None) or cls.TEST_SITE cls.TEST_SITE = getattr(frappe.local, "site", None) or cls.TEST_SITE
frappe.init(cls.TEST_SITE) frappe.init(cls.TEST_SITE)
cls.ADMIN_PASSWORD = frappe.get_conf(cls.TEST_SITE).admin_password cls.ADMIN_PASSWORD = frappe.get_conf(cls.TEST_SITE).admin_password
cls._primary_connection = frappe.local.db cls._primary_connection = frappe.local.db
cls._secondary_connection = None cls._secondary_connection = None
# Create test record dependencies
cls._newly_created_test_records = []
if cls.doctype and cls.doctype not in frappe.local.test_objects:
cls._newly_created_test_records += make_test_records(cls.doctype)
for doctype in getattr(cls.module, "test_dependencies", []):
if doctype not in frappe.local.test_objects:
cls._newly_created_test_records += make_test_records(doctype)
# flush changes done so far to avoid flake # flush changes done so far to avoid flake
frappe.db.commit() frappe.db.commit()
if cls.SHOW_TRANSACTION_COMMIT_WARNINGS: if cls.SHOW_TRANSACTION_COMMIT_WARNINGS: