diff --git a/frappe/database/database.py b/frappe/database/database.py index e01780aabb..6072fe0f03 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -11,7 +11,7 @@ import warnings from collections.abc import Iterable, Sequence from contextlib import contextmanager, suppress from time import time -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, Any from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder @@ -22,13 +22,14 @@ from frappe.database.utils import ( DefaultOrderBy, EmptyQueryValues, FallBackDateTimeStr, + FilterValue, LazyMogrify, Query, QueryValues, + convert_to_value, is_query_type, ) from frappe.exceptions import DoesNotExistError, ImplicitCommitError -from frappe.model.document import DocRef from frappe.monitor import get_trace_id from frappe.query_builder.functions import Count from frappe.utils import CallbackManager, cint, get_datetime, get_table_name, getdate, now, sbool @@ -53,9 +54,6 @@ be ignored. Commit/Rollback from here WILL CAUSE very hard to debug problems wit concurrent data update bugs.""" -Stringable: TypeAlias = str | DocRef - - class Database: """ Open a database connection with the given parmeters, if use_default is True, use the @@ -479,7 +477,7 @@ class Database: def get_value( self, doctype: str, - filters: Stringable | dict | list | None = None, + filters: FilterValue | dict | list | None = None, fieldname: str | list[str] = "name", ignore: bool = False, as_dict: bool = False, @@ -558,7 +556,7 @@ class Database: def get_values( self, doctype: str, - filters: Stringable | dict | list | None = None, + filters: FilterValue | dict | list | None = None, fieldname: str | list[str] = "name", ignore: bool = False, as_dict: bool = False, @@ -596,8 +594,8 @@ class Database: """ out = None cache_key = None - if cache and isinstance(filters, Stringable): - cache_key = (doctype, str(filters), fieldname) + if cache and isinstance(filters, FilterValue): + cache_key = (doctype, convert_to_value(filters), fieldname) if cache_key in self.value_cache: return self.value_cache[cache_key] @@ -941,7 +939,7 @@ class Database: def set_value( self, dt: str, - dn: Stringable | dict, + dn: FilterValue | dict, field: str, val=None, modified=None, @@ -997,8 +995,8 @@ class Database: validate_filters=True, ) - if isinstance(dn, Stringable): - frappe.clear_document_cache(dt, str(dn)) + if isinstance(dn, FilterValue): + frappe.clear_document_cache(dt, convert_to_value(dn)) else: # No way to guess which documents are modified, clear all of them frappe.clear_document_cache(dt) diff --git a/frappe/database/query.py b/frappe/database/query.py index bd6eef1b0f..5248c9d5c1 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -10,8 +10,7 @@ import frappe from frappe import _ from frappe.database.operator_map import OPERATOR_MAP from frappe.database.schema import SPECIAL_CHAR_PATTERN -from frappe.database.utils import DefaultOrderBy, get_doctype_name -from frappe.model.document import DocRef +from frappe.database.utils import DefaultOrderBy, FilterValue, convert_to_value, get_doctype_name from frappe.query_builder import Criterion, Field, Order, functions from frappe.query_builder.functions import Function, SqlFunctions from frappe.query_builder.utils import PseudoColumnMapper @@ -30,8 +29,6 @@ COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))") # to allow table names like __Auth TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII) -FilterValue: TypeAlias = DocRef | str | int - class Engine: def get_query( @@ -122,7 +119,7 @@ class Engine: return if isinstance(filters, FilterValue): - filters = {"name": str(filters)} + filters = {"name": convert_to_value(filters)} if isinstance(filters, Criterion): self.query = self.query.where(filters) @@ -132,7 +129,7 @@ class Engine: elif isinstance(filters, list | tuple): if all(isinstance(d, FilterValue) for d in filters) and len(filters) > 0: - self.apply_dict_filters({"name": ("in", filters)}) + self.apply_dict_filters({"name": ("in", tuple(convert_to_value(f) for f in filters))}) else: for filter in filters: if isinstance(filter, FilterValue | Criterion | dict): @@ -162,7 +159,7 @@ class Engine: def _apply_filter( self, field: str, - value: FilterValue | list | None, + value: FilterValue | list | set | None, operator: str = "=", doctype: str | None = None, ): @@ -192,13 +189,9 @@ class Engine: (table.parent == self.table.name) & (table.parenttype == self.doctype) ) - if isinstance(_value, bool): - _value = int(_value) + _value = convert_to_value(_value) - if isinstance(_value, DocRef): - _value = str(_value) - - elif not _value and isinstance(_value, list | tuple): + if not _value and isinstance(_value, list | tuple | set): _value = ("",) # Nested set diff --git a/frappe/database/utils.py b/frappe/database/utils.py index 64ae6b9865..bb936605d5 100644 --- a/frappe/database/utils.py +++ b/frappe/database/utils.py @@ -4,11 +4,13 @@ from functools import cached_property, wraps import frappe +from frappe.model.document import DocRef from frappe.query_builder.builder import MariaDB, Postgres from frappe.query_builder.functions import Function Query = str | MariaDB | Postgres QueryValues = tuple | list | dict | None +FilterValue = DocRef | str | int | bool EmptyQueryValues = object() FallBackDateTimeStr = "0001-01-01 00:00:00.000000" @@ -22,6 +24,14 @@ NestedSetHierarchy = ( ) +def convert_to_value(o: FilterValue): + if hasattr(o, "__value__"): + return o.__value__() + if isinstance(o, bool): + return int(o) + return o + + def is_query_type(query: str, query_type: str | tuple[str, ...]) -> bool: return query.lstrip().split(maxsplit=1)[0].lower().startswith(query_type) diff --git a/frappe/model/base_document.py b/frappe/model/base_document.py index 1ff268dd16..4fe1fa2fde 100644 --- a/frappe/model/base_document.py +++ b/frappe/model/base_document.py @@ -146,6 +146,9 @@ class BaseDocument: if hasattr(self, "__setup__"): self.__setup__() + def __json__(self): + return self.as_dict(no_nulls=True) + @cached_property def meta(self): return frappe.get_meta(self.doctype) @@ -433,8 +436,8 @@ class BaseDocument: else: value = get_not_null_defaults(df.fieldtype) - if isinstance(value, DocRef): - value = str(value) + if hasattr(value, "__value__"): + value = value.__value__() d[fieldname] = value diff --git a/frappe/model/document.py b/frappe/model/document.py index 682ec02134..6cf7b34323 100644 --- a/frappe/model/document.py +++ b/frappe/model/document.py @@ -45,11 +45,16 @@ class DocRef: self.doctype = doctype self.name = name - def __str__(self): - # ! Used in frappe's query engine in frappe/database/query.py - # ! Keep it stable + def __value__(self): + # Used when requiring its value representation for db interactions, serializations, etc return self.name + def __str__(self): + return f"{self.doctype} ({self.name or 'n/a'})" + + def __repr__(self): + return f"<{self.__class__.__name__}: doctype={self.doctype} name={self.name or 'n/a'}>" + @simple_singledispatch def get_doc(*args, **kwargs) -> "Document": @@ -1789,14 +1794,16 @@ class Document(BaseDocument, DocRef): doc = self.get_valid_dict(convert_dates_to_str=True, ignore_virtual=True) deferred_insert(doctype=self.doctype, records=doc) - def __repr__(self): - name = self.name or "unsaved" - doctype = self.__class__.__name__ + def __str__(self): + return f"{self.doctype} ({self.name or 'unsaved'})" + def __repr__(self): + doctype = f"doctype={self.doctype}" + name = self.name or "unsaved" docstatus = f" docstatus={self.docstatus}" if self.docstatus else "" parent = f" parent={self.parent}" if getattr(self, "parent", None) else "" - return f"<{doctype}: {name}{docstatus}{parent}>" + return f"<{self.__class__.__name__}: {doctype} {name}{docstatus}{parent}>" def execute_action(__doctype, __name, __action, **kwargs): diff --git a/frappe/tests/test_doc_ref.py b/frappe/tests/test_doc_ref.py index 47deb87b00..1422a457d4 100644 --- a/frappe/tests/test_doc_ref.py +++ b/frappe/tests/test_doc_ref.py @@ -60,10 +60,10 @@ class TestDocRef(IntegrationTestCase): self.assertTrue("first_name" in [f.fieldname for f in meta.fields]) self.assertTrue("last_name" in [f.fieldname for f in meta.fields]) - def test_doc_ref_str_representation(self): - # Test the string representation of DocRef + def test_doc_ref_value_representation(self): + # Test the value representation of DocRef doc_ref = DocRef("User", "test@example.com") - self.assertEqual(str(doc_ref), "test@example.com") + self.assertEqual(doc_ref.__value__(), "test@example.com") def test_doc_ref_attributes(self): # Test DocRef attributes diff --git a/frappe/utils/response.py b/frappe/utils/response.py index b07fd49fa1..240833cffe 100644 --- a/frappe/utils/response.py +++ b/frappe/utils/response.py @@ -219,12 +219,6 @@ def json_handler(obj): elif isinstance(obj, LocalProxy): return str(obj) - elif isinstance(obj, frappe.model.document.BaseDocument): - return obj.as_dict(no_nulls=True) - - elif isinstance(obj, frappe.model.document.DocRef): # if not BaseDocument, but DocRef - return str(obj) - elif isinstance(obj, Iterable): return list(obj) @@ -246,6 +240,9 @@ def json_handler(obj): elif hasattr(obj, "__json__"): return obj.__json__() + elif hasattr(obj, "__value__"): # order imporant: defer to __json__ if implemented + return obj.__value__() + else: raise TypeError(f"""Object of type {type(obj)} with value of {obj!r} is not JSON serializable""")