fix: refactor

- move operator map in separate file
- remove unnecessary code
- organize functions
This commit is contained in:
Faris Ansari 2023-01-10 16:15:08 +05:30
parent 49471922e4
commit fe13108eec
3 changed files with 208 additions and 241 deletions

View file

@ -0,0 +1,138 @@
# Copyright (c) 2023, Frappe Technologies Pvt. Ltd. and Contributors
# MIT License. See license.txt
import operator
from typing import Callable
import frappe
from frappe.database.utils import NestedSetHierarchy
from frappe.model.db_query import get_timespan_date_range
from frappe.query_builder import Field
def like(key: Field, value: str) -> frappe.qb:
"""Wrapper method for `LIKE`
Args:
key (str): field
value (str): criterion
Returns:
frappe.qb: `frappe.qb object with `LIKE`
"""
return key.like(value)
def func_in(key: Field, value: list | tuple) -> frappe.qb:
"""Wrapper method for `IN`
Args:
key (str): field
value (Union[int, str]): criterion
Returns:
frappe.qb: `frappe.qb object with `IN`
"""
if isinstance(value, str):
value = value.split(",")
return key.isin(value)
def not_like(key: Field, value: str) -> frappe.qb:
"""Wrapper method for `NOT LIKE`
Args:
key (str): field
value (str): criterion
Returns:
frappe.qb: `frappe.qb object with `NOT LIKE`
"""
return key.not_like(value)
def func_not_in(key: Field, value: list | tuple | str):
"""Wrapper method for `NOT IN`
Args:
key (str): field
value (Union[int, str]): criterion
Returns:
frappe.qb: `frappe.qb object with `NOT IN`
"""
if isinstance(value, str):
value = value.split(",")
return key.notin(value)
def func_regex(key: Field, value: str) -> frappe.qb:
"""Wrapper method for `REGEX`
Args:
key (str): field
value (str): criterion
Returns:
frappe.qb: `frappe.qb object with `REGEX`
"""
return key.regex(value)
def func_between(key: Field, value: list | tuple) -> frappe.qb:
"""Wrapper method for `BETWEEN`
Args:
key (str): field
value (Union[int, str]): criterion
Returns:
frappe.qb: `frappe.qb object with `BETWEEN`
"""
return key[slice(*value)]
def func_is(key, value):
"Wrapper for IS"
return key.isnotnull() if value.lower() == "set" else key.isnull()
def func_timespan(key: Field, value: str) -> frappe.qb:
"""Wrapper method for `TIMESPAN`
Args:
key (str): field
value (str): criterion
Returns:
frappe.qb: `frappe.qb object with `TIMESPAN`
"""
return func_between(key, get_timespan_date_range(value))
# default operators
OPERATOR_MAP: dict[str, Callable] = {
"+": operator.add,
"=": operator.eq,
"-": operator.sub,
"!=": operator.ne,
"<": operator.lt,
">": operator.gt,
"<=": operator.le,
"=<": operator.le,
">=": operator.ge,
"=>": operator.ge,
"/": operator.truediv,
"*": operator.mul,
"in": func_in,
"not in": func_not_in,
"like": like,
"not like": not_like,
"regex": func_regex,
"between": func_between,
"is": func_is,
"timespan": func_timespan,
"nested_set": NestedSetHierarchy,
# TODO: Add support for custom operators (WIP) - via filters_config hooks
}

View file

@ -1,20 +1,16 @@
import itertools
import operator
import re
from ast import literal_eval
from functools import cached_property
from types import BuiltinFunctionType
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING
import sqlparse
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
from pypika.queries import QueryBuilder
import frappe
from frappe import _
from frappe.database.utils import NestedSetHierarchy, is_pypika_function_object
from frappe.model.db_query import get_timespan_date_range
from frappe.query_builder import Criterion, Field, Order, Table, functions
from frappe.database.operator_map import OPERATOR_MAP
from frappe.query_builder import Criterion, Field, Order, functions
from frappe.query_builder.functions import Function, SqlFunctions
from frappe.query_builder.utils import PseudoColumnMapper
from frappe.utils.data import MARIADB_SPECIFIC_COMMENT
@ -29,186 +25,12 @@ SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions]
COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
def like(key: Field, value: str) -> frappe.qb:
"""Wrapper method for `LIKE`
Args:
key (str): field
value (str): criterion
Returns:
frappe.qb: `frappe.qb object with `LIKE`
"""
return key.like(value)
def func_in(key: Field, value: list | tuple) -> frappe.qb:
"""Wrapper method for `IN`
Args:
key (str): field
value (Union[int, str]): criterion
Returns:
frappe.qb: `frappe.qb object with `IN`
"""
if isinstance(value, str):
value = value.split(",")
return key.isin(value)
def not_like(key: Field, value: str) -> frappe.qb:
"""Wrapper method for `NOT LIKE`
Args:
key (str): field
value (str): criterion
Returns:
frappe.qb: `frappe.qb object with `NOT LIKE`
"""
return key.not_like(value)
def func_not_in(key: Field, value: list | tuple | str):
"""Wrapper method for `NOT IN`
Args:
key (str): field
value (Union[int, str]): criterion
Returns:
frappe.qb: `frappe.qb object with `NOT IN`
"""
if isinstance(value, str):
value = value.split(",")
return key.notin(value)
def func_regex(key: Field, value: str) -> frappe.qb:
"""Wrapper method for `REGEX`
Args:
key (str): field
value (str): criterion
Returns:
frappe.qb: `frappe.qb object with `REGEX`
"""
return key.regex(value)
def func_between(key: Field, value: list | tuple) -> frappe.qb:
"""Wrapper method for `BETWEEN`
Args:
key (str): field
value (Union[int, str]): criterion
Returns:
frappe.qb: `frappe.qb object with `BETWEEN`
"""
return key[slice(*value)]
def func_is(key, value):
"Wrapper for IS"
return key.isnotnull() if value.lower() == "set" else key.isnull()
def func_timespan(key: Field, value: str) -> frappe.qb:
"""Wrapper method for `TIMESPAN`
Args:
key (str): field
value (str): criterion
Returns:
frappe.qb: `frappe.qb object with `TIMESPAN`
"""
return func_between(key, get_timespan_date_range(value))
def literal_eval_(literal):
try:
return literal_eval(literal)
except (ValueError, SyntaxError):
return literal
def has_function(field):
_field = field.casefold() if (isinstance(field, str) and "`" not in field) else field
if not issubclass(type(_field), Criterion):
if any([f"{func}(" in _field for func in SQL_FUNCTIONS]):
return True
def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str):
table = frappe.qb.DocType(doctype)
try:
lft, rgt = frappe.qb.from_(table).select("lft", "rgt").where(table.name == name).run()[0]
except IndexError:
lft, rgt = None, None
if hierarchy in ("descendants of", "not descendants of"):
result = (
frappe.qb.from_(table)
.select(table.name)
.where(table.lft > lft)
.where(table.rgt < rgt)
.orderby(table.lft, order=Order.asc)
.run()
)
else:
# Get ancestor elements of a DocType with a tree structure
result = (
frappe.qb.from_(table)
.select(table.name)
.where(table.lft < lft)
.where(table.rgt > rgt)
.orderby(table.lft, order=Order.desc)
.run()
)
return result
# default operators
OPERATOR_MAP: dict[str, Callable] = {
"+": operator.add,
"=": operator.eq,
"-": operator.sub,
"!=": operator.ne,
"<": operator.lt,
">": operator.gt,
"<=": operator.le,
"=<": operator.le,
">=": operator.ge,
"=>": operator.ge,
"/": operator.truediv,
"*": operator.mul,
"in": func_in,
"not in": func_not_in,
"like": like,
"not like": not_like,
"regex": func_regex,
"between": func_between,
"is": func_is,
"timespan": func_timespan,
"nested_set": NestedSetHierarchy,
# TODO: Add support for custom operators (WIP) - via filters_config hooks
}
class Engine:
tables: dict[str, str] = {}
def get_query(
self,
table: str,
fields: list | tuple | None = None,
filters: dict[str, str | int] | str | int | list[list | str | int] | None = None,
pluck: str | None = None,
order_by: str | None = None,
group_by: str | None = None,
limit: int | None = None,
@ -218,15 +40,11 @@ class Engine:
update: bool = False,
into: bool = False,
delete: bool = False,
) -> MySQLQueryBuilder | PostgreSQLQueryBuilder:
# Clean up state before each query
) -> QueryBuilder:
self.is_mariadb = frappe.db.db_type == "mariadb"
self.is_postgres = frappe.db.db_type == "postgres"
self.tables = {}
self.implicit_joins = set()
self.doctype = table
self.table = self.get_table(table)
self.table = frappe.qb.DocType(table)
if update:
self.query = frappe.qb.update(self.table)
@ -236,19 +54,9 @@ class Engine:
self.query = frappe.qb.from_(self.table).delete()
else:
self.query = frappe.qb.from_(self.table)
# add fields
self.fields = self.parse_fields(fields)
if not self.fields:
self.fields = [getattr(self.table, pluck or "name")]
for field in self.fields:
if isinstance(field, DynamicTableField):
self.query = field.apply_select(self.query)
else:
self.query = self.query.select(field)
self.apply_fields(fields)
self.apply_filters(filters)
self.apply_implicit_joins()
self.apply_order_by(order_by)
if limit:
@ -268,16 +76,21 @@ class Engine:
return self.query
def get_table(self, table_name: str | Table) -> Table:
if isinstance(table_name, Table):
return table_name
table_name = table_name.strip('"').strip("'")
if table_name not in self.tables:
self.tables[table_name] = frappe.qb.DocType(table_name)
return self.tables[table_name]
def apply_fields(self, fields):
# add fields
self.fields = self.parse_fields(fields)
if not self.fields:
self.fields = [getattr(self.table, "name")]
for field in self.fields:
if isinstance(field, DynamicTableField):
self.query = field.apply_select(self.query)
else:
self.query = self.query.select(field)
def apply_filters(
self, filters: dict[str, str | int | list] | str | int | list[list] | None = None
self,
filters: dict[str, str | int] | str | int | list[list | str | int] | None = None,
):
if not filters:
return
@ -332,12 +145,12 @@ class Engine:
elif not doctype or doctype == self.doctype:
_field = self.table[field]
elif doctype:
_field = self.get_table(doctype)[field]
_field = frappe.qb.DocType(doctype)[field]
# apply implicit join if child table is referenced
if doctype and doctype != self.doctype:
meta = frappe.get_meta(doctype)
table = self.get_table(doctype)
table = frappe.qb.DocType(doctype)
if meta.istable and not self.query.is_joined(table):
self.query = self.query.left_join(table).on(
(table.parent == self.table.name) & (table.parenttype == self.doctype)
@ -354,14 +167,14 @@ class Engine:
_value = self.get_function_object(_value)
# Nested set
if _operator in self.OPERATOR_MAP["nested_set"]:
if _operator in OPERATOR_MAP["nested_set"]:
hierarchy = _operator
docname = _value
result = get_nested_set_hierarchy_result(self.doctype, docname, hierarchy)
operator_fn = (
self.OPERATOR_MAP["not in"]
OPERATOR_MAP["not in"]
if hierarchy in ("not ancestors of", "not descendants of")
else self.OPERATOR_MAP["in"]
else OPERATOR_MAP["in"]
)
if result:
result = list(itertools.chain.from_iterable(result))
@ -370,7 +183,7 @@ class Engine:
self.query = self.query.where(operator_fn(_field, ("",)))
return
operator_fn = self.OPERATOR_MAP[_operator.casefold()]
operator_fn = OPERATOR_MAP[_operator.casefold()]
if _value is None and isinstance(_field, Field):
self.query = self.query.where(_field.isnull())
else:
@ -490,15 +303,6 @@ class Engine:
return _fields
def apply_implicit_joins(self):
for d in self.implicit_joins:
doctype, join_type = d
table = self.get_table(doctype)
if join_type == "child":
self.query = self.query.left_join(table).on(
(table.parent == self.table.name) & (table.parenttype == self.doctype)
)
def apply_order_by(self, order_by: str | None):
if not order_by or order_by == "KEEP_DEFAULT_ORDERING":
return
@ -509,22 +313,6 @@ class Engine:
order_direction = Order.asc if order_direction.lower() == "asc" else Order.desc
self.query = self.query.orderby(order_field, order=order_direction)
@cached_property
def OPERATOR_MAP(self):
# default operators
all_operators = OPERATOR_MAP.copy()
# TODO: update with site-specific custom operators / removed previous buggy implementation
if frappe.get_hooks("filters_config"):
from frappe.utils.commands import warn
warn(
"The 'filters_config' hook used to add custom operators is not yet implemented"
" in frappe.db.query engine. Use db_query (frappe.get_list) instead."
)
return all_operators
class Permission:
@classmethod
@ -657,3 +445,46 @@ class LinkTableField(DynamicTableField):
if not query.is_joined(table):
query = query.left_join(table).on(table.name == getattr(main_table, self.link_fieldname))
return query
def literal_eval_(literal):
try:
return literal_eval(literal)
except (ValueError, SyntaxError):
return literal
def has_function(field):
_field = field.casefold() if (isinstance(field, str) and "`" not in field) else field
if not issubclass(type(_field), Criterion):
if any([f"{func}(" in _field for func in SQL_FUNCTIONS]):
return True
def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str):
table = frappe.qb.DocType(doctype)
try:
lft, rgt = frappe.qb.from_(table).select("lft", "rgt").where(table.name == name).run()[0]
except IndexError:
lft, rgt = None, None
if hierarchy in ("descendants of", "not descendants of"):
result = (
frappe.qb.from_(table)
.select(table.name)
.where(table.lft > lft)
.where(table.rgt < rgt)
.orderby(table.lft, order=Order.asc)
.run()
)
else:
# Get ancestor elements of a DocType with a tree structure
result = (
frappe.qb.from_(table)
.select(table.name)
.where(table.lft < lft)
.where(table.rgt > rgt)
.orderby(table.lft, order=Order.desc)
.run()
)
return result

View file

@ -2,9 +2,7 @@ from enum import Enum
from importlib import import_module
from typing import Any, Callable, get_type_hints
from pypika import Query
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
from pypika.queries import Column
from pypika.queries import Column, QueryBuilder
from pypika.terms import PseudoColumn
import frappe
@ -56,7 +54,7 @@ def get_query_builder(type_of_db: str) -> Postgres | MariaDB:
return picks[db]
def get_query(*args, **kwargs) -> MySQLQueryBuilder | PostgreSQLQueryBuilder:
def get_query(*args, **kwargs) -> QueryBuilder:
from frappe.database.query import Engine
return Engine().get_query(*args, **kwargs)