refactor: Structure frappe.test.utils (green to green) (#28038)

* docs: constitute frappe.test readme

* refactor: move utils to __init__

* refactor: move generators into generators.py

* refactor: move cm into context_managers.py

* refactor: move test classes into submodule

* refactor: reexport general purpose context managers

* refactor: adapt imports (treewide)
This commit is contained in:
David Arnold 2024-10-08 17:10:24 +02:00 committed by GitHub
parent b0b8139233
commit e7776021aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 891 additions and 802 deletions

View file

@ -19,7 +19,7 @@ repos:
- id: check-toml
- id: check-yaml
- id: debug-statements
exclude: ^frappe/tests/utils\.py$
exclude: ^frappe/tests/classes/context_managers\.py$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0

View file

@ -7,9 +7,8 @@ from contextlib import contextmanager
import frappe
from frappe.desk.query_report import generate_report_result, get_report_doc
from frappe.query_builder.utils import db_type_is
from frappe.tests import IntegrationTestCase, UnitTestCase
from frappe.tests import IntegrationTestCase, UnitTestCase, timeout
from frappe.tests.test_query_builder import run_only_if
from frappe.tests.utils import timeout
class UnitTestPreparedReport(UnitTestCase):

View file

@ -10,8 +10,7 @@ from rq.job import Job
import frappe
from frappe.core.doctype.rq_job.rq_job import RQJob, remove_failed_jobs, stop_job
from frappe.installer import update_site_config
from frappe.tests import IntegrationTestCase, UnitTestCase
from frappe.tests.utils import timeout
from frappe.tests import IntegrationTestCase, UnitTestCase, timeout
from frappe.utils import cstr, execute_in_shell
from frappe.utils.background_jobs import get_job_status, is_job_enqueued

View file

@ -5,8 +5,7 @@ import time
import typing
import frappe
from frappe.tests import IntegrationTestCase, UnitTestCase
from frappe.tests.utils import timeout
from frappe.tests import IntegrationTestCase, UnitTestCase, timeout
from frappe.utils.background_jobs import get_queue
if typing.TYPE_CHECKING:

View file

@ -6,8 +6,7 @@ import time
import frappe
from frappe.core.doctype.doctype.test_doctype import new_doctype
from frappe.desk.doctype.bulk_update.bulk_update import submit_cancel_or_update_docs
from frappe.tests import IntegrationTestCase, UnitTestCase
from frappe.tests.utils import timeout
from frappe.tests import IntegrationTestCase, UnitTestCase, timeout
class UnitTestBulkUpdate(UnitTestCase):

View file

@ -28,6 +28,7 @@ def xmlrunner_wrapper(output):
return _runner
# TODO: move to deprecation dumpster
from frappe.tests.utils import (
TestRecordLog,
get_dependencies,
@ -39,6 +40,7 @@ from frappe.tests.utils import (
)
# TODO: move to deprecation dumpster
# Compatibility functions
def add_to_test_record_log(doctype):
TestRecordLog().add(doctype)

View file

@ -27,7 +27,7 @@ from pathlib import Path
from typing import TYPE_CHECKING
import frappe
from frappe.tests.utils import IntegrationTestCase
from frappe.tests import IntegrationTestCase
from .utils import debug_timer

View file

@ -1,45 +1,76 @@
# Frappe Test Utilities
# Frappe Test Framework
This README provides an overview of the test utilities available in the Frappe framework, particularly focusing on the `frappe/tests/utils.py` file. These utilities are designed to facilitate efficient and effective testing of Frappe applications.
This README provides an overview of the test case framework available in Frappe. These utilities are designed to facilitate efficient and effective testing of Frappe applications.
## Main Functions
This is different from the `frappe.testing` module which houses the discovery and runner infrastructure for CLI and CI.
The `utils.py` file contains several key components:
## Directory Structure
1. Test record generation utilities
2. Test case classes (UnitTestCase and IntegrationTestCase)
3. Context managers for various testing scenarios
4. Utility functions and decorators
The test framework is organized into the following structure:
## Test Case Classes
```
frappe/tests/
├── classes/
│ ├── context_managers.py
│ ├── unit_test_case.py
│ └── ...
├── utils/
│ ├── generators.py
│ └── ...
├── test_api.py
├── test_child_table.py
└── ...
```
### UnitTestCase
## Key Components
This class extends `unittest.TestCase` and provides additional utilities specific to Frappe framework. It's designed for testing individual components or functions in isolation.
1. Test case classes (UnitTestCase and IntegrationTestCase)
3. Framework and class specific context managers
4. Utility functions and generators
5. Specific test modules for various Frappe components
Some key methods and features include:
### Test Case Classes
- Custom assertions (e.g., `assertQueryEqual`, `assertDocumentEqual`)
- HTML and SQL normalization
#### UnitTestCase ([`classes/unit_test_case.py`](./classes/unit_test_case.py))
###### Import convention: `from frappe.tests import UnitTestCase`
This class extends `unittest.TestCase` and provides additional utilities specific to the Frappe framework. It's designed for testing individual components or functions in isolation.
Key features include:
- Custom assertions for Frappe-specific comparisons
- Utilities for HTML and SQL normalization
- Context managers for user switching and time freezing
### IntegrationTestCase
#### IntegrationTestCase ([`classes/integration_test_case.py`](./classes/integration_test_case.py))
###### Import convention: `from frappe.tests import IntegrationTestCase`
This class extends `UnitTestCase` and is designed for integration testing. It provides features for:
- Automatic site and connection setup
- Automatic test records loading
- Automatic reset of thread locals
- Context managers that depend on a site connection
- Asserts that depend on a site connection
- Automatic database setup and teardown
- Database connection management
- Query counting and Redis call monitoring
- Lazy loading of test record dependencies
For a detailed list of context managers, please refer to the code.
For a complete list of methods and their usage, please refer to the actual code in `frappe/tests/utils.py`.
### Utility Functions and Generators ([`utils/generators.py`](./utils/generators.py))
This module contains utility functions for generating test records and managing test data.
### Specific Test Modules
Various test modules (e.g., test_api.py, test_document.py) contain tests for specific Frappe core components and functionalities.
Note that Document tests are collocated alongside each Document module.
## Usage
To use these test utilities in your Frappe application tests, you can inherit from the appropriate test case class:
```python
from frappe.tests.utils import UnitTestCase
from frappe.tests import UnitTestCase
class MyTestCase(UnitTestCase):
def test_something(self):
@ -47,8 +78,11 @@ class MyTestCase(UnitTestCase):
pass
```
Remember that this README provides an overview as of the time of writing. Always refer to the actual code for the most up-to-date and detailed information on available methods and their usage.
## Contributing
If you're adding new test utilities or modifying existing ones, please ensure to update this README accordingly.
When adding new test utilities or modifying existing ones:
1. Place them in the appropriate directory based on their function.
2. Update this README to reflect any significant changes in the framework structure or usage.
3. Ensure that your changes follow the existing coding style and conventions.
Remember to always refer to the actual code for the most up-to-date and detailed information on available methods and their usage.

View file

@ -1,7 +1,20 @@
# TODO: move to dumpster
import frappe
from .utils import IntegrationTestCase, MockedRequestTestCase, UnitTestCase
from .classes import *
from .classes.context_managers import *
global_test_dependencies = ["User"]
# TODO: move to dumpster - not meant to be a public interface anymore
import frappe.tests.utils as utils
utils.IntegrationTestCase = IntegrationTestCase
utils.UnitTestCase = UnitTestCase
utils.FrappeTestCase = IntegrationTestCase
utils.change_settings = IntegrationTestCase.change_settings
utils.patch_hooks = UnitTestCase.patch_hooks
utils.debug_on = debug_on
utils.timeout = timeout
# TODO: move to dumpster
@ -17,6 +30,3 @@ def update_system_settings(args, commit=False):
# TODO: move to dumpster
def get_system_setting(key):
return frappe.db.get_single_value("System Settings", key)
global_test_dependencies = ["User"]

View file

@ -0,0 +1,3 @@
from .integration_test_case import IntegrationTestCase
from .mocked_request_test_case import MockedRequestTestCase
from .unit_test_case import UnitTestCase

View file

@ -0,0 +1,116 @@
import functools
import logging
import pdb
import signal
import sys
import traceback
logger = logging.Logger(__file__)
# NOTE: declare those who should also be made available directly frappe.tests.* namespace
# these can be general purpose context managers who do NOT depend on a particular
# test class setup, such as for example the IntegrationTestCase's connection to site
__all__ = [
"debug_on",
"timeout",
]
def debug_on(*exceptions):
"""
A decorator to automatically start the debugger when specified exceptions occur.
This decorator allows you to automatically invoke the debugger (pdb) when certain
exceptions are raised in the decorated function. If no exceptions are specified,
it defaults to catching AssertionError.
Args:
*exceptions: Variable length argument list of exception classes to catch.
If none provided, defaults to (AssertionError,).
Returns:
function: A decorator function.
Usage:
1. Basic usage (catches AssertionError):
@debug_on()
def test_assertion_error():
assert False, "This will start the debugger"
2. Catching specific exceptions:
@debug_on(ValueError, TypeError)
def test_specific_exceptions():
raise ValueError("This will start the debugger")
3. Using on a method in a test class:
class TestMyFunctionality(unittest.TestCase):
@debug_on(ZeroDivisionError)
def test_division_by_zero(self):
result = 1 / 0
Note:
When an exception is caught, this decorator will print the exception traceback
and then start the post-mortem debugger, allowing you to inspect the state of
the program at the point where the exception was raised.
"""
if not exceptions:
exceptions = (AssertionError,)
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except exceptions as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
# Pretty print the exception
print("\n\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[93m" + str(exc_type.__name__) + ": " + str(exc_value) + "\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
# Print the formatted traceback
traceback_lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
for line in traceback_lines:
print("\033[96m" + line.rstrip() + "\033[0m") # Cyan color
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[92mEntering post-mortem debugging\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
pdb.post_mortem()
raise e
return wrapper
return decorator
def timeout(seconds=30, error_message="Test timed out."):
"""Timeout decorator to ensure a test doesn't run for too long.
adapted from https://stackoverflow.com/a/2282656"""
# Support @timeout (without function call)
no_args = bool(callable(seconds))
actual_timeout = 30 if no_args else seconds
actual_error_message = "Test timed out" if no_args else error_message
def decorator(func):
def _handle_timeout(signum, frame):
raise Exception(actual_error_message)
def wrapper(*args, **kwargs):
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(actual_timeout)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
return wrapper
if no_args:
return decorator(seconds)
return decorator

View file

@ -0,0 +1,250 @@
import copy
import logging
from contextlib import AbstractContextManager, contextmanager
import frappe
from frappe.utils import cint
from ..utils.generators import make_test_records
from .unit_test_case import UnitTestCase
logger = logging.Logger(__file__)
class IntegrationTestCase(UnitTestCase):
"""Integration test class for Frappe tests.
Key features:
- Automatic database setup and teardown
- Utilities for managing database connections
- Context managers for query counting and Redis call monitoring
- Lazy loading of test record dependencies
Note: If you override `setUpClass`, make sure to call `super().setUpClass()`
to maintain the functionality of this base class.
"""
TEST_SITE = "test_site"
SHOW_TRANSACTION_COMMIT_WARNINGS = False
maxDiff = 10_000 # prints long diffs but useful in CI
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
# Site initialization
cls.TEST_SITE = getattr(frappe.local, "site", None) or cls.TEST_SITE
frappe.init(cls.TEST_SITE)
cls.ADMIN_PASSWORD = frappe.get_conf(cls.TEST_SITE).admin_password
cls._primary_connection = frappe.local.db
cls._secondary_connection = None
# Create test record dependencies
cls._newly_created_test_records = []
if cls.doctype and cls.doctype not in frappe.local.test_objects:
cls._newly_created_test_records += make_test_records(cls.doctype)
for doctype in getattr(cls.module, "test_dependencies", []):
if doctype not in frappe.local.test_objects:
cls._newly_created_test_records += make_test_records(doctype)
# flush changes done so far to avoid flake
frappe.db.commit()
if cls.SHOW_TRANSACTION_COMMIT_WARNINGS:
frappe.db.before_commit.add(_commit_watcher)
# enqueue teardown actions (executed in LIFO order)
cls.addClassCleanup(_restore_thread_locals, copy.deepcopy(frappe.local.flags))
cls.addClassCleanup(_rollback_db)
@classmethod
def tearDownClass(cls) -> None:
# Add any necessary teardown code here
super().tearDownClass()
def setUp(self) -> None:
super().setUp()
# Add any per-test setup code here
def tearDown(self) -> None:
# Add any per-test teardown code here
super().tearDown()
@contextmanager
def primary_connection(self) -> AbstractContextManager[None]:
"""Switch to primary DB connection
This is used for simulating multiple users performing actions by simulating two DB connections"""
try:
current_conn = frappe.local.db
frappe.local.db = self._primary_connection
yield
finally:
frappe.local.db = current_conn
@contextmanager
def secondary_connection(self) -> AbstractContextManager[None]:
"""Switch to secondary DB connection."""
if self._secondary_connection is None:
frappe.connect() # get second connection
self._secondary_connection = frappe.local.db
try:
current_conn = frappe.local.db
frappe.local.db = self._secondary_connection
yield
finally:
frappe.local.db = current_conn
self.addCleanup(self._rollback_connections)
def _rollback_connections(self) -> None:
self._primary_connection.rollback()
self._secondary_connection.rollback()
@contextmanager
def assertQueryCount(self, count: int) -> AbstractContextManager[None]:
queries = []
def _sql_with_count(*args, **kwargs):
ret = orig_sql(*args, **kwargs)
queries.append(args[0].last_query)
return ret
try:
orig_sql = frappe.db.__class__.sql
frappe.db.__class__.sql = _sql_with_count
yield
self.assertLessEqual(len(queries), count, msg="Queries executed: \n" + "\n\n".join(queries))
finally:
frappe.db.__class__.sql = orig_sql
@contextmanager
def assertRedisCallCounts(self, count: int) -> AbstractContextManager[None]:
commands = []
def execute_command_and_count(*args, **kwargs):
ret = orig_execute(*args, **kwargs)
key_len = 2
if "H" in args[0]:
key_len = 3
commands.append((args)[:key_len])
return ret
try:
orig_execute = frappe.cache.execute_command
frappe.cache.execute_command = execute_command_and_count
yield
self.assertLessEqual(
len(commands), count, msg="commands executed: \n" + "\n".join(str(c) for c in commands)
)
finally:
frappe.cache.execute_command = orig_execute
@contextmanager
def assertRowsRead(self, count: int) -> AbstractContextManager[None]:
rows_read = 0
def _sql_with_count(*args, **kwargs):
nonlocal rows_read
ret = orig_sql(*args, **kwargs)
# count of last touched rows as per DB-API 2.0 https://peps.python.org/pep-0249/#rowcount
rows_read += cint(frappe.db._cursor.rowcount)
return ret
try:
orig_sql = frappe.db.sql
frappe.db.sql = _sql_with_count
yield
self.assertLessEqual(rows_read, count, msg="Queries read more rows than expected")
finally:
frappe.db.sql = orig_sql
@contextmanager
def switch_site(self, site: str) -> AbstractContextManager[None]:
"""Switch connection to different site.
Note: Drops current site connection completely."""
try:
old_site = frappe.local.site
frappe.init(site, force=True)
frappe.connect()
yield
finally:
frappe.init(old_site, force=True)
frappe.connect()
@staticmethod
@contextmanager
def change_settings(doctype, settings_dict=None, /, commit=False, **settings):
"""A context manager to ensure that settings are changed before running
function and restored after running it regardless of exceptions occurred.
This is useful in tests where you want to make changes in a function but
don't retain those changes.
import and use as decorator to cover full function or using `with` statement.
example:
@change_settings("Print Settings", {"send_print_as_pdf": 1})
def test_case(self):
...
@change_settings("Print Settings", send_print_as_pdf=1)
def test_case(self):
...
"""
if settings_dict is None:
settings_dict = settings
try:
settings = frappe.get_doc(doctype)
# remember setting
previous_settings = copy.deepcopy(settings_dict)
for key in previous_settings:
previous_settings[key] = getattr(settings, key)
# change setting
for key, value in settings_dict.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
# singles are cached by default, clear to avoid flake
frappe.db.value_cache[settings] = {}
if commit:
frappe.db.commit()
yield # yield control to calling function
finally:
# restore settings
settings = frappe.get_doc(doctype)
for key, value in previous_settings.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
if commit:
frappe.db.commit()
def _commit_watcher():
import traceback
logger.warning("Transaction committed during tests.")
traceback.print_stack(limit=10)
def _rollback_db():
frappe.db.value_cache = {}
frappe.db.rollback()
def _restore_thread_locals(flags):
frappe.local.flags = flags
frappe.local.error_log = []
frappe.local.message_log = []
frappe.local.debug_log = []
frappe.local.conf = frappe._dict(frappe.get_site_config())
frappe.local.cache = {}
frappe.local.lang = "en"
frappe.local.preload_assets = {"style": [], "script": [], "icons": []}
if hasattr(frappe.local, "request"):
delattr(frappe.local, "request")

View file

@ -0,0 +1,18 @@
import logging
from .integration_test_case import IntegrationTestCase
logger = logging.Logger(__file__)
class MockedRequestTestCase(IntegrationTestCase):
def setUp(self):
import responses
self.responses = responses.RequestsMock()
self.responses.start()
self.addCleanup(self.responses.stop)
self.addCleanup(self.responses.reset)
return super().setUp()

View file

@ -0,0 +1,165 @@
import datetime
import json
import logging
import os
import unittest
from collections.abc import Sequence
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytz
import frappe
from frappe.model.base_document import BaseDocument
from frappe.utils import cint
from frappe.utils.data import get_datetime, get_system_timezone
from .context_managers import debug_on
logger = logging.Logger(__file__)
datetime_like_types = (datetime.datetime, datetime.date, datetime.time, datetime.timedelta)
class UnitTestCase(unittest.TestCase):
"""Unit test class for Frappe tests.
This class extends unittest.TestCase and provides additional utilities
specific to Frappe framework. It's designed for testing individual
components or functions in isolation.
Key features:
- Custom assertions for Frappe-specific comparisons
- Utilities for HTML and SQL normalization
- Context managers for user switching and time freezing
Note: If you override `setUpClass`, make sure to call `super().setUpClass()`
to maintain the functionality of this base class.
"""
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.doctype = cls._get_doctype_from_module()
cls.module = frappe.get_module(cls.__module__)
@classmethod
def _get_doctype_from_module(cls):
module_path = cls.__module__.split(".")
try:
doctype_index = module_path.index("doctype")
doctype_snake_case = module_path[doctype_index + 1]
json_file_path = Path(*module_path[:-1]).joinpath(f"{doctype_snake_case}.json")
if json_file_path.is_file():
doctype_data = json.loads(json_file_path.read_text())
return doctype_data.get("name")
except (ValueError, IndexError):
# 'doctype' not found in module_path
pass
return None
def _apply_debug_decorator(self, exceptions=()):
setattr(self, self._testMethodName, debug_on(*exceptions)(getattr(self, self._testMethodName)))
def assertQueryEqual(self, first: str, second: str) -> None:
self.assertEqual(self.normalize_sql(first), self.normalize_sql(second))
def assertSequenceSubset(self, larger: Sequence, smaller: Sequence, msg: str | None = None) -> None:
"""Assert that `expected` is a subset of `actual`."""
self.assertTrue(set(smaller).issubset(set(larger)), msg=msg)
# --- Frappe Framework specific assertions
def assertDocumentEqual(self, expected: dict | BaseDocument, actual: BaseDocument) -> None:
"""Compare a (partial) expected document with actual Document."""
if isinstance(expected, BaseDocument):
expected = expected.as_dict()
for field, value in expected.items():
if isinstance(value, list):
actual_child_docs = actual.get(field)
self.assertEqual(len(value), len(actual_child_docs), msg=f"{field} length should be same")
for exp_child, actual_child in zip(value, actual_child_docs, strict=False):
self.assertDocumentEqual(exp_child, actual_child)
else:
self._compare_field(value, actual.get(field), actual, field)
def _compare_field(self, expected: Any, actual: Any, doc: BaseDocument, field: str) -> None:
msg = f"{field} should be same."
if isinstance(expected, float):
precision = doc.precision(field)
self.assertAlmostEqual(
expected, actual, places=precision, msg=f"{field} should be same to {precision} digits"
)
elif isinstance(expected, bool | int):
self.assertEqual(expected, cint(actual), msg=msg)
elif isinstance(expected, datetime_like_types) or isinstance(actual, datetime_like_types):
self.assertEqual(str(expected), str(actual), msg=msg)
else:
self.assertEqual(expected, actual, msg=msg)
def normalize_html(self, code: str) -> str:
"""Formats HTML consistently so simple string comparisons can work on them."""
from bs4 import BeautifulSoup
return BeautifulSoup(code, "html.parser").prettify(formatter=None)
@contextmanager
def set_user(self, user: str) -> AbstractContextManager[None]:
try:
old_user = frappe.session.user
frappe.set_user(user)
yield
finally:
frappe.set_user(old_user)
def normalize_sql(self, query: str) -> str:
"""Formats SQL consistently so simple string comparisons can work on them."""
import sqlparse
return sqlparse.format(query.strip(), keyword_case="upper", reindent=True, strip_comments=True)
@classmethod
def enable_safe_exec(cls) -> None:
"""Enable safe exec and disable them after test case is completed."""
from frappe.installer import update_site_config
from frappe.utils.safe_exec import SAFE_EXEC_CONFIG_KEY
cls._common_conf = os.path.join(frappe.local.sites_path, "common_site_config.json")
update_site_config(SAFE_EXEC_CONFIG_KEY, 1, validate=False, site_config_path=cls._common_conf)
cls.addClassCleanup(
lambda: update_site_config(
SAFE_EXEC_CONFIG_KEY, 0, validate=False, site_config_path=cls._common_conf
)
)
@staticmethod
@contextmanager
def patch_hooks(overridden_hooks: dict) -> AbstractContextManager[None]:
get_hooks = frappe.get_hooks
def patched_hooks(hook=None, default="_KEEP_DEFAULT_LIST", app_name=None):
if hook in overridden_hooks:
return overridden_hooks[hook]
return get_hooks(hook, default, app_name)
with patch.object(frappe, "get_hooks", patched_hooks):
yield
@contextmanager
def freeze_time(
self, time_to_freeze: Any, is_utc: bool = False, *args: Any, **kwargs: Any
) -> AbstractContextManager[None]:
from freezegun import freeze_time
if not is_utc:
# Freeze time expects UTC or tzaware objects. We have neither, so convert to UTC.
timezone = pytz.timezone(get_system_timezone())
time_to_freeze = timezone.localize(get_datetime(time_to_freeze)).astimezone(pytz.utc)
with freeze_time(time_to_freeze, *args, **kwargs):
yield

View file

@ -33,9 +33,8 @@ import frappe.commands.utils
import frappe.recorder
from frappe.installer import add_to_installed_apps, remove_app
from frappe.query_builder.utils import db_type_is
from frappe.tests import IntegrationTestCase
from frappe.tests import IntegrationTestCase, timeout
from frappe.tests.test_query_builder import run_only_if
from frappe.tests.utils import timeout
from frappe.utils import add_to_date, get_bench_path, get_bench_relative_path, now
from frappe.utils.backups import BackupGenerator, fetch_latest_backups
from frappe.utils.jinja_globals import bundled_asset

View file

@ -14,9 +14,8 @@ from frappe.database.database import get_query_execution_timeout
from frappe.database.utils import FallBackDateTimeStr
from frappe.query_builder import Field
from frappe.query_builder.functions import Concat_ws
from frappe.tests import IntegrationTestCase
from frappe.tests import IntegrationTestCase, timeout
from frappe.tests.test_query_builder import db_type_is, run_only_if
from frappe.tests.utils import timeout
from frappe.utils import add_days, now, random_string, set_request
from frappe.utils.data import now_datetime
from frappe.utils.testutils import clear_custom_fields

View file

@ -8,8 +8,7 @@ import sqlparse
import frappe
import frappe.recorder
from frappe.recorder import normalize_query
from frappe.tests import IntegrationTestCase
from frappe.tests.utils import timeout
from frappe.tests import IntegrationTestCase, timeout
from frappe.utils import set_request
from frappe.utils.doctor import any_job_pending
from frappe.website.serve import get_response_content

View file

@ -1,757 +0,0 @@
import copy
import datetime
import functools
import json
import os
import pdb
import signal
import sys
import traceback
import unittest
from collections.abc import Callable, Sequence
from contextlib import AbstractContextManager, contextmanager
from functools import cache
from importlib import reload
from pathlib import Path
from typing import Any, Union
from unittest.mock import patch
import pytz
import frappe
from frappe.model.base_document import BaseDocument, get_controller
from frappe.model.naming import revert_series_if_last
from frappe.modules import load_doctype_module
from frappe.utils import cint
from frappe.utils.data import convert_utc_to_timezone, get_datetime, get_system_timezone
datetime_like_types = (datetime.datetime, datetime.date, datetime.time, datetime.timedelta)
import logging
logger = logging.Logger(__file__)
@cache
def get_modules(doctype):
"""Get the modules for the specified doctype"""
module = frappe.db.get_value("DocType", doctype, "module")
try:
test_module = load_doctype_module(doctype, module, "test_")
if test_module:
reload(test_module)
except ImportError:
test_module = None
return module, test_module
@cache
def get_dependencies(doctype):
"""Get the dependencies for the specified doctype"""
module, test_module = get_modules(doctype)
meta = frappe.get_meta(doctype)
link_fields = meta.get_link_fields()
for df in meta.get_table_fields():
link_fields.extend(frappe.get_meta(df.options).get_link_fields())
options_list = [df.options for df in link_fields]
if hasattr(test_module, "test_dependencies"):
options_list += test_module.test_dependencies
options_list = list(set(options_list))
if hasattr(test_module, "test_ignore"):
for doctype_name in test_module.test_ignore:
if doctype_name in options_list:
options_list.remove(doctype_name)
options_list.sort()
return options_list
# Test record generation
def make_test_records(doctype, force=False, commit=False):
return list(_make_test_records(doctype, force, commit))
def make_test_records_for_doctype(doctype, force=False, commit=False):
return list(_make_test_records_for_doctype(doctype, force, commit))
def make_test_objects(doctype, test_records=None, reset=False, commit=False):
return list(_make_test_objects(doctype, test_records, reset, commit))
def _make_test_records(doctype, force=False, commit=False):
"""Make test records for the specified doctype"""
loadme = False
if doctype not in frappe.local.test_objects:
loadme = True
frappe.local.test_objects[doctype] = [] # infinite recursion guard, here
# First, create test records for dependencies
for dependency in get_dependencies(doctype):
if dependency != "[Select]" and dependency not in frappe.local.test_objects:
yield from _make_test_records(dependency, force, commit)
# Then, create test records for the doctype itself
if loadme:
# Yield the doctype and record length
yield (
doctype,
len(
# Create all test records
list(_make_test_records_for_doctype(doctype, force, commit))
),
)
def _make_test_records_for_doctype(doctype, force=False, commit=False):
"""Make test records for the specified doctype"""
test_record_log_instance = TestRecordLog()
if not force and doctype in test_record_log_instance.get():
return
module, test_module = get_modules(doctype)
if hasattr(test_module, "_make_test_records"):
yield from test_module._make_test_records()
elif hasattr(test_module, "test_records"):
yield from _make_test_objects(doctype, test_module.test_records, force, commit=commit)
else:
test_records = frappe.get_test_records(doctype)
if test_records:
yield from _make_test_objects(doctype, test_records, force, commit=commit)
else:
print_mandatory_fields(doctype)
test_record_log_instance.add(doctype)
def _make_test_objects(doctype, test_records=None, reset=False, commit=False):
"""Generator function to make test objects"""
def revert_naming(d):
if getattr(d, "naming_series", None):
revert_series_if_last(d.naming_series, d.name)
if test_records is None:
test_records = frappe.get_test_records(doctype)
for doc in test_records:
if not reset:
frappe.db.savepoint("creating_test_record")
if not doc.get("doctype"):
doc["doctype"] = doctype
d = frappe.copy_doc(doc)
if d.meta.get_field("naming_series"):
if not d.naming_series:
d.naming_series = "_T-" + d.doctype + "-"
if doc.get("name"):
d.name = doc.get("name")
else:
d.set_new_name()
if frappe.db.exists(d.doctype, d.name) and not reset:
frappe.db.rollback(save_point="creating_test_record")
# do not create test records, if already exists
continue
# submit if docstatus is set to 1 for test record
docstatus = d.docstatus
d.docstatus = 0
try:
d.run_method("before_test_insert")
d.insert(ignore_if_duplicate=True)
if docstatus == 1:
d.submit()
except frappe.NameError:
revert_naming(d)
except Exception as e:
if (
d.flags.ignore_these_exceptions_in_test
and e.__class__ in d.flags.ignore_these_exceptions_in_test
):
revert_naming(d)
else:
logger.debug(f"Error in making test record for {d.doctype} {d.name}")
raise
if commit:
frappe.db.commit()
frappe.local.test_objects[doctype] += d.name
yield d.name
def print_mandatory_fields(doctype):
"""Print mandatory fields for the specified doctype"""
meta = frappe.get_meta(doctype)
logger.warning(f"Please setup make_test_records for: {doctype}")
logger.warning("-" * 60)
logger.warning(f"Autoname: {meta.autoname or ''}")
logger.warning("Mandatory Fields:")
for d in meta.get("fields", {"reqd": 1}):
logger.warning(f" - {d.parent}:{d.fieldname} | {d.fieldtype} | {d.options or ''}")
logger.warning("")
class TestRecordLog:
def __init__(self):
self.log_file = Path(frappe.get_site_path(".test_log"))
self._log = None
def get(self):
if self._log is None:
self._log = self._read_log()
return self._log
def add(self, doctype):
log = self.get()
if doctype not in log:
log.append(doctype)
self._write_log(log)
def _read_log(self):
if self.log_file.exists():
with self.log_file.open() as f:
return f.read().splitlines()
return []
def _write_log(self, log):
with self.log_file.open("w") as f:
f.write("\n".join(l for l in log if l is not None))
def debug_on(*exceptions):
"""
A decorator to automatically start the debugger when specified exceptions occur.
This decorator allows you to automatically invoke the debugger (pdb) when certain
exceptions are raised in the decorated function. If no exceptions are specified,
it defaults to catching AssertionError.
Args:
*exceptions: Variable length argument list of exception classes to catch.
If none provided, defaults to (AssertionError,).
Returns:
function: A decorator function.
Usage:
1. Basic usage (catches AssertionError):
@debug_on()
def test_assertion_error():
assert False, "This will start the debugger"
2. Catching specific exceptions:
@debug_on(ValueError, TypeError)
def test_specific_exceptions():
raise ValueError("This will start the debugger")
3. Using on a method in a test class:
class TestMyFunctionality(unittest.TestCase):
@debug_on(ZeroDivisionError)
def test_division_by_zero(self):
result = 1 / 0
Note:
When an exception is caught, this decorator will print the exception traceback
and then start the post-mortem debugger, allowing you to inspect the state of
the program at the point where the exception was raised.
"""
if not exceptions:
exceptions = (AssertionError,)
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except exceptions as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
# Pretty print the exception
print("\n\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[93m" + str(exc_type.__name__) + ": " + str(exc_value) + "\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
# Print the formatted traceback
traceback_lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
for line in traceback_lines:
print("\033[96m" + line.rstrip() + "\033[0m") # Cyan color
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
print("\033[92mEntering post-mortem debugging\033[0m")
print("\033[91m" + "=" * 60 + "\033[0m") # Red line
pdb.post_mortem()
raise e
return wrapper
return decorator
class UnitTestCase(unittest.TestCase):
"""Unit test class for Frappe tests.
This class extends unittest.TestCase and provides additional utilities
specific to Frappe framework. It's designed for testing individual
components or functions in isolation.
Key features:
- Custom assertions for Frappe-specific comparisons
- Utilities for HTML and SQL normalization
- Context managers for user switching and time freezing
Note: If you override `setUpClass`, make sure to call `super().setUpClass()`
to maintain the functionality of this base class.
"""
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.doctype = cls._get_doctype_from_module()
cls.module = frappe.get_module(cls.__module__)
@classmethod
def _get_doctype_from_module(cls):
module_path = cls.__module__.split(".")
try:
doctype_index = module_path.index("doctype")
doctype_snake_case = module_path[doctype_index + 1]
json_file_path = Path(*module_path[:-1]).joinpath(f"{doctype_snake_case}.json")
if json_file_path.is_file():
doctype_data = json.loads(json_file_path.read_text())
return doctype_data.get("name")
except (ValueError, IndexError):
# 'doctype' not found in module_path
pass
return None
def _apply_debug_decorator(self, exceptions=()):
setattr(self, self._testMethodName, debug_on(*exceptions)(getattr(self, self._testMethodName)))
def assertQueryEqual(self, first: str, second: str) -> None:
self.assertEqual(self.normalize_sql(first), self.normalize_sql(second))
def assertSequenceSubset(self, larger: Sequence, smaller: Sequence, msg: str | None = None) -> None:
"""Assert that `expected` is a subset of `actual`."""
self.assertTrue(set(smaller).issubset(set(larger)), msg=msg)
# --- Frappe Framework specific assertions
def assertDocumentEqual(self, expected: dict | BaseDocument, actual: BaseDocument) -> None:
"""Compare a (partial) expected document with actual Document."""
if isinstance(expected, BaseDocument):
expected = expected.as_dict()
for field, value in expected.items():
if isinstance(value, list):
actual_child_docs = actual.get(field)
self.assertEqual(len(value), len(actual_child_docs), msg=f"{field} length should be same")
for exp_child, actual_child in zip(value, actual_child_docs, strict=False):
self.assertDocumentEqual(exp_child, actual_child)
else:
self._compare_field(value, actual.get(field), actual, field)
def _compare_field(self, expected: Any, actual: Any, doc: BaseDocument, field: str) -> None:
msg = f"{field} should be same."
if isinstance(expected, float):
precision = doc.precision(field)
self.assertAlmostEqual(
expected, actual, places=precision, msg=f"{field} should be same to {precision} digits"
)
elif isinstance(expected, bool | int):
self.assertEqual(expected, cint(actual), msg=msg)
elif isinstance(expected, datetime_like_types) or isinstance(actual, datetime_like_types):
self.assertEqual(str(expected), str(actual), msg=msg)
else:
self.assertEqual(expected, actual, msg=msg)
def normalize_html(self, code: str) -> str:
"""Formats HTML consistently so simple string comparisons can work on them."""
from bs4 import BeautifulSoup
return BeautifulSoup(code, "html.parser").prettify(formatter=None)
@contextmanager
def set_user(self, user: str) -> AbstractContextManager[None]:
try:
old_user = frappe.session.user
frappe.set_user(user)
yield
finally:
frappe.set_user(old_user)
def normalize_sql(self, query: str) -> str:
"""Formats SQL consistently so simple string comparisons can work on them."""
import sqlparse
return sqlparse.format(query.strip(), keyword_case="upper", reindent=True, strip_comments=True)
@classmethod
def enable_safe_exec(cls) -> None:
"""Enable safe exec and disable them after test case is completed."""
from frappe.installer import update_site_config
from frappe.utils.safe_exec import SAFE_EXEC_CONFIG_KEY
cls._common_conf = os.path.join(frappe.local.sites_path, "common_site_config.json")
update_site_config(SAFE_EXEC_CONFIG_KEY, 1, validate=False, site_config_path=cls._common_conf)
cls.addClassCleanup(
lambda: update_site_config(
SAFE_EXEC_CONFIG_KEY, 0, validate=False, site_config_path=cls._common_conf
)
)
@staticmethod
@contextmanager
def patch_hooks(overridden_hooks: dict) -> AbstractContextManager[None]:
get_hooks = frappe.get_hooks
def patched_hooks(hook=None, default="_KEEP_DEFAULT_LIST", app_name=None):
if hook in overridden_hooks:
return overridden_hooks[hook]
return get_hooks(hook, default, app_name)
with patch.object(frappe, "get_hooks", patched_hooks):
yield
@contextmanager
def freeze_time(
self, time_to_freeze: Any, is_utc: bool = False, *args: Any, **kwargs: Any
) -> AbstractContextManager[None]:
from freezegun import freeze_time
if not is_utc:
# Freeze time expects UTC or tzaware objects. We have neither, so convert to UTC.
timezone = pytz.timezone(get_system_timezone())
time_to_freeze = timezone.localize(get_datetime(time_to_freeze)).astimezone(pytz.utc)
with freeze_time(time_to_freeze, *args, **kwargs):
yield
class IntegrationTestCase(UnitTestCase):
"""Integration test class for Frappe tests.
Key features:
- Automatic database setup and teardown
- Utilities for managing database connections
- Context managers for query counting and Redis call monitoring
- Lazy loading of test record dependencies
Note: If you override `setUpClass`, make sure to call `super().setUpClass()`
to maintain the functionality of this base class.
"""
TEST_SITE = "test_site"
SHOW_TRANSACTION_COMMIT_WARNINGS = False
maxDiff = 10_000 # prints long diffs but useful in CI
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
# Site initialization
cls.TEST_SITE = getattr(frappe.local, "site", None) or cls.TEST_SITE
frappe.init(cls.TEST_SITE)
cls.ADMIN_PASSWORD = frappe.get_conf(cls.TEST_SITE).admin_password
cls._primary_connection = frappe.local.db
cls._secondary_connection = None
# Create test record dependencies
cls._newly_created_test_records = []
if cls.doctype and cls.doctype not in frappe.local.test_objects:
cls._newly_created_test_records += make_test_records(cls.doctype)
for doctype in getattr(cls.module, "test_dependencies", []):
if doctype not in frappe.local.test_objects:
cls._newly_created_test_records += make_test_records(doctype)
# flush changes done so far to avoid flake
frappe.db.commit()
if cls.SHOW_TRANSACTION_COMMIT_WARNINGS:
frappe.db.before_commit.add(_commit_watcher)
# enqueue teardown actions (executed in LIFO order)
cls.addClassCleanup(_restore_thread_locals, copy.deepcopy(frappe.local.flags))
cls.addClassCleanup(_rollback_db)
@classmethod
def tearDownClass(cls) -> None:
# Add any necessary teardown code here
super().tearDownClass()
def setUp(self) -> None:
super().setUp()
# Add any per-test setup code here
def tearDown(self) -> None:
# Add any per-test teardown code here
super().tearDown()
@contextmanager
def primary_connection(self) -> AbstractContextManager[None]:
"""Switch to primary DB connection
This is used for simulating multiple users performing actions by simulating two DB connections"""
try:
current_conn = frappe.local.db
frappe.local.db = self._primary_connection
yield
finally:
frappe.local.db = current_conn
@contextmanager
def secondary_connection(self) -> AbstractContextManager[None]:
"""Switch to secondary DB connection."""
if self._secondary_connection is None:
frappe.connect() # get second connection
self._secondary_connection = frappe.local.db
try:
current_conn = frappe.local.db
frappe.local.db = self._secondary_connection
yield
finally:
frappe.local.db = current_conn
self.addCleanup(self._rollback_connections)
def _rollback_connections(self) -> None:
self._primary_connection.rollback()
self._secondary_connection.rollback()
@contextmanager
def assertQueryCount(self, count: int) -> AbstractContextManager[None]:
queries = []
def _sql_with_count(*args, **kwargs):
ret = orig_sql(*args, **kwargs)
queries.append(args[0].last_query)
return ret
try:
orig_sql = frappe.db.__class__.sql
frappe.db.__class__.sql = _sql_with_count
yield
self.assertLessEqual(len(queries), count, msg="Queries executed: \n" + "\n\n".join(queries))
finally:
frappe.db.__class__.sql = orig_sql
@contextmanager
def assertRedisCallCounts(self, count: int) -> AbstractContextManager[None]:
commands = []
def execute_command_and_count(*args, **kwargs):
ret = orig_execute(*args, **kwargs)
key_len = 2
if "H" in args[0]:
key_len = 3
commands.append((args)[:key_len])
return ret
try:
orig_execute = frappe.cache.execute_command
frappe.cache.execute_command = execute_command_and_count
yield
self.assertLessEqual(
len(commands), count, msg="commands executed: \n" + "\n".join(str(c) for c in commands)
)
finally:
frappe.cache.execute_command = orig_execute
@contextmanager
def assertRowsRead(self, count: int) -> AbstractContextManager[None]:
rows_read = 0
def _sql_with_count(*args, **kwargs):
nonlocal rows_read
ret = orig_sql(*args, **kwargs)
# count of last touched rows as per DB-API 2.0 https://peps.python.org/pep-0249/#rowcount
rows_read += cint(frappe.db._cursor.rowcount)
return ret
try:
orig_sql = frappe.db.sql
frappe.db.sql = _sql_with_count
yield
self.assertLessEqual(rows_read, count, msg="Queries read more rows than expected")
finally:
frappe.db.sql = orig_sql
@contextmanager
def switch_site(self, site: str) -> AbstractContextManager[None]:
"""Switch connection to different site.
Note: Drops current site connection completely."""
try:
old_site = frappe.local.site
frappe.init(site, force=True)
frappe.connect()
yield
finally:
frappe.init(old_site, force=True)
frappe.connect()
@staticmethod
@contextmanager
def change_settings(doctype, settings_dict=None, /, commit=False, **settings):
"""A context manager to ensure that settings are changed before running
function and restored after running it regardless of exceptions occurred.
This is useful in tests where you want to make changes in a function but
don't retain those changes.
import and use as decorator to cover full function or using `with` statement.
example:
@change_settings("Print Settings", {"send_print_as_pdf": 1})
def test_case(self):
...
@change_settings("Print Settings", send_print_as_pdf=1)
def test_case(self):
...
"""
if settings_dict is None:
settings_dict = settings
try:
settings = frappe.get_doc(doctype)
# remember setting
previous_settings = copy.deepcopy(settings_dict)
for key in previous_settings:
previous_settings[key] = getattr(settings, key)
# change setting
for key, value in settings_dict.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
# singles are cached by default, clear to avoid flake
frappe.db.value_cache[settings] = {}
if commit:
frappe.db.commit()
yield # yield control to calling function
finally:
# restore settings
settings = frappe.get_doc(doctype)
for key, value in previous_settings.items():
setattr(settings, key, value)
settings.save(ignore_permissions=True)
if commit:
frappe.db.commit()
# TODO: move to dumpster
FrappeTestCase = IntegrationTestCase
change_settings = IntegrationTestCase.change_settings
patch_hooks = UnitTestCase.patch_hooks
class MockedRequestTestCase(IntegrationTestCase):
def setUp(self):
import responses
self.responses = responses.RequestsMock()
self.responses.start()
self.addCleanup(self.responses.stop)
self.addCleanup(self.responses.reset)
return super().setUp()
def _commit_watcher():
import traceback
logger.warning("Transaction committed during tests.")
traceback.print_stack(limit=10)
def _rollback_db():
frappe.db.value_cache = {}
frappe.db.rollback()
def _restore_thread_locals(flags):
frappe.local.flags = flags
frappe.local.error_log = []
frappe.local.message_log = []
frappe.local.debug_log = []
frappe.local.conf = frappe._dict(frappe.get_site_config())
frappe.local.cache = {}
frappe.local.lang = "en"
frappe.local.preload_assets = {"style": [], "script": [], "icons": []}
if hasattr(frappe.local, "request"):
delattr(frappe.local, "request")
def timeout(seconds=30, error_message="Test timed out."):
"""Timeout decorator to ensure a test doesn't run for too long.
adapted from https://stackoverflow.com/a/2282656"""
# Support @timeout (without function call)
no_args = bool(callable(seconds))
actual_timeout = 30 if no_args else seconds
actual_error_message = "Test timed out" if no_args else error_message
def decorator(func):
def _handle_timeout(signum, frame):
raise Exception(actual_error_message)
def wrapper(*args, **kwargs):
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(actual_timeout)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
return wrapper
if no_args:
return decorator(seconds)
return decorator
def check_orpahned_doctypes():
"""Check that all doctypes in DB actually exist after patch test"""
doctypes = frappe.get_all("DocType", {"custom": 0}, pluck="name")
orpahned_doctypes = []
for doctype in doctypes:
try:
get_controller(doctype)
except ImportError:
orpahned_doctypes.append(doctype)
if orpahned_doctypes:
frappe.throw(
"Following doctypes exist in DB without controller.\n {}".format("\n".join(orpahned_doctypes))
)

View file

@ -0,0 +1,26 @@
import logging
import frappe
logger = logging.Logger(__file__)
from .generators import *
def check_orpahned_doctypes():
"""Check that all doctypes in DB actually exist after patch test"""
from frappe.model.base_document import get_controller
doctypes = frappe.get_all("DocType", {"custom": 0}, pluck="name")
orpahned_doctypes = []
for doctype in doctypes:
try:
get_controller(doctype)
except ImportError:
orpahned_doctypes.append(doctype)
if orpahned_doctypes:
frappe.throw(
"Following doctypes exist in DB without controller.\n {}".format("\n".join(orpahned_doctypes))
)

View file

@ -0,0 +1,229 @@
import datetime
import logging
from functools import cache
from importlib import reload
from pathlib import Path
import frappe
from frappe.model.naming import revert_series_if_last
from frappe.modules import load_doctype_module
logger = logging.getLogger(__name__)
datetime_like_types = (datetime.datetime, datetime.date, datetime.time, datetime.timedelta)
__all__ = [
"get_modules",
"get_dependencies",
"make_test_records",
"make_test_records_for_doctype",
"make_test_objects",
]
@cache
def get_modules(doctype):
"""Get the modules for the specified doctype"""
module = frappe.db.get_value("DocType", doctype, "module")
try:
test_module = load_doctype_module(doctype, module, "test_")
if test_module:
reload(test_module)
except ImportError:
test_module = None
return module, test_module
@cache
def get_dependencies(doctype):
"""Get the dependencies for the specified doctype"""
module, test_module = get_modules(doctype)
meta = frappe.get_meta(doctype)
link_fields = meta.get_link_fields()
for df in meta.get_table_fields():
link_fields.extend(frappe.get_meta(df.options).get_link_fields())
options_list = [df.options for df in link_fields]
if hasattr(test_module, "test_dependencies"):
options_list += test_module.test_dependencies
options_list = list(set(options_list))
if hasattr(test_module, "test_ignore"):
for doctype_name in test_module.test_ignore:
if doctype_name in options_list:
options_list.remove(doctype_name)
options_list.sort()
return options_list
# Test record generation
def make_test_records(doctype, force=False, commit=False):
return list(_make_test_records(doctype, force, commit))
def make_test_records_for_doctype(doctype, force=False, commit=False):
return list(_make_test_records_for_doctype(doctype, force, commit))
def make_test_objects(doctype, test_records=None, reset=False, commit=False):
return list(_make_test_objects(doctype, test_records, reset, commit))
def _make_test_records(doctype, force=False, commit=False):
"""Make test records for the specified doctype"""
loadme = False
if doctype not in frappe.local.test_objects:
loadme = True
frappe.local.test_objects[doctype] = [] # infinite recursion guard, here
# First, create test records for dependencies
for dependency in get_dependencies(doctype):
if dependency != "[Select]" and dependency not in frappe.local.test_objects:
yield from _make_test_records(dependency, force, commit)
# Then, create test records for the doctype itself
if loadme:
# Yield the doctype and record length
yield (
doctype,
len(
# Create all test records
list(_make_test_records_for_doctype(doctype, force, commit))
),
)
def _make_test_records_for_doctype(doctype, force=False, commit=False):
"""Make test records for the specified doctype"""
test_record_log_instance = TestRecordLog()
if not force and doctype in test_record_log_instance.get():
return
module, test_module = get_modules(doctype)
if hasattr(test_module, "_make_test_records"):
yield from test_module._make_test_records()
elif hasattr(test_module, "test_records"):
yield from _make_test_objects(doctype, test_module.test_records, force, commit=commit)
else:
test_records = frappe.get_test_records(doctype)
if test_records:
yield from _make_test_objects(doctype, test_records, force, commit=commit)
else:
print_mandatory_fields(doctype)
test_record_log_instance.add(doctype)
def _make_test_objects(doctype, test_records=None, reset=False, commit=False):
"""Generator function to make test objects"""
def revert_naming(d):
if getattr(d, "naming_series", None):
revert_series_if_last(d.naming_series, d.name)
if test_records is None:
test_records = frappe.get_test_records(doctype)
for doc in test_records:
if not reset:
frappe.db.savepoint("creating_test_record")
if not doc.get("doctype"):
doc["doctype"] = doctype
d = frappe.copy_doc(doc)
if d.meta.get_field("naming_series"):
if not d.naming_series:
d.naming_series = "_T-" + d.doctype + "-"
if doc.get("name"):
d.name = doc.get("name")
else:
d.set_new_name()
if frappe.db.exists(d.doctype, d.name) and not reset:
frappe.db.rollback(save_point="creating_test_record")
# do not create test records, if already exists
continue
# submit if docstatus is set to 1 for test record
docstatus = d.docstatus
d.docstatus = 0
try:
d.run_method("before_test_insert")
d.insert(ignore_if_duplicate=True)
if docstatus == 1:
d.submit()
except frappe.NameError:
revert_naming(d)
except Exception as e:
if (
d.flags.ignore_these_exceptions_in_test
and e.__class__ in d.flags.ignore_these_exceptions_in_test
):
revert_naming(d)
else:
logger.debug(f"Error in making test record for {d.doctype} {d.name}")
raise
if commit:
frappe.db.commit()
frappe.local.test_objects[doctype] += d.name
yield d.name
def print_mandatory_fields(doctype):
"""Print mandatory fields for the specified doctype"""
meta = frappe.get_meta(doctype)
logger.warning(f"Please setup make_test_records for: {doctype}")
logger.warning("-" * 60)
logger.warning(f"Autoname: {meta.autoname or ''}")
logger.warning("Mandatory Fields:")
for d in meta.get("fields", {"reqd": 1}):
logger.warning(f" - {d.parent}:{d.fieldname} | {d.fieldtype} | {d.options or ''}")
logger.warning("")
class TestRecordLog:
def __init__(self):
self.log_file = Path(frappe.get_site_path(".test_log"))
self._log = None
def get(self):
if self._log is None:
self._log = self._read_log()
return self._log
def add(self, doctype):
log = self.get()
if doctype not in log:
log.append(doctype)
self._write_log(log)
def _read_log(self):
if self.log_file.exists():
with self.log_file.open() as f:
return f.read().splitlines()
return []
def _write_log(self, log):
with self.log_file.open("w") as f:
f.write("\n".join(l for l in log if l is not None))