fix: don't force values into the string type (#28185)

This commit is contained in:
David Arnold 2024-10-19 21:00:25 +02:00 committed by GitHub
parent 75b58802ad
commit 2abba7b51b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 51 additions and 43 deletions

View file

@ -11,7 +11,7 @@ import warnings
from collections.abc import Iterable, Sequence
from contextlib import contextmanager, suppress
from time import time
from typing import TYPE_CHECKING, Any, TypeAlias
from typing import TYPE_CHECKING, Any
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
@ -22,13 +22,14 @@ from frappe.database.utils import (
DefaultOrderBy,
EmptyQueryValues,
FallBackDateTimeStr,
FilterValue,
LazyMogrify,
Query,
QueryValues,
convert_to_value,
is_query_type,
)
from frappe.exceptions import DoesNotExistError, ImplicitCommitError
from frappe.model.document import DocRef
from frappe.monitor import get_trace_id
from frappe.query_builder.functions import Count
from frappe.utils import CallbackManager, cint, get_datetime, get_table_name, getdate, now, sbool
@ -53,9 +54,6 @@ be ignored. Commit/Rollback from here WILL CAUSE very hard to debug problems wit
concurrent data update bugs."""
Stringable: TypeAlias = str | DocRef
class Database:
"""
Open a database connection with the given parmeters, if use_default is True, use the
@ -479,7 +477,7 @@ class Database:
def get_value(
self,
doctype: str,
filters: Stringable | dict | list | None = None,
filters: FilterValue | dict | list | None = None,
fieldname: str | list[str] = "name",
ignore: bool = False,
as_dict: bool = False,
@ -558,7 +556,7 @@ class Database:
def get_values(
self,
doctype: str,
filters: Stringable | dict | list | None = None,
filters: FilterValue | dict | list | None = None,
fieldname: str | list[str] = "name",
ignore: bool = False,
as_dict: bool = False,
@ -596,8 +594,8 @@ class Database:
"""
out = None
cache_key = None
if cache and isinstance(filters, Stringable):
cache_key = (doctype, str(filters), fieldname)
if cache and isinstance(filters, FilterValue):
cache_key = (doctype, convert_to_value(filters), fieldname)
if cache_key in self.value_cache:
return self.value_cache[cache_key]
@ -941,7 +939,7 @@ class Database:
def set_value(
self,
dt: str,
dn: Stringable | dict,
dn: FilterValue | dict,
field: str,
val=None,
modified=None,
@ -997,8 +995,8 @@ class Database:
validate_filters=True,
)
if isinstance(dn, Stringable):
frappe.clear_document_cache(dt, str(dn))
if isinstance(dn, FilterValue):
frappe.clear_document_cache(dt, convert_to_value(dn))
else:
# No way to guess which documents are modified, clear all of them
frappe.clear_document_cache(dt)

View file

@ -10,8 +10,7 @@ import frappe
from frappe import _
from frappe.database.operator_map import OPERATOR_MAP
from frappe.database.schema import SPECIAL_CHAR_PATTERN
from frappe.database.utils import DefaultOrderBy, get_doctype_name
from frappe.model.document import DocRef
from frappe.database.utils import DefaultOrderBy, FilterValue, convert_to_value, get_doctype_name
from frappe.query_builder import Criterion, Field, Order, functions
from frappe.query_builder.functions import Function, SqlFunctions
from frappe.query_builder.utils import PseudoColumnMapper
@ -30,8 +29,6 @@ COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
# to allow table names like __Auth
TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII)
FilterValue: TypeAlias = DocRef | str | int
class Engine:
def get_query(
@ -122,7 +119,7 @@ class Engine:
return
if isinstance(filters, FilterValue):
filters = {"name": str(filters)}
filters = {"name": convert_to_value(filters)}
if isinstance(filters, Criterion):
self.query = self.query.where(filters)
@ -132,7 +129,7 @@ class Engine:
elif isinstance(filters, list | tuple):
if all(isinstance(d, FilterValue) for d in filters) and len(filters) > 0:
self.apply_dict_filters({"name": ("in", filters)})
self.apply_dict_filters({"name": ("in", tuple(convert_to_value(f) for f in filters))})
else:
for filter in filters:
if isinstance(filter, FilterValue | Criterion | dict):
@ -162,7 +159,7 @@ class Engine:
def _apply_filter(
self,
field: str,
value: FilterValue | list | None,
value: FilterValue | list | set | None,
operator: str = "=",
doctype: str | None = None,
):
@ -192,13 +189,9 @@ class Engine:
(table.parent == self.table.name) & (table.parenttype == self.doctype)
)
if isinstance(_value, bool):
_value = int(_value)
_value = convert_to_value(_value)
if isinstance(_value, DocRef):
_value = str(_value)
elif not _value and isinstance(_value, list | tuple):
if not _value and isinstance(_value, list | tuple | set):
_value = ("",)
# Nested set

View file

@ -4,11 +4,13 @@
from functools import cached_property, wraps
import frappe
from frappe.model.document import DocRef
from frappe.query_builder.builder import MariaDB, Postgres
from frappe.query_builder.functions import Function
Query = str | MariaDB | Postgres
QueryValues = tuple | list | dict | None
FilterValue = DocRef | str | int | bool
EmptyQueryValues = object()
FallBackDateTimeStr = "0001-01-01 00:00:00.000000"
@ -22,6 +24,14 @@ NestedSetHierarchy = (
)
def convert_to_value(o: FilterValue):
if hasattr(o, "__value__"):
return o.__value__()
if isinstance(o, bool):
return int(o)
return o
def is_query_type(query: str, query_type: str | tuple[str, ...]) -> bool:
return query.lstrip().split(maxsplit=1)[0].lower().startswith(query_type)

View file

@ -146,6 +146,9 @@ class BaseDocument:
if hasattr(self, "__setup__"):
self.__setup__()
def __json__(self):
return self.as_dict(no_nulls=True)
@cached_property
def meta(self):
return frappe.get_meta(self.doctype)
@ -433,8 +436,8 @@ class BaseDocument:
else:
value = get_not_null_defaults(df.fieldtype)
if isinstance(value, DocRef):
value = str(value)
if hasattr(value, "__value__"):
value = value.__value__()
d[fieldname] = value

View file

@ -45,11 +45,16 @@ class DocRef:
self.doctype = doctype
self.name = name
def __str__(self):
# ! Used in frappe's query engine in frappe/database/query.py
# ! Keep it stable
def __value__(self):
# Used when requiring its value representation for db interactions, serializations, etc
return self.name
def __str__(self):
return f"{self.doctype} ({self.name or 'n/a'})"
def __repr__(self):
return f"<{self.__class__.__name__}: doctype={self.doctype} name={self.name or 'n/a'}>"
@simple_singledispatch
def get_doc(*args, **kwargs) -> "Document":
@ -1789,14 +1794,16 @@ class Document(BaseDocument, DocRef):
doc = self.get_valid_dict(convert_dates_to_str=True, ignore_virtual=True)
deferred_insert(doctype=self.doctype, records=doc)
def __repr__(self):
name = self.name or "unsaved"
doctype = self.__class__.__name__
def __str__(self):
return f"{self.doctype} ({self.name or 'unsaved'})"
def __repr__(self):
doctype = f"doctype={self.doctype}"
name = self.name or "unsaved"
docstatus = f" docstatus={self.docstatus}" if self.docstatus else ""
parent = f" parent={self.parent}" if getattr(self, "parent", None) else ""
return f"<{doctype}: {name}{docstatus}{parent}>"
return f"<{self.__class__.__name__}: {doctype} {name}{docstatus}{parent}>"
def execute_action(__doctype, __name, __action, **kwargs):

View file

@ -60,10 +60,10 @@ class TestDocRef(IntegrationTestCase):
self.assertTrue("first_name" in [f.fieldname for f in meta.fields])
self.assertTrue("last_name" in [f.fieldname for f in meta.fields])
def test_doc_ref_str_representation(self):
# Test the string representation of DocRef
def test_doc_ref_value_representation(self):
# Test the value representation of DocRef
doc_ref = DocRef("User", "test@example.com")
self.assertEqual(str(doc_ref), "test@example.com")
self.assertEqual(doc_ref.__value__(), "test@example.com")
def test_doc_ref_attributes(self):
# Test DocRef attributes

View file

@ -219,12 +219,6 @@ def json_handler(obj):
elif isinstance(obj, LocalProxy):
return str(obj)
elif isinstance(obj, frappe.model.document.BaseDocument):
return obj.as_dict(no_nulls=True)
elif isinstance(obj, frappe.model.document.DocRef): # if not BaseDocument, but DocRef
return str(obj)
elif isinstance(obj, Iterable):
return list(obj)
@ -246,6 +240,9 @@ def json_handler(obj):
elif hasattr(obj, "__json__"):
return obj.__json__()
elif hasattr(obj, "__value__"): # order imporant: defer to __json__ if implemented
return obj.__value__()
else:
raise TypeError(f"""Object of type {type(obj)} with value of {obj!r} is not JSON serializable""")