feat: docref identifier / proxy (#27973)

* feat: add DocRef

* feat: Add comprehensive test cases for DocRef functionality

* chore(db): add field type hints

* fix: ensure document stringer fulfills the DocRef contract
This commit is contained in:
David Arnold 2024-10-19 06:10:26 +02:00 committed by GitHub
parent 232f45cfd5
commit 7348572af8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 175 additions and 60 deletions

View file

@ -11,7 +11,7 @@ import warnings
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from time import time from time import time
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, TypeAlias
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
@ -28,6 +28,7 @@ from frappe.database.utils import (
is_query_type, is_query_type,
) )
from frappe.exceptions import DoesNotExistError, ImplicitCommitError from frappe.exceptions import DoesNotExistError, ImplicitCommitError
from frappe.model.document import DocRef
from frappe.monitor import get_trace_id from frappe.monitor import get_trace_id
from frappe.query_builder.functions import Count from frappe.query_builder.functions import Count
from frappe.utils import CallbackManager, cint, get_datetime, get_table_name, getdate, now, sbool from frappe.utils import CallbackManager, cint, get_datetime, get_table_name, getdate, now, sbool
@ -52,6 +53,9 @@ be ignored. Commit/Rollback from here WILL CAUSE very hard to debug problems wit
concurrent data update bugs.""" concurrent data update bugs."""
Stringable: TypeAlias = str | DocRef
class Database: class Database:
""" """
Open a database connection with the given parmeters, if use_default is True, use the Open a database connection with the given parmeters, if use_default is True, use the
@ -474,21 +478,21 @@ class Database:
def get_value( def get_value(
self, self,
doctype, doctype: str,
filters=None, filters: Stringable | dict | list | None = None,
fieldname="name", fieldname: str | list[str] = "name",
ignore=None, ignore: bool = False,
as_dict=False, as_dict: bool = False,
debug=False, debug: bool = False,
order_by=DefaultOrderBy, order_by: str = DefaultOrderBy,
cache=False, cache: bool = False,
for_update=False, for_update: bool = False,
*, *,
run=True, run: bool = True,
pluck=False, pluck: bool = False,
distinct=False, distinct: bool = False,
skip_locked=False, skip_locked: bool = False,
wait=True, wait: bool = True,
): ):
"""Return a document property or list of properties. """Return a document property or list of properties.
@ -553,23 +557,23 @@ class Database:
def get_values( def get_values(
self, self,
doctype, doctype: str,
filters=None, filters: Stringable | dict | list | None = None,
fieldname="name", fieldname: str | list[str] = "name",
ignore=None, ignore: bool = False,
as_dict=False, as_dict: bool = False,
debug=False, debug: bool = False,
order_by=DefaultOrderBy, order_by: str = DefaultOrderBy,
update=None, update: dict | None = None,
cache=False, cache: bool = False,
for_update=False, for_update: bool = False,
*, *,
run=True, run: bool = True,
pluck=False, pluck: bool = False,
distinct=False, distinct: bool = False,
limit=None, limit: int | None = None,
skip_locked=False, skip_locked: bool = False,
wait=True, wait: bool = True,
): ):
"""Return multiple document properties. """Return multiple document properties.
@ -591,8 +595,11 @@ class Database:
user = frappe.db.get_values("User", "test@example.com", "*")[0] user = frappe.db.get_values("User", "test@example.com", "*")[0]
""" """
out = None out = None
if cache and isinstance(filters, str) and (doctype, filters, fieldname) in self.value_cache: cache_key = None
return self.value_cache[(doctype, filters, fieldname)] if cache and isinstance(filters, Stringable):
cache_key = (doctype, str(filters), fieldname)
if cache_key in self.value_cache:
return self.value_cache[cache_key]
if distinct: if distinct:
order_by = None order_by = None
@ -660,8 +667,8 @@ class Database:
fields, filters, doctype, as_dict, debug, update, run=run, pluck=pluck, distinct=distinct fields, filters, doctype, as_dict, debug, update, run=run, pluck=pluck, distinct=distinct
) )
if cache and isinstance(filters, str): if cache and cache_key:
self.value_cache[(doctype, filters, fieldname)] = out self.value_cache[cache_key] = out
return out return out
@ -820,7 +827,7 @@ class Database:
if doctype in self.value_cache: if doctype in self.value_cache:
del self.value_cache[doctype] del self.value_cache[doctype]
def get_single_value(self, doctype, fieldname, cache=True): def get_single_value(self, doctype: str, fieldname: str, cache: bool = True):
"""Get property of Single DocType. Cache locally by default """Get property of Single DocType. Cache locally by default
:param doctype: DocType of the single object whose value is requested :param doctype: DocType of the single object whose value is requested
@ -933,9 +940,9 @@ class Database:
def set_value( def set_value(
self, self,
dt, dt: str,
dn, dn: Stringable | dict,
field, field: str,
val=None, val=None,
modified=None, modified=None,
modified_by=None, modified_by=None,
@ -990,8 +997,8 @@ class Database:
validate_filters=True, validate_filters=True,
) )
if isinstance(dn, str): if isinstance(dn, Stringable):
frappe.clear_document_cache(dt, dn) frappe.clear_document_cache(dt, str(dn))
else: else:
# No way to guess which documents are modified, clear all of them # No way to guess which documents are modified, clear all of them
frappe.clear_document_cache(dt) frappe.clear_document_cache(dt)

View file

@ -1,7 +1,7 @@
import re import re
from ast import literal_eval from ast import literal_eval
from types import BuiltinFunctionType from types import BuiltinFunctionType
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, TypeAlias
import sqlparse import sqlparse
from pypika.queries import QueryBuilder, Table from pypika.queries import QueryBuilder, Table
@ -11,6 +11,7 @@ from frappe import _
from frappe.database.operator_map import OPERATOR_MAP from frappe.database.operator_map import OPERATOR_MAP
from frappe.database.schema import SPECIAL_CHAR_PATTERN from frappe.database.schema import SPECIAL_CHAR_PATTERN
from frappe.database.utils import DefaultOrderBy, get_doctype_name from frappe.database.utils import DefaultOrderBy, get_doctype_name
from frappe.model.document import DocRef
from frappe.query_builder import Criterion, Field, Order, functions from frappe.query_builder import Criterion, Field, Order, functions
from frappe.query_builder.functions import Function, SqlFunctions from frappe.query_builder.functions import Function, SqlFunctions
from frappe.query_builder.utils import PseudoColumnMapper from frappe.query_builder.utils import PseudoColumnMapper
@ -29,13 +30,15 @@ COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
# to allow table names like __Auth # to allow table names like __Auth
TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII) TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII)
FilterValue: TypeAlias = DocRef | str | int
class Engine: class Engine:
def get_query( def get_query(
self, self,
table: str | Table, table: str | Table,
fields: list | tuple | None = None, fields: list | tuple | None = None,
filters: dict[str, str | int] | str | int | list[list | str | int] | None = None, filters: dict[str, FilterValue] | FilterValue | list[list | FilterValue] | None = None,
order_by: str | None = None, order_by: str | None = None,
group_by: str | None = None, group_by: str | None = None,
limit: int | None = None, limit: int | None = None,
@ -113,12 +116,12 @@ class Engine:
def apply_filters( def apply_filters(
self, self,
filters: dict[str, str | int] | str | int | list[list | str | int] | None = None, filters: dict[str, FilterValue] | FilterValue | list[list | FilterValue] | None = None,
): ):
if filters is None: if filters is None:
return return
if isinstance(filters, str | int): if isinstance(filters, FilterValue):
filters = {"name": str(filters)} filters = {"name": str(filters)}
if isinstance(filters, Criterion): if isinstance(filters, Criterion):
@ -128,11 +131,11 @@ class Engine:
self.apply_dict_filters(filters) self.apply_dict_filters(filters)
elif isinstance(filters, list | tuple): elif isinstance(filters, list | tuple):
if all(isinstance(d, str | int) for d in filters) and len(filters) > 0: if all(isinstance(d, FilterValue) for d in filters) and len(filters) > 0:
self.apply_dict_filters({"name": ("in", filters)}) self.apply_dict_filters({"name": ("in", filters)})
else: else:
for filter in filters: for filter in filters:
if isinstance(filter, str | int | Criterion | dict): if isinstance(filter, FilterValue | Criterion | dict):
self.apply_filters(filter) self.apply_filters(filter)
elif isinstance(filter, list | tuple): elif isinstance(filter, list | tuple):
self.apply_list_filters(filter) self.apply_list_filters(filter)
@ -148,7 +151,7 @@ class Engine:
doctype, field, operator, value = filter doctype, field, operator, value = filter
self._apply_filter(field, value, operator, doctype) self._apply_filter(field, value, operator, doctype)
def apply_dict_filters(self, filters: dict[str, str | int | list]): def apply_dict_filters(self, filters: dict[str, FilterValue | list]):
for field, value in filters.items(): for field, value in filters.items():
operator = "=" operator = "="
if isinstance(value, list | tuple): if isinstance(value, list | tuple):
@ -157,7 +160,11 @@ class Engine:
self._apply_filter(field, value, operator) self._apply_filter(field, value, operator)
def _apply_filter( def _apply_filter(
self, field: str, value: str | int | list | None, operator: str = "=", doctype: str | None = None self,
field: str,
value: FilterValue | list | None,
operator: str = "=",
doctype: str | None = None,
): ):
_field = field _field = field
_value = value _value = value
@ -188,6 +195,9 @@ class Engine:
if isinstance(_value, bool): if isinstance(_value, bool):
_value = int(_value) _value = int(_value)
if isinstance(_value, DocRef):
_value = str(_value)
elif not _value and isinstance(_value, list | tuple): elif not _value and isinstance(_value, list | tuple):
_value = ("",) _value = ("",)

View file

@ -365,6 +365,8 @@ class BaseDocument:
def get_valid_dict( def get_valid_dict(
self, sanitize=True, convert_dates_to_str=False, ignore_nulls=False, ignore_virtual=False self, sanitize=True, convert_dates_to_str=False, ignore_nulls=False, ignore_virtual=False
) -> _dict: ) -> _dict:
from frappe.model.document import DocRef
d = _dict() d = _dict()
field_values = self.__dict__ field_values = self.__dict__
@ -431,6 +433,9 @@ class BaseDocument:
else: else:
value = get_not_null_defaults(df.fieldtype) value = get_not_null_defaults(df.fieldtype)
if isinstance(value, DocRef):
value = str(value)
d[fieldname] = value d[fieldname] = value
return d return d

View file

@ -38,6 +38,19 @@ DOCUMENT_LOCK_EXPIRTY = 12 * 60 * 60 # All locks expire in 12 hours automatical
DOCUMENT_LOCK_SOFT_EXPIRY = 60 * 60 # Let users force-unlock after 60 minutes DOCUMENT_LOCK_SOFT_EXPIRY = 60 * 60 # Let users force-unlock after 60 minutes
class DocRef:
"""A lightweight reference to a document, containing just the doctype and name."""
def __init__(self, doctype: str, name: str):
self.doctype = doctype
self.name = name
def __str__(self):
# ! Used in frappe's query engine in frappe/database/query.py
# ! Keep it stable
return self.name
@simple_singledispatch @simple_singledispatch
def get_doc(*args, **kwargs) -> "Document": def get_doc(*args, **kwargs) -> "Document":
"""Return a `frappe.model.Document` object. """Return a `frappe.model.Document` object.
@ -77,6 +90,11 @@ def _basedoc(doc: BaseDocument, *args, **kwargs) -> "Document":
return doc return doc
@get_doc.register(DocRef)
def _docref(doc_ref: DocRef, **kwargs) -> "Document":
return get_doc(doc_ref.doctype, doc_ref.name, **kwargs)
@get_doc.register(str) @get_doc.register(str)
def get_doc_str(doctype: str, name: str | None = None, **kwargs) -> "Document": def get_doc_str(doctype: str, name: str | None = None, **kwargs) -> "Document":
# if no name: it's a single # if no name: it's a single
@ -157,7 +175,7 @@ def read_only_document(context=None):
del frappe.local.read_only_depth del frappe.local.read_only_depth
class Document(BaseDocument): class Document(BaseDocument, DocRef):
"""All controllers inherit from `Document`.""" """All controllers inherit from `Document`."""
doctype: DF.Data doctype: DF.Data
@ -172,7 +190,7 @@ class Document(BaseDocument):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Constructor. """Constructor.
:param arg1: DocType name as string or document **dict** :param arg1: DocType name as string, document **dict**, or DocRef object
:param arg2: Document name, if `arg1` is DocType name. :param arg2: Document name, if `arg1` is DocType name.
If DocType name and document name are passed, the object will load If DocType name and document name are passed, the object will load
@ -214,6 +232,10 @@ class Document(BaseDocument):
name = doctype if not args else args[0] name = doctype if not args else args[0]
self._init_known_doc(doctype, name, **kwargs) self._init_known_doc(doctype, name, **kwargs)
@_init_dispatch.register(DocRef)
def _init_docref(self, doc_ref, **kwargs):
self._init_known_doc(doc_ref.doctype, doc_ref.name, **kwargs)
@_init_dispatch.register(dict) @_init_dispatch.register(dict)
def _init_dict(self, arg_dict, **kwargs): def _init_dict(self, arg_dict, **kwargs):
# discard any further keyword args # discard any further keyword args
@ -1776,12 +1798,6 @@ class Document(BaseDocument):
return f"<{doctype}: {name}{docstatus}{parent}>" return f"<{doctype}: {name}{docstatus}{parent}>"
def __str__(self):
name = self.name or "unsaved"
doctype = self.__class__.__name__
return f"{doctype}({name})"
def execute_action(__doctype, __name, __action, **kwargs): def execute_action(__doctype, __name, __action, **kwargs):
"""Execute an action on a document (called by background worker)""" """Execute an action on a document (called by background worker)"""

View file

@ -38,7 +38,7 @@ from frappe.model.base_document import (
TABLE_DOCTYPES_FOR_DOCTYPE, TABLE_DOCTYPES_FOR_DOCTYPE,
BaseDocument, BaseDocument,
) )
from frappe.model.document import Document from frappe.model.document import DocRef, Document
from frappe.model.workflow import get_workflow_name from frappe.model.workflow import get_workflow_name
from frappe.modules import load_doctype_module from frappe.modules import load_doctype_module
from frappe.utils import cast, cint, cstr from frappe.utils import cast, cint, cstr
@ -58,11 +58,11 @@ DEFAULT_FIELD_LABELS = {
} }
def get_meta(doctype: str | Document, cached=True) -> "_Meta": def get_meta(doctype: str | DocRef | Document, cached=True) -> "_Meta":
"""Get metadata for a doctype. """Get metadata for a doctype.
Args: Args:
doctype: The doctype as a string or Document object. doctype: The doctype as a string, DocRef, or Document object.
cached: Whether to use cached metadata (default: True). cached: Whether to use cached metadata (default: True).
Returns: Returns:
@ -132,6 +132,11 @@ class Meta(Document):
super().__init__("DocType", doctype) super().__init__("DocType", doctype)
self.process() self.process()
@__init__.register(DocRef)
def _(self, doc_ref):
super().__init__("DocType", doc_ref.doctype)
self.process()
@__init__.register(Document) @__init__.register(Document)
def _(self, doc): def _(self, doc):
super().__init__(doc.as_dict()) super().__init__(doc.as_dict())

View file

@ -0,0 +1,72 @@
import frappe
from frappe.model.document import DocRef, Document, get_doc
from frappe.tests import IntegrationTestCase
EXTRA_TEST_RECORD_DEPENDENCIES = ["User"]
class TestDocRef(IntegrationTestCase):
def test_doc_ref_get_doc(self):
# Test using DocRef with get_doc
doc_ref = DocRef("User", "test@example.com")
user = get_doc(doc_ref)
# Assert that user is an instance of both Document and DocRef
self.assertIsInstance(user, Document)
self.assertIsInstance(user, DocRef)
# Check more attributes
self.assertEqual(user.doctype, "User")
self.assertEqual(user.name, "test@example.com")
self.assertEqual(user.email, "test@example.com")
self.assertEqual(user.first_name, "_Test")
def test_doc_ref_in_query(self):
# Test using DocRef in a database query
user = frappe.get_doc("User", "test@example.com")
# Assert that user is an instance of both Document and DocRef
self.assertIsInstance(user, Document)
self.assertIsInstance(user, DocRef)
# Create a test document that references the user
test_doc = frappe.get_doc(
{
"doctype": "ToDo",
"description": "Test ToDo",
"reference_type": "User",
"reference_name": user, # This should work with DocRef
}
).insert()
# Getter using the DocRef
result = frappe.db.get_value("ToDo", {"reference_name": user}, ["name", "description"])
self.assertEqual(result[0], test_doc.name)
self.assertEqual(result[1], "Test ToDo")
# Setter using Document as DocRef
frappe.db.set_value("ToDo", test_doc, "description", "Revised Test ToDo")
test_doc.reload()
self.assertEqual(test_doc.description, "Revised Test ToDo")
def test_get_meta_with_doc_ref(self):
# Test get_meta with DocRef
doc_ref = DocRef("User", "test@example.com")
meta = frappe.get_meta(doc_ref)
# Check more attributes of the meta
self.assertEqual(meta.name, "User")
self.assertEqual(meta.module, "Core")
self.assertTrue("email" in [f.fieldname for f in meta.fields])
self.assertTrue("first_name" in [f.fieldname for f in meta.fields])
self.assertTrue("last_name" in [f.fieldname for f in meta.fields])
def test_doc_ref_str_representation(self):
# Test the string representation of DocRef
doc_ref = DocRef("User", "test@example.com")
self.assertEqual(str(doc_ref), "test@example.com")
def test_doc_ref_attributes(self):
# Test DocRef attributes
doc_ref = DocRef("User", "test@example.com")
self.assertEqual(doc_ref.doctype, "User")
self.assertEqual(doc_ref.name, "test@example.com")