import itertools import re from ast import literal_eval from types import BuiltinFunctionType from typing import TYPE_CHECKING import sqlparse from pypika.queries import QueryBuilder, Table import frappe from frappe import _ from frappe.database.operator_map import OPERATOR_MAP from frappe.database.utils import get_doctype_name from frappe.query_builder import Criterion, Field, Order, functions from frappe.query_builder.functions import Function, SqlFunctions from frappe.query_builder.utils import PseudoColumnMapper from frappe.utils.data import MARIADB_SPECIFIC_COMMENT if TYPE_CHECKING: from frappe.query_builder import DocType TAB_PATTERN = re.compile("^tab") WORDS_PATTERN = re.compile(r"\w+") BRACKETS_PATTERN = re.compile(r"\(.*?\)|$") SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions] COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))") class Engine: def get_query( self, table: str | Table, fields: list | tuple | None = None, filters: dict[str, str | int] | str | int | list[list | str | int] | None = None, order_by: str | None = None, group_by: str | None = None, limit: int | None = None, offset: int | None = None, distinct: bool = False, for_update: bool = False, update: bool = False, into: bool = False, delete: bool = False, ) -> QueryBuilder: self.is_mariadb = frappe.db.db_type == "mariadb" self.is_postgres = frappe.db.db_type == "postgres" if isinstance(table, Table): self.table = table self.doctype = get_doctype_name(table.get_sql()) else: self.doctype = table self.table = frappe.qb.DocType(table) if update: self.query = frappe.qb.update(self.table) elif into: self.query = frappe.qb.into(self.table) elif delete: self.query = frappe.qb.from_(self.table).delete() else: self.query = frappe.qb.from_(self.table) self.apply_fields(fields) self.apply_filters(filters) self.apply_order_by(order_by) if limit: self.query = self.query.limit(limit) if offset: self.query = self.query.offset(offset) if distinct: self.query = self.query.distinct() if for_update: self.query = self.query.for_update() if group_by: self.query = self.query.groupby(group_by) return self.query def apply_fields(self, fields): # add fields self.fields = self.parse_fields(fields) if not self.fields: self.fields = [getattr(self.table, "name")] for field in self.fields: if isinstance(field, DynamicTableField): self.query = field.apply_select(self.query) else: self.query = self.query.select(field) def apply_filters( self, filters: dict[str, str | int] | str | int | list[list | str | int] | None = None, ): if filters is None: return if isinstance(filters, (str, int)): filters = {"name": str(filters)} if isinstance(filters, Criterion): self.query = self.query.where(filters) elif isinstance(filters, dict): self.apply_dict_filters(filters) elif isinstance(filters, (list, tuple)): if all(isinstance(d, (str, int)) for d in filters) and len(filters) > 0: self.apply_dict_filters({"name": ("in", filters)}) else: for filter in filters: if isinstance(filter, (str, int, Criterion, dict)): self.apply_filters(filter) elif isinstance(filter, (list, tuple)): self.apply_list_filters(filter) def apply_list_filters(self, filter: list): if len(filter) == 2: field, value = filter self._apply_filter(field, value) elif len(filter) == 3: field, operator, value = filter self._apply_filter(field, value, operator) elif len(filter) == 4: doctype, field, operator, value = filter self._apply_filter(field, value, operator, doctype) def apply_dict_filters(self, filters: dict[str, str | int | list]): for key in filters: value = filters.get(key) self._apply_filter(key, value) def _apply_filter( self, field: str, value: str | int | list | None, operator: str = "=", doctype: str | None = None ): _field = field _value = value _operator = operator if isinstance(_field, Field): pass elif dynamic_field := DynamicTableField.parse(field, self.doctype): # apply implicit join if link field's field is referenced self.query = dynamic_field.apply_join(self.query) _field = dynamic_field.field elif has_function(field): _field = self.get_function_object(field) elif not doctype or doctype == self.doctype: _field = self.table[field] elif doctype: _field = frappe.qb.DocType(doctype)[field] # apply implicit join if child table is referenced if doctype and doctype != self.doctype: meta = frappe.get_meta(doctype) table = frappe.qb.DocType(doctype) if meta.istable and not self.query.is_joined(table): self.query = self.query.left_join(table).on( (table.parent == self.table.name) & (table.parenttype == self.doctype) ) if isinstance(_value, (list, tuple)): _operator, _value = _value elif isinstance(_value, bool): _value = int(_value) if isinstance(_value, str) and has_function(_value): _value = self.get_function_object(_value) if isinstance(_value, (list, tuple)) and not _value: _value = ("",) # Nested set if _operator in OPERATOR_MAP["nested_set"]: hierarchy = _operator docname = _value result = get_nested_set_hierarchy_result(self.doctype, docname, hierarchy) operator_fn = ( OPERATOR_MAP["not in"] if hierarchy in ("not ancestors of", "not descendants of") else OPERATOR_MAP["in"] ) if result: result = list(itertools.chain.from_iterable(result)) self.query = self.query.where(operator_fn(_field, result)) else: self.query = self.query.where(operator_fn(_field, ("",))) return operator_fn = OPERATOR_MAP[_operator.casefold()] if _value is None and isinstance(_field, Field): self.query = self.query.where(_field.isnull()) else: self.query = self.query.where(operator_fn(_field, _value)) def get_function_object(self, field: str) -> "Function": """Expects field to look like 'SUM(*)' or 'name' or something similar. Returns PyPika Function object""" func = field.split("(", maxsplit=1)[0].capitalize() args_start, args_end = len(func) + 1, field.index(")") args = field[args_start:args_end].split(",") _, alias = field.split(" as ") if " as " in field else (None, None) to_cast = "*" not in args _args = [] for arg in args: initial_fields = literal_eval_(arg.strip()) if to_cast: has_primitive_operator = False for _operator in OPERATOR_MAP.keys(): if _operator in initial_fields: operator_mapping = OPERATOR_MAP[_operator] # Only perform this if operator is of primitive type. if isinstance(operator_mapping, BuiltinFunctionType): has_primitive_operator = True field = operator_mapping( *map( lambda field: Field(field.strip()) if "`" not in field else PseudoColumnMapper(field.strip()), arg.split(_operator), ), ) field = ( (Field(initial_fields) if "`" not in initial_fields else PseudoColumnMapper(initial_fields)) if not has_primitive_operator else field ) else: field = initial_fields _args.append(field) if alias and "`" in alias: alias = alias.replace("`", "") try: if func.casefold() == "now": return getattr(functions, func)() return getattr(functions, func)(*_args, alias=alias or None) except AttributeError: # Fall back for functions not present in `SqlFunctions`` return Function(func, *_args, alias=alias or None) def sanitize_fields(self, fields: str | list | tuple): def _sanitize_field(field: str): if not isinstance(field, str): return field stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower") if self.is_mariadb: return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field) return stripped_field if isinstance(fields, (list, tuple)): return [_sanitize_field(field) for field in fields] elif isinstance(fields, str): return _sanitize_field(fields) return fields def parse_string_field(self, field: str): if field == "*": return self.table.star alias = None if " as " in field: field, alias = field.split(" as ") if "`" in field: if alias: return PseudoColumnMapper(f"{field} {alias}") return PseudoColumnMapper(field) if alias: return self.table[field].as_(alias) return self.table[field] def parse_fields(self, fields: str | list | tuple | None) -> list: if not fields: return [] fields = self.sanitize_fields(fields) if isinstance(fields, (list, tuple, set)) and None in fields and Field not in fields: return [] if not isinstance(fields, (list, tuple)): fields = [fields] def parse_field(field: str): if has_function(field): return self.get_function_object(field) elif parsed := DynamicTableField.parse(field, self.doctype): return parsed else: return self.parse_string_field(field) _fields = [] for field in fields: if isinstance(field, Criterion): _fields.append(field) elif isinstance(field, str): if "," in field: field = field.casefold() if "`" not in field else field field_list = COMMA_PATTERN.split(field) for field in field_list: if _field := field.strip(): _fields.append(parse_field(_field)) else: _fields.append(parse_field(field)) return _fields def apply_order_by(self, order_by: str | None): if not order_by or order_by == "KEEP_DEFAULT_ORDERING": return for declaration in order_by.split(","): if _order_by := declaration.strip(): parts = _order_by.split(" ") order_field, order_direction = parts[0], parts[1] if len(parts) > 1 else "desc" order_direction = Order.asc if order_direction.lower() == "asc" else Order.desc self.query = self.query.orderby(order_field, order=order_direction) class Permission: @classmethod def check_permissions(cls, query, **kwargs): if not isinstance(query, str): query = query.get_sql() doctype = cls.get_tables_from_query(query) if isinstance(doctype, str): doctype = [doctype] for dt in doctype: dt = TAB_PATTERN.sub("", dt) if not frappe.has_permission( dt, "select", user=kwargs.get("user"), parent_doctype=kwargs.get("parent_doctype"), ) and not frappe.has_permission( dt, "read", user=kwargs.get("user"), parent_doctype=kwargs.get("parent_doctype"), ): frappe.throw(_("Insufficient Permission for {0}").format(frappe.bold(dt))) @staticmethod def get_tables_from_query(query: str): return [table for table in WORDS_PATTERN.findall(query) if table.startswith("tab")] class DynamicTableField: def __init__( self, doctype: str, fieldname: str, parent_doctype: str, alias: str | None = None, ) -> None: self.doctype = doctype self.fieldname = fieldname self.alias = alias self.parent_doctype = parent_doctype def __str__(self) -> str: table_name = f"`tab{self.doctype}`" fieldname = f"`{self.fieldname}`" if frappe.db.db_type == "postgres": table_name = table_name.replace("`", '"') fieldname = fieldname.replace("`", '"') alias = f"AS {self.alias}" if self.alias else "" return f"{table_name}.{fieldname} {alias}".strip() @staticmethod def parse(field: str, doctype: str): if "." in field: alias = None if " as " in field: field, alias = field.split(" as ") if field.startswith("`tab") or field.startswith('"tab'): _, child_doctype, child_field = re.search(r'([`"])tab(.+?)\1.\1(.+)\1', field).groups() if child_doctype == doctype: return return ChildTableField(child_doctype, child_field, doctype, alias=alias) else: linked_fieldname, fieldname = field.split(".") linked_field = frappe.get_meta(doctype).get_field(linked_fieldname) linked_doctype = linked_field.options if linked_field.fieldtype == "Link": return LinkTableField(linked_doctype, fieldname, doctype, linked_fieldname, alias=alias) elif linked_field.fieldtype in frappe.model.table_fields: return ChildTableField(linked_doctype, fieldname, doctype, alias=alias) def apply_select(self, query: QueryBuilder) -> QueryBuilder: raise NotImplementedError class ChildTableField(DynamicTableField): def __init__( self, doctype: str, fieldname: str, parent_doctype: str, alias: str | None = None, ) -> None: self.doctype = doctype self.fieldname = fieldname self.alias = alias self.parent_doctype = parent_doctype self.table = frappe.qb.DocType(self.doctype) self.field = self.table[self.fieldname] def apply_select(self, query: QueryBuilder) -> QueryBuilder: table = frappe.qb.DocType(self.doctype) query = self.apply_join(query) return query.select(getattr(table, self.fieldname).as_(self.alias or None)) def apply_join(self, query: QueryBuilder) -> QueryBuilder: table = frappe.qb.DocType(self.doctype) main_table = frappe.qb.DocType(self.parent_doctype) if not query.is_joined(table): query = query.left_join(table).on( (table.parent == main_table.name) & (table.parenttype == self.parent_doctype) ) return query class LinkTableField(DynamicTableField): def __init__( self, doctype: str, fieldname: str, parent_doctype: str, link_fieldname: str, alias: str | None = None, ) -> None: super().__init__(doctype, fieldname, parent_doctype, alias=alias) self.link_fieldname = link_fieldname self.table = frappe.qb.DocType(self.doctype) self.field = self.table[self.fieldname] def apply_select(self, query: QueryBuilder) -> QueryBuilder: table = frappe.qb.DocType(self.doctype) query = self.apply_join(query) return query.select(getattr(table, self.fieldname).as_(self.alias or None)) def apply_join(self, query: QueryBuilder) -> QueryBuilder: table = frappe.qb.DocType(self.doctype) main_table = frappe.qb.DocType(self.parent_doctype) if not query.is_joined(table): query = query.left_join(table).on(table.name == getattr(main_table, self.link_fieldname)) return query def literal_eval_(literal): try: return literal_eval(literal) except (ValueError, SyntaxError): return literal def has_function(field): _field = field.casefold() if (isinstance(field, str) and "`" not in field) else field if not issubclass(type(_field), Criterion): if any([f"{func}(" in _field for func in SQL_FUNCTIONS]): return True def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str): table = frappe.qb.DocType(doctype) try: lft, rgt = frappe.qb.from_(table).select("lft", "rgt").where(table.name == name).run()[0] except IndexError: lft, rgt = None, None if hierarchy in ("descendants of", "not descendants of"): result = ( frappe.qb.from_(table) .select(table.name) .where(table.lft > lft) .where(table.rgt < rgt) .orderby(table.lft, order=Order.asc) .run() ) else: # Get ancestor elements of a DocType with a tree structure result = ( frappe.qb.from_(table) .select(table.name) .where(table.lft < lft) .where(table.rgt > rgt) .orderby(table.lft, order=Order.desc) .run() ) return result