500 lines
15 KiB
Python
500 lines
15 KiB
Python
import itertools
|
|
import re
|
|
from ast import literal_eval
|
|
from types import BuiltinFunctionType
|
|
from typing import TYPE_CHECKING
|
|
|
|
import sqlparse
|
|
from pypika.queries import QueryBuilder, Table
|
|
|
|
import frappe
|
|
from frappe import _
|
|
from frappe.database.operator_map import OPERATOR_MAP
|
|
from frappe.database.utils import get_doctype_name
|
|
from frappe.query_builder import Criterion, Field, Order, functions
|
|
from frappe.query_builder.functions import Function, SqlFunctions
|
|
from frappe.query_builder.utils import PseudoColumnMapper
|
|
from frappe.utils.data import MARIADB_SPECIFIC_COMMENT
|
|
|
|
if TYPE_CHECKING:
|
|
from frappe.query_builder import DocType
|
|
|
|
TAB_PATTERN = re.compile("^tab")
|
|
WORDS_PATTERN = re.compile(r"\w+")
|
|
BRACKETS_PATTERN = re.compile(r"\(.*?\)|$")
|
|
SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions]
|
|
COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
|
|
|
|
|
|
class Engine:
|
|
def get_query(
|
|
self,
|
|
table: str | Table,
|
|
fields: list | tuple | None = None,
|
|
filters: dict[str, str | int] | str | int | list[list | str | int] | None = None,
|
|
order_by: str | None = None,
|
|
group_by: str | None = None,
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
distinct: bool = False,
|
|
for_update: bool = False,
|
|
update: bool = False,
|
|
into: bool = False,
|
|
delete: bool = False,
|
|
) -> QueryBuilder:
|
|
self.is_mariadb = frappe.db.db_type == "mariadb"
|
|
self.is_postgres = frappe.db.db_type == "postgres"
|
|
|
|
if isinstance(table, Table):
|
|
self.table = table
|
|
self.doctype = get_doctype_name(table.get_sql())
|
|
else:
|
|
self.doctype = table
|
|
self.table = frappe.qb.DocType(table)
|
|
|
|
if update:
|
|
self.query = frappe.qb.update(self.table)
|
|
elif into:
|
|
self.query = frappe.qb.into(self.table)
|
|
elif delete:
|
|
self.query = frappe.qb.from_(self.table).delete()
|
|
else:
|
|
self.query = frappe.qb.from_(self.table)
|
|
self.apply_fields(fields)
|
|
|
|
self.apply_filters(filters)
|
|
self.apply_order_by(order_by)
|
|
|
|
if limit:
|
|
self.query = self.query.limit(limit)
|
|
|
|
if offset:
|
|
self.query = self.query.offset(offset)
|
|
|
|
if distinct:
|
|
self.query = self.query.distinct()
|
|
|
|
if for_update:
|
|
self.query = self.query.for_update()
|
|
|
|
if group_by:
|
|
self.query = self.query.groupby(group_by)
|
|
|
|
return self.query
|
|
|
|
def apply_fields(self, fields):
|
|
# add fields
|
|
self.fields = self.parse_fields(fields)
|
|
if not self.fields:
|
|
self.fields = [getattr(self.table, "name")]
|
|
|
|
for field in self.fields:
|
|
if isinstance(field, DynamicTableField):
|
|
self.query = field.apply_select(self.query)
|
|
else:
|
|
self.query = self.query.select(field)
|
|
|
|
def apply_filters(
|
|
self,
|
|
filters: dict[str, str | int] | str | int | list[list | str | int] | None = None,
|
|
):
|
|
if filters is None:
|
|
return
|
|
|
|
if isinstance(filters, (str, int)):
|
|
filters = {"name": str(filters)}
|
|
|
|
if isinstance(filters, Criterion):
|
|
self.query = self.query.where(filters)
|
|
|
|
elif isinstance(filters, dict):
|
|
self.apply_dict_filters(filters)
|
|
|
|
elif isinstance(filters, (list, tuple)):
|
|
if all(isinstance(d, (str, int)) for d in filters) and len(filters) > 0:
|
|
self.apply_dict_filters({"name": ("in", filters)})
|
|
else:
|
|
for filter in filters:
|
|
if isinstance(filter, (str, int, Criterion, dict)):
|
|
self.apply_filters(filter)
|
|
elif isinstance(filter, (list, tuple)):
|
|
self.apply_list_filters(filter)
|
|
|
|
def apply_list_filters(self, filter: list):
|
|
if len(filter) == 2:
|
|
field, value = filter
|
|
self._apply_filter(field, value)
|
|
elif len(filter) == 3:
|
|
field, operator, value = filter
|
|
self._apply_filter(field, value, operator)
|
|
elif len(filter) == 4:
|
|
doctype, field, operator, value = filter
|
|
self._apply_filter(field, value, operator, doctype)
|
|
|
|
def apply_dict_filters(self, filters: dict[str, str | int | list]):
|
|
for key in filters:
|
|
value = filters.get(key)
|
|
self._apply_filter(key, value)
|
|
|
|
def _apply_filter(
|
|
self, field: str, value: str | int | list | None, operator: str = "=", doctype: str | None = None
|
|
):
|
|
_field = field
|
|
_value = value
|
|
_operator = operator
|
|
|
|
if isinstance(_field, Field):
|
|
pass
|
|
elif dynamic_field := DynamicTableField.parse(field, self.doctype):
|
|
# apply implicit join if link field's field is referenced
|
|
self.query = dynamic_field.apply_join(self.query)
|
|
_field = dynamic_field.field
|
|
elif has_function(field):
|
|
_field = self.get_function_object(field)
|
|
elif not doctype or doctype == self.doctype:
|
|
_field = self.table[field]
|
|
elif doctype:
|
|
_field = frappe.qb.DocType(doctype)[field]
|
|
|
|
# apply implicit join if child table is referenced
|
|
if doctype and doctype != self.doctype:
|
|
meta = frappe.get_meta(doctype)
|
|
table = frappe.qb.DocType(doctype)
|
|
if meta.istable and not self.query.is_joined(table):
|
|
self.query = self.query.left_join(table).on(
|
|
(table.parent == self.table.name) & (table.parenttype == self.doctype)
|
|
)
|
|
|
|
if isinstance(_value, (list, tuple)):
|
|
_operator, _value = _value
|
|
elif isinstance(_value, bool):
|
|
_value = int(_value)
|
|
|
|
if isinstance(_value, str) and has_function(_value):
|
|
_value = self.get_function_object(_value)
|
|
|
|
if isinstance(_value, (list, tuple)) and not _value:
|
|
_value = ("",)
|
|
|
|
# Nested set
|
|
if _operator in OPERATOR_MAP["nested_set"]:
|
|
hierarchy = _operator
|
|
docname = _value
|
|
result = get_nested_set_hierarchy_result(self.doctype, docname, hierarchy)
|
|
operator_fn = (
|
|
OPERATOR_MAP["not in"]
|
|
if hierarchy in ("not ancestors of", "not descendants of")
|
|
else OPERATOR_MAP["in"]
|
|
)
|
|
if result:
|
|
result = list(itertools.chain.from_iterable(result))
|
|
self.query = self.query.where(operator_fn(_field, result))
|
|
else:
|
|
self.query = self.query.where(operator_fn(_field, ("",)))
|
|
return
|
|
|
|
operator_fn = OPERATOR_MAP[_operator.casefold()]
|
|
if _value is None and isinstance(_field, Field):
|
|
self.query = self.query.where(_field.isnull())
|
|
else:
|
|
self.query = self.query.where(operator_fn(_field, _value))
|
|
|
|
def get_function_object(self, field: str) -> "Function":
|
|
"""Expects field to look like 'SUM(*)' or 'name' or something similar. Returns PyPika Function object"""
|
|
func = field.split("(", maxsplit=1)[0].capitalize()
|
|
args_start, args_end = len(func) + 1, field.index(")")
|
|
args = field[args_start:args_end].split(",")
|
|
|
|
_, alias = field.split(" as ") if " as " in field else (None, None)
|
|
|
|
to_cast = "*" not in args
|
|
_args = []
|
|
|
|
for arg in args:
|
|
initial_fields = literal_eval_(arg.strip())
|
|
if to_cast:
|
|
has_primitive_operator = False
|
|
for _operator in OPERATOR_MAP.keys():
|
|
if _operator in initial_fields:
|
|
operator_mapping = OPERATOR_MAP[_operator]
|
|
# Only perform this if operator is of primitive type.
|
|
if isinstance(operator_mapping, BuiltinFunctionType):
|
|
has_primitive_operator = True
|
|
field = operator_mapping(
|
|
*map(
|
|
lambda field: Field(field.strip())
|
|
if "`" not in field
|
|
else PseudoColumnMapper(field.strip()),
|
|
arg.split(_operator),
|
|
),
|
|
)
|
|
|
|
field = (
|
|
(Field(initial_fields) if "`" not in initial_fields else PseudoColumnMapper(initial_fields))
|
|
if not has_primitive_operator
|
|
else field
|
|
)
|
|
else:
|
|
field = initial_fields
|
|
|
|
_args.append(field)
|
|
|
|
if alias and "`" in alias:
|
|
alias = alias.replace("`", "")
|
|
try:
|
|
if func.casefold() == "now":
|
|
return getattr(functions, func)()
|
|
return getattr(functions, func)(*_args, alias=alias or None)
|
|
except AttributeError:
|
|
# Fall back for functions not present in `SqlFunctions``
|
|
return Function(func, *_args, alias=alias or None)
|
|
|
|
def sanitize_fields(self, fields: str | list | tuple):
|
|
def _sanitize_field(field: str):
|
|
if not isinstance(field, str):
|
|
return field
|
|
stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower")
|
|
if self.is_mariadb:
|
|
return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field)
|
|
return stripped_field
|
|
|
|
if isinstance(fields, (list, tuple)):
|
|
return [_sanitize_field(field) for field in fields]
|
|
elif isinstance(fields, str):
|
|
return _sanitize_field(fields)
|
|
|
|
return fields
|
|
|
|
def parse_string_field(self, field: str):
|
|
if field == "*":
|
|
return self.table.star
|
|
alias = None
|
|
if " as " in field:
|
|
field, alias = field.split(" as ")
|
|
if "`" in field:
|
|
if alias:
|
|
return PseudoColumnMapper(f"{field} {alias}")
|
|
return PseudoColumnMapper(field)
|
|
if alias:
|
|
return self.table[field].as_(alias)
|
|
return self.table[field]
|
|
|
|
def parse_fields(self, fields: str | list | tuple | None) -> list:
|
|
if not fields:
|
|
return []
|
|
fields = self.sanitize_fields(fields)
|
|
if isinstance(fields, (list, tuple, set)) and None in fields and Field not in fields:
|
|
return []
|
|
|
|
if not isinstance(fields, (list, tuple)):
|
|
fields = [fields]
|
|
|
|
def parse_field(field: str):
|
|
if has_function(field):
|
|
return self.get_function_object(field)
|
|
elif parsed := DynamicTableField.parse(field, self.doctype):
|
|
return parsed
|
|
else:
|
|
return self.parse_string_field(field)
|
|
|
|
_fields = []
|
|
for field in fields:
|
|
if isinstance(field, Criterion):
|
|
_fields.append(field)
|
|
elif isinstance(field, str):
|
|
if "," in field:
|
|
field = field.casefold() if "`" not in field else field
|
|
field_list = COMMA_PATTERN.split(field)
|
|
for field in field_list:
|
|
if _field := field.strip():
|
|
_fields.append(parse_field(_field))
|
|
else:
|
|
_fields.append(parse_field(field))
|
|
|
|
return _fields
|
|
|
|
def apply_order_by(self, order_by: str | None):
|
|
if not order_by or order_by == "KEEP_DEFAULT_ORDERING":
|
|
return
|
|
for declaration in order_by.split(","):
|
|
if _order_by := declaration.strip():
|
|
parts = _order_by.split(" ")
|
|
order_field, order_direction = parts[0], parts[1] if len(parts) > 1 else "desc"
|
|
order_direction = Order.asc if order_direction.lower() == "asc" else Order.desc
|
|
self.query = self.query.orderby(order_field, order=order_direction)
|
|
|
|
|
|
class Permission:
|
|
@classmethod
|
|
def check_permissions(cls, query, **kwargs):
|
|
if not isinstance(query, str):
|
|
query = query.get_sql()
|
|
|
|
doctype = cls.get_tables_from_query(query)
|
|
if isinstance(doctype, str):
|
|
doctype = [doctype]
|
|
|
|
for dt in doctype:
|
|
dt = TAB_PATTERN.sub("", dt)
|
|
if not frappe.has_permission(
|
|
dt,
|
|
"select",
|
|
user=kwargs.get("user"),
|
|
parent_doctype=kwargs.get("parent_doctype"),
|
|
) and not frappe.has_permission(
|
|
dt,
|
|
"read",
|
|
user=kwargs.get("user"),
|
|
parent_doctype=kwargs.get("parent_doctype"),
|
|
):
|
|
frappe.throw(_("Insufficient Permission for {0}").format(frappe.bold(dt)))
|
|
|
|
@staticmethod
|
|
def get_tables_from_query(query: str):
|
|
return [table for table in WORDS_PATTERN.findall(query) if table.startswith("tab")]
|
|
|
|
|
|
class DynamicTableField:
|
|
def __init__(
|
|
self,
|
|
doctype: str,
|
|
fieldname: str,
|
|
parent_doctype: str,
|
|
alias: str | None = None,
|
|
) -> None:
|
|
self.doctype = doctype
|
|
self.fieldname = fieldname
|
|
self.alias = alias
|
|
self.parent_doctype = parent_doctype
|
|
|
|
def __str__(self) -> str:
|
|
table_name = f"`tab{self.doctype}`"
|
|
fieldname = f"`{self.fieldname}`"
|
|
if frappe.db.db_type == "postgres":
|
|
table_name = table_name.replace("`", '"')
|
|
fieldname = fieldname.replace("`", '"')
|
|
alias = f"AS {self.alias}" if self.alias else ""
|
|
return f"{table_name}.{fieldname} {alias}".strip()
|
|
|
|
@staticmethod
|
|
def parse(field: str, doctype: str):
|
|
if "." in field:
|
|
alias = None
|
|
if " as " in field:
|
|
field, alias = field.split(" as ")
|
|
if field.startswith("`tab") or field.startswith('"tab'):
|
|
_, child_doctype, child_field = re.search(r'([`"])tab(.+?)\1.\1(.+)\1', field).groups()
|
|
if child_doctype == doctype:
|
|
return
|
|
return ChildTableField(child_doctype, child_field, doctype, alias=alias)
|
|
else:
|
|
linked_fieldname, fieldname = field.split(".")
|
|
linked_field = frappe.get_meta(doctype).get_field(linked_fieldname)
|
|
linked_doctype = linked_field.options
|
|
if linked_field.fieldtype == "Link":
|
|
return LinkTableField(linked_doctype, fieldname, doctype, linked_fieldname, alias=alias)
|
|
elif linked_field.fieldtype in frappe.model.table_fields:
|
|
return ChildTableField(linked_doctype, fieldname, doctype, alias=alias)
|
|
|
|
def apply_select(self, query: QueryBuilder) -> QueryBuilder:
|
|
raise NotImplementedError
|
|
|
|
|
|
class ChildTableField(DynamicTableField):
|
|
def __init__(
|
|
self,
|
|
doctype: str,
|
|
fieldname: str,
|
|
parent_doctype: str,
|
|
alias: str | None = None,
|
|
) -> None:
|
|
self.doctype = doctype
|
|
self.fieldname = fieldname
|
|
self.alias = alias
|
|
self.parent_doctype = parent_doctype
|
|
self.table = frappe.qb.DocType(self.doctype)
|
|
self.field = self.table[self.fieldname]
|
|
|
|
def apply_select(self, query: QueryBuilder) -> QueryBuilder:
|
|
table = frappe.qb.DocType(self.doctype)
|
|
query = self.apply_join(query)
|
|
return query.select(getattr(table, self.fieldname).as_(self.alias or None))
|
|
|
|
def apply_join(self, query: QueryBuilder) -> QueryBuilder:
|
|
table = frappe.qb.DocType(self.doctype)
|
|
main_table = frappe.qb.DocType(self.parent_doctype)
|
|
if not query.is_joined(table):
|
|
query = query.left_join(table).on(
|
|
(table.parent == main_table.name) & (table.parenttype == self.parent_doctype)
|
|
)
|
|
return query
|
|
|
|
|
|
class LinkTableField(DynamicTableField):
|
|
def __init__(
|
|
self,
|
|
doctype: str,
|
|
fieldname: str,
|
|
parent_doctype: str,
|
|
link_fieldname: str,
|
|
alias: str | None = None,
|
|
) -> None:
|
|
super().__init__(doctype, fieldname, parent_doctype, alias=alias)
|
|
self.link_fieldname = link_fieldname
|
|
self.table = frappe.qb.DocType(self.doctype)
|
|
self.field = self.table[self.fieldname]
|
|
|
|
def apply_select(self, query: QueryBuilder) -> QueryBuilder:
|
|
table = frappe.qb.DocType(self.doctype)
|
|
query = self.apply_join(query)
|
|
return query.select(getattr(table, self.fieldname).as_(self.alias or None))
|
|
|
|
def apply_join(self, query: QueryBuilder) -> QueryBuilder:
|
|
table = frappe.qb.DocType(self.doctype)
|
|
main_table = frappe.qb.DocType(self.parent_doctype)
|
|
if not query.is_joined(table):
|
|
query = query.left_join(table).on(table.name == getattr(main_table, self.link_fieldname))
|
|
return query
|
|
|
|
|
|
def literal_eval_(literal):
|
|
try:
|
|
return literal_eval(literal)
|
|
except (ValueError, SyntaxError):
|
|
return literal
|
|
|
|
|
|
def has_function(field):
|
|
_field = field.casefold() if (isinstance(field, str) and "`" not in field) else field
|
|
if not issubclass(type(_field), Criterion):
|
|
if any([f"{func}(" in _field for func in SQL_FUNCTIONS]):
|
|
return True
|
|
|
|
|
|
def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str):
|
|
table = frappe.qb.DocType(doctype)
|
|
try:
|
|
lft, rgt = frappe.qb.from_(table).select("lft", "rgt").where(table.name == name).run()[0]
|
|
except IndexError:
|
|
lft, rgt = None, None
|
|
|
|
if hierarchy in ("descendants of", "not descendants of"):
|
|
result = (
|
|
frappe.qb.from_(table)
|
|
.select(table.name)
|
|
.where(table.lft > lft)
|
|
.where(table.rgt < rgt)
|
|
.orderby(table.lft, order=Order.asc)
|
|
.run()
|
|
)
|
|
else:
|
|
# Get ancestor elements of a DocType with a tree structure
|
|
result = (
|
|
frappe.qb.from_(table)
|
|
.select(table.name)
|
|
.where(table.lft < lft)
|
|
.where(table.rgt > rgt)
|
|
.orderby(table.lft, order=Order.desc)
|
|
.run()
|
|
)
|
|
return result
|