diff --git a/frappe/database/database.py b/frappe/database/database.py index b86c38fe3a..e01780aabb 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -11,7 +11,7 @@ import warnings from collections.abc import Iterable, Sequence from contextlib import contextmanager, suppress from time import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeAlias from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder @@ -28,6 +28,7 @@ from frappe.database.utils import ( is_query_type, ) from frappe.exceptions import DoesNotExistError, ImplicitCommitError +from frappe.model.document import DocRef from frappe.monitor import get_trace_id from frappe.query_builder.functions import Count from frappe.utils import CallbackManager, cint, get_datetime, get_table_name, getdate, now, sbool @@ -52,6 +53,9 @@ be ignored. Commit/Rollback from here WILL CAUSE very hard to debug problems wit concurrent data update bugs.""" +Stringable: TypeAlias = str | DocRef + + class Database: """ Open a database connection with the given parmeters, if use_default is True, use the @@ -474,21 +478,21 @@ class Database: def get_value( self, - doctype, - filters=None, - fieldname="name", - ignore=None, - as_dict=False, - debug=False, - order_by=DefaultOrderBy, - cache=False, - for_update=False, + doctype: str, + filters: Stringable | dict | list | None = None, + fieldname: str | list[str] = "name", + ignore: bool = False, + as_dict: bool = False, + debug: bool = False, + order_by: str = DefaultOrderBy, + cache: bool = False, + for_update: bool = False, *, - run=True, - pluck=False, - distinct=False, - skip_locked=False, - wait=True, + run: bool = True, + pluck: bool = False, + distinct: bool = False, + skip_locked: bool = False, + wait: bool = True, ): """Return a document property or list of properties. @@ -553,23 +557,23 @@ class Database: def get_values( self, - doctype, - filters=None, - fieldname="name", - ignore=None, - as_dict=False, - debug=False, - order_by=DefaultOrderBy, - update=None, - cache=False, - for_update=False, + doctype: str, + filters: Stringable | dict | list | None = None, + fieldname: str | list[str] = "name", + ignore: bool = False, + as_dict: bool = False, + debug: bool = False, + order_by: str = DefaultOrderBy, + update: dict | None = None, + cache: bool = False, + for_update: bool = False, *, - run=True, - pluck=False, - distinct=False, - limit=None, - skip_locked=False, - wait=True, + run: bool = True, + pluck: bool = False, + distinct: bool = False, + limit: int | None = None, + skip_locked: bool = False, + wait: bool = True, ): """Return multiple document properties. @@ -591,8 +595,11 @@ class Database: user = frappe.db.get_values("User", "test@example.com", "*")[0] """ out = None - if cache and isinstance(filters, str) and (doctype, filters, fieldname) in self.value_cache: - return self.value_cache[(doctype, filters, fieldname)] + cache_key = None + if cache and isinstance(filters, Stringable): + cache_key = (doctype, str(filters), fieldname) + if cache_key in self.value_cache: + return self.value_cache[cache_key] if distinct: order_by = None @@ -660,8 +667,8 @@ class Database: fields, filters, doctype, as_dict, debug, update, run=run, pluck=pluck, distinct=distinct ) - if cache and isinstance(filters, str): - self.value_cache[(doctype, filters, fieldname)] = out + if cache and cache_key: + self.value_cache[cache_key] = out return out @@ -820,7 +827,7 @@ class Database: if doctype in self.value_cache: del self.value_cache[doctype] - def get_single_value(self, doctype, fieldname, cache=True): + def get_single_value(self, doctype: str, fieldname: str, cache: bool = True): """Get property of Single DocType. Cache locally by default :param doctype: DocType of the single object whose value is requested @@ -933,9 +940,9 @@ class Database: def set_value( self, - dt, - dn, - field, + dt: str, + dn: Stringable | dict, + field: str, val=None, modified=None, modified_by=None, @@ -990,8 +997,8 @@ class Database: validate_filters=True, ) - if isinstance(dn, str): - frappe.clear_document_cache(dt, dn) + if isinstance(dn, Stringable): + frappe.clear_document_cache(dt, str(dn)) else: # No way to guess which documents are modified, clear all of them frappe.clear_document_cache(dt) diff --git a/frappe/database/query.py b/frappe/database/query.py index 2a28375428..bd6eef1b0f 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -1,7 +1,7 @@ import re from ast import literal_eval from types import BuiltinFunctionType -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeAlias import sqlparse from pypika.queries import QueryBuilder, Table @@ -11,6 +11,7 @@ from frappe import _ from frappe.database.operator_map import OPERATOR_MAP from frappe.database.schema import SPECIAL_CHAR_PATTERN from frappe.database.utils import DefaultOrderBy, get_doctype_name +from frappe.model.document import DocRef from frappe.query_builder import Criterion, Field, Order, functions from frappe.query_builder.functions import Function, SqlFunctions from frappe.query_builder.utils import PseudoColumnMapper @@ -29,13 +30,15 @@ COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))") # to allow table names like __Auth TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII) +FilterValue: TypeAlias = DocRef | str | int + class Engine: def get_query( self, table: str | Table, fields: list | tuple | None = None, - filters: dict[str, str | int] | str | int | list[list | str | int] | None = None, + filters: dict[str, FilterValue] | FilterValue | list[list | FilterValue] | None = None, order_by: str | None = None, group_by: str | None = None, limit: int | None = None, @@ -113,12 +116,12 @@ class Engine: def apply_filters( self, - filters: dict[str, str | int] | str | int | list[list | str | int] | None = None, + filters: dict[str, FilterValue] | FilterValue | list[list | FilterValue] | None = None, ): if filters is None: return - if isinstance(filters, str | int): + if isinstance(filters, FilterValue): filters = {"name": str(filters)} if isinstance(filters, Criterion): @@ -128,11 +131,11 @@ class Engine: self.apply_dict_filters(filters) elif isinstance(filters, list | tuple): - if all(isinstance(d, str | int) for d in filters) and len(filters) > 0: + if all(isinstance(d, FilterValue) for d in filters) and len(filters) > 0: self.apply_dict_filters({"name": ("in", filters)}) else: for filter in filters: - if isinstance(filter, str | int | Criterion | dict): + if isinstance(filter, FilterValue | Criterion | dict): self.apply_filters(filter) elif isinstance(filter, list | tuple): self.apply_list_filters(filter) @@ -148,7 +151,7 @@ class Engine: doctype, field, operator, value = filter self._apply_filter(field, value, operator, doctype) - def apply_dict_filters(self, filters: dict[str, str | int | list]): + def apply_dict_filters(self, filters: dict[str, FilterValue | list]): for field, value in filters.items(): operator = "=" if isinstance(value, list | tuple): @@ -157,7 +160,11 @@ class Engine: self._apply_filter(field, value, operator) def _apply_filter( - self, field: str, value: str | int | list | None, operator: str = "=", doctype: str | None = None + self, + field: str, + value: FilterValue | list | None, + operator: str = "=", + doctype: str | None = None, ): _field = field _value = value @@ -188,6 +195,9 @@ class Engine: if isinstance(_value, bool): _value = int(_value) + if isinstance(_value, DocRef): + _value = str(_value) + elif not _value and isinstance(_value, list | tuple): _value = ("",) diff --git a/frappe/model/base_document.py b/frappe/model/base_document.py index f949fda568..1ff268dd16 100644 --- a/frappe/model/base_document.py +++ b/frappe/model/base_document.py @@ -365,6 +365,8 @@ class BaseDocument: def get_valid_dict( self, sanitize=True, convert_dates_to_str=False, ignore_nulls=False, ignore_virtual=False ) -> _dict: + from frappe.model.document import DocRef + d = _dict() field_values = self.__dict__ @@ -431,6 +433,9 @@ class BaseDocument: else: value = get_not_null_defaults(df.fieldtype) + if isinstance(value, DocRef): + value = str(value) + d[fieldname] = value return d diff --git a/frappe/model/document.py b/frappe/model/document.py index 0e582a8616..682ec02134 100644 --- a/frappe/model/document.py +++ b/frappe/model/document.py @@ -38,6 +38,19 @@ DOCUMENT_LOCK_EXPIRTY = 12 * 60 * 60 # All locks expire in 12 hours automatical DOCUMENT_LOCK_SOFT_EXPIRY = 60 * 60 # Let users force-unlock after 60 minutes +class DocRef: + """A lightweight reference to a document, containing just the doctype and name.""" + + def __init__(self, doctype: str, name: str): + self.doctype = doctype + self.name = name + + def __str__(self): + # ! Used in frappe's query engine in frappe/database/query.py + # ! Keep it stable + return self.name + + @simple_singledispatch def get_doc(*args, **kwargs) -> "Document": """Return a `frappe.model.Document` object. @@ -77,6 +90,11 @@ def _basedoc(doc: BaseDocument, *args, **kwargs) -> "Document": return doc +@get_doc.register(DocRef) +def _docref(doc_ref: DocRef, **kwargs) -> "Document": + return get_doc(doc_ref.doctype, doc_ref.name, **kwargs) + + @get_doc.register(str) def get_doc_str(doctype: str, name: str | None = None, **kwargs) -> "Document": # if no name: it's a single @@ -157,7 +175,7 @@ def read_only_document(context=None): del frappe.local.read_only_depth -class Document(BaseDocument): +class Document(BaseDocument, DocRef): """All controllers inherit from `Document`.""" doctype: DF.Data @@ -172,7 +190,7 @@ class Document(BaseDocument): def __init__(self, *args, **kwargs): """Constructor. - :param arg1: DocType name as string or document **dict** + :param arg1: DocType name as string, document **dict**, or DocRef object :param arg2: Document name, if `arg1` is DocType name. If DocType name and document name are passed, the object will load @@ -214,6 +232,10 @@ class Document(BaseDocument): name = doctype if not args else args[0] self._init_known_doc(doctype, name, **kwargs) + @_init_dispatch.register(DocRef) + def _init_docref(self, doc_ref, **kwargs): + self._init_known_doc(doc_ref.doctype, doc_ref.name, **kwargs) + @_init_dispatch.register(dict) def _init_dict(self, arg_dict, **kwargs): # discard any further keyword args @@ -1776,12 +1798,6 @@ class Document(BaseDocument): return f"<{doctype}: {name}{docstatus}{parent}>" - def __str__(self): - name = self.name or "unsaved" - doctype = self.__class__.__name__ - - return f"{doctype}({name})" - def execute_action(__doctype, __name, __action, **kwargs): """Execute an action on a document (called by background worker)""" diff --git a/frappe/model/meta.py b/frappe/model/meta.py index c79854d3cc..78d30f1c0e 100644 --- a/frappe/model/meta.py +++ b/frappe/model/meta.py @@ -38,7 +38,7 @@ from frappe.model.base_document import ( TABLE_DOCTYPES_FOR_DOCTYPE, BaseDocument, ) -from frappe.model.document import Document +from frappe.model.document import DocRef, Document from frappe.model.workflow import get_workflow_name from frappe.modules import load_doctype_module from frappe.utils import cast, cint, cstr @@ -58,11 +58,11 @@ DEFAULT_FIELD_LABELS = { } -def get_meta(doctype: str | Document, cached=True) -> "_Meta": +def get_meta(doctype: str | DocRef | Document, cached=True) -> "_Meta": """Get metadata for a doctype. Args: - doctype: The doctype as a string or Document object. + doctype: The doctype as a string, DocRef, or Document object. cached: Whether to use cached metadata (default: True). Returns: @@ -132,6 +132,11 @@ class Meta(Document): super().__init__("DocType", doctype) self.process() + @__init__.register(DocRef) + def _(self, doc_ref): + super().__init__("DocType", doc_ref.doctype) + self.process() + @__init__.register(Document) def _(self, doc): super().__init__(doc.as_dict()) diff --git a/frappe/tests/test_doc_ref.py b/frappe/tests/test_doc_ref.py new file mode 100644 index 0000000000..47deb87b00 --- /dev/null +++ b/frappe/tests/test_doc_ref.py @@ -0,0 +1,72 @@ +import frappe +from frappe.model.document import DocRef, Document, get_doc +from frappe.tests import IntegrationTestCase + +EXTRA_TEST_RECORD_DEPENDENCIES = ["User"] + + +class TestDocRef(IntegrationTestCase): + def test_doc_ref_get_doc(self): + # Test using DocRef with get_doc + doc_ref = DocRef("User", "test@example.com") + user = get_doc(doc_ref) + + # Assert that user is an instance of both Document and DocRef + self.assertIsInstance(user, Document) + self.assertIsInstance(user, DocRef) + + # Check more attributes + self.assertEqual(user.doctype, "User") + self.assertEqual(user.name, "test@example.com") + self.assertEqual(user.email, "test@example.com") + self.assertEqual(user.first_name, "_Test") + + def test_doc_ref_in_query(self): + # Test using DocRef in a database query + user = frappe.get_doc("User", "test@example.com") + + # Assert that user is an instance of both Document and DocRef + self.assertIsInstance(user, Document) + self.assertIsInstance(user, DocRef) + + # Create a test document that references the user + test_doc = frappe.get_doc( + { + "doctype": "ToDo", + "description": "Test ToDo", + "reference_type": "User", + "reference_name": user, # This should work with DocRef + } + ).insert() + + # Getter using the DocRef + result = frappe.db.get_value("ToDo", {"reference_name": user}, ["name", "description"]) + self.assertEqual(result[0], test_doc.name) + self.assertEqual(result[1], "Test ToDo") + # Setter using Document as DocRef + frappe.db.set_value("ToDo", test_doc, "description", "Revised Test ToDo") + test_doc.reload() + self.assertEqual(test_doc.description, "Revised Test ToDo") + + def test_get_meta_with_doc_ref(self): + # Test get_meta with DocRef + doc_ref = DocRef("User", "test@example.com") + meta = frappe.get_meta(doc_ref) + + # Check more attributes of the meta + self.assertEqual(meta.name, "User") + self.assertEqual(meta.module, "Core") + self.assertTrue("email" in [f.fieldname for f in meta.fields]) + self.assertTrue("first_name" in [f.fieldname for f in meta.fields]) + self.assertTrue("last_name" in [f.fieldname for f in meta.fields]) + + def test_doc_ref_str_representation(self): + # Test the string representation of DocRef + doc_ref = DocRef("User", "test@example.com") + self.assertEqual(str(doc_ref), "test@example.com") + + def test_doc_ref_attributes(self): + # Test DocRef attributes + doc_ref = DocRef("User", "test@example.com") + self.assertEqual(doc_ref.doctype, "User") + self.assertEqual(doc_ref.name, "test@example.com")