refactor: minor changes

This commit is contained in:
Aradhya 2022-08-17 19:52:51 +05:30
parent 40bfad9aeb
commit 0addffafb9
4 changed files with 29 additions and 26 deletions

View file

@ -3,18 +3,21 @@ import re
from ast import literal_eval
from functools import cached_property
from types import BuiltinFunctionType
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
import frappe
from frappe import _
from frappe.database.utils import is_function_object, table_from_string
from frappe.database.utils import is_function_object
from frappe.model.db_query import get_timespan_date_range
from frappe.query_builder import Criterion, Field, Order, Table, functions
from frappe.query_builder.functions import Function, SqlFunctions
from frappe.query_builder.utils import PseudoColumn
if TYPE_CHECKING:
from frappe.query_builder import DocType
TAB_PATTERN = re.compile("^tab")
WORDS_PATTERN = re.compile(r"\w+")
BRACKETS_PATTERN = re.compile(r"\(.*?\)|$")
@ -64,7 +67,7 @@ def not_like(key: Field, value: str) -> frappe.qb:
return key.not_like(value)
def func_not_in(key: Field, value: list | tuple):
def func_not_in(key: Field, value: list | tuple | str):
"""Wrapper method for `NOT IN`
Args:
@ -158,6 +161,14 @@ def has_function(field):
return True
def table_from_string(table: str) -> "DocType":
table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:]
if "`" in table_name:
return frappe.qb.DocType(table_name=table_name.replace("`", ""))
else:
return frappe.qb.DocType(table_name=table_name)
# default operators
OPERATOR_MAP: dict[str, Callable] = {
"+": operator.add,
@ -378,7 +389,7 @@ class Engine:
return criterion
def make_function_for_filters(self, key: Any, value: int | str):
def make_function_for_filters(self, key, value: int | str):
value = list(value)
if isinstance(value[1], str) and has_function(value[1]):
value[1] = self.get_function_object(value[1])
@ -445,14 +456,19 @@ class Engine:
def remove_string_functions(self, fields, function_objects):
"""Remove string functions from fields which have already been converted to function objects"""
def _remove_string_aliasing(function, fields: list | str):
if function.alias:
to_replace = " as " + function.alias.casefold()
if to_replace in fields:
fields = fields.replace(to_replace, "")
elif " as " + f"`{function.alias.casefold()}" in fields:
fields = fields.replace(" as " + f"`{function.alias.casefold()}`", "")
return fields
for function in function_objects:
if isinstance(fields, str):
if function.alias:
to_replace = " as " + function.alias.casefold()
if to_replace in fields:
fields = fields.replace(to_replace, "")
elif " as " + f"`{function.alias.casefold()}" in fields:
fields = fields.replace(" as " + f"`{function.alias.casefold()}`", "")
fields = _remove_string_aliasing(function, fields)
fields = BRACKETS_PATTERN.sub("", fields.casefold().replace(function.name.casefold(), ""))
# Converting back to capitalized doctype names.
if "tab" in fields:
@ -466,12 +482,7 @@ class Engine:
updated_fields = []
for field in fields:
if isinstance(field, str):
if function.alias:
to_replace = " as " + function.alias.casefold()
if to_replace in field:
field = field.replace(to_replace, "")
elif " as " + f"`{function.alias.casefold()}" in field:
field = field.replace(" as " + f"`{function.alias.casefold()}`", "")
field = _remove_string_aliasing(function, field)
substituted_string = (
BRACKETS_PATTERN.sub("", field).strip().casefold()
if "`" not in field

View file

@ -23,14 +23,6 @@ def is_query_type(query: str, query_type: str | tuple[str]) -> bool:
return query.lstrip().split(maxsplit=1)[0].lower().startswith(query_type)
def table_from_string(table: str) -> "DocType":
table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:]
if "`" in table_name:
return frappe.qb.DocType(table_name=table_name.replace("`", ""))
else:
return frappe.qb.DocType(table_name=table_name)
def is_function_object(field: str) -> bool:
return getattr(field, "__module__", None) == "pypika.functions" or isinstance(field, Function)

View file

@ -104,7 +104,7 @@ def get_users(doctype, name):
return frappe.db.get_all(
"DocShare",
fields=[
"name", # Don't understant the need for pseudocolumns here, don't know why get_all supports it?
"name",
"user",
"read",
"write",

View file

@ -189,7 +189,7 @@ class TestQuery(unittest.TestCase):
frappe.qb.from_("User").select(Max(Field("name"))).where(Ifnull("name", "") < Now()).run(),
)
def test_indirect_join_query(self):
def test_implicit_join_query(self):
self.assertEqual(
frappe.qb.engine.get_query(
"Note",