From 0addffafb949feb604e5afb01fc319986036f1d1 Mon Sep 17 00:00:00 2001 From: Aradhya Date: Wed, 17 Aug 2022 19:52:51 +0530 Subject: [PATCH] refactor: minor changes --- frappe/database/query.py | 43 ++++++++++++++++++++++++-------------- frappe/database/utils.py | 8 ------- frappe/share.py | 2 +- frappe/tests/test_query.py | 2 +- 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/frappe/database/query.py b/frappe/database/query.py index 91163fb607..12c878a567 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -3,18 +3,21 @@ import re from ast import literal_eval from functools import cached_property from types import BuiltinFunctionType -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder import frappe from frappe import _ -from frappe.database.utils import is_function_object, table_from_string +from frappe.database.utils import is_function_object from frappe.model.db_query import get_timespan_date_range from frappe.query_builder import Criterion, Field, Order, Table, functions from frappe.query_builder.functions import Function, SqlFunctions from frappe.query_builder.utils import PseudoColumn +if TYPE_CHECKING: + from frappe.query_builder import DocType + TAB_PATTERN = re.compile("^tab") WORDS_PATTERN = re.compile(r"\w+") BRACKETS_PATTERN = re.compile(r"\(.*?\)|$") @@ -64,7 +67,7 @@ def not_like(key: Field, value: str) -> frappe.qb: return key.not_like(value) -def func_not_in(key: Field, value: list | tuple): +def func_not_in(key: Field, value: list | tuple | str): """Wrapper method for `NOT IN` Args: @@ -158,6 +161,14 @@ def has_function(field): return True +def table_from_string(table: str) -> "DocType": + table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:] + if "`" in table_name: + return frappe.qb.DocType(table_name=table_name.replace("`", "")) + else: + return frappe.qb.DocType(table_name=table_name) + + # default operators OPERATOR_MAP: dict[str, Callable] = { "+": operator.add, @@ -378,7 +389,7 @@ class Engine: return criterion - def make_function_for_filters(self, key: Any, value: int | str): + def make_function_for_filters(self, key, value: int | str): value = list(value) if isinstance(value[1], str) and has_function(value[1]): value[1] = self.get_function_object(value[1]) @@ -445,14 +456,19 @@ class Engine: def remove_string_functions(self, fields, function_objects): """Remove string functions from fields which have already been converted to function objects""" + + def _remove_string_aliasing(function, fields: list | str): + if function.alias: + to_replace = " as " + function.alias.casefold() + if to_replace in fields: + fields = fields.replace(to_replace, "") + elif " as " + f"`{function.alias.casefold()}" in fields: + fields = fields.replace(" as " + f"`{function.alias.casefold()}`", "") + return fields + for function in function_objects: if isinstance(fields, str): - if function.alias: - to_replace = " as " + function.alias.casefold() - if to_replace in fields: - fields = fields.replace(to_replace, "") - elif " as " + f"`{function.alias.casefold()}" in fields: - fields = fields.replace(" as " + f"`{function.alias.casefold()}`", "") + fields = _remove_string_aliasing(function, fields) fields = BRACKETS_PATTERN.sub("", fields.casefold().replace(function.name.casefold(), "")) # Converting back to capitalized doctype names. if "tab" in fields: @@ -466,12 +482,7 @@ class Engine: updated_fields = [] for field in fields: if isinstance(field, str): - if function.alias: - to_replace = " as " + function.alias.casefold() - if to_replace in field: - field = field.replace(to_replace, "") - elif " as " + f"`{function.alias.casefold()}" in field: - field = field.replace(" as " + f"`{function.alias.casefold()}`", "") + field = _remove_string_aliasing(function, field) substituted_string = ( BRACKETS_PATTERN.sub("", field).strip().casefold() if "`" not in field diff --git a/frappe/database/utils.py b/frappe/database/utils.py index 2a23ca485e..1d44b358fe 100644 --- a/frappe/database/utils.py +++ b/frappe/database/utils.py @@ -23,14 +23,6 @@ def is_query_type(query: str, query_type: str | tuple[str]) -> bool: return query.lstrip().split(maxsplit=1)[0].lower().startswith(query_type) -def table_from_string(table: str) -> "DocType": - table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:] - if "`" in table_name: - return frappe.qb.DocType(table_name=table_name.replace("`", "")) - else: - return frappe.qb.DocType(table_name=table_name) - - def is_function_object(field: str) -> bool: return getattr(field, "__module__", None) == "pypika.functions" or isinstance(field, Function) diff --git a/frappe/share.py b/frappe/share.py index 7f7ca2033f..97c9a472e6 100644 --- a/frappe/share.py +++ b/frappe/share.py @@ -104,7 +104,7 @@ def get_users(doctype, name): return frappe.db.get_all( "DocShare", fields=[ - "name", # Don't understant the need for pseudocolumns here, don't know why get_all supports it? + "name", "user", "read", "write", diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index 7ffd111901..ce819930aa 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -189,7 +189,7 @@ class TestQuery(unittest.TestCase): frappe.qb.from_("User").select(Max(Field("name"))).where(Ifnull("name", "") < Now()).run(), ) - def test_indirect_join_query(self): + def test_implicit_join_query(self): self.assertEqual( frappe.qb.engine.get_query( "Note",