seitime-frappe/frappe/database/query.py

1253 lines
44 KiB
Python

import operator
import re
from ast import literal_eval
from functools import lru_cache
from types import BuiltinFunctionType
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeAlias, Union
import sqlparse
from pypika.queries import QueryBuilder, Table
from pypika.terms import AggregateFunction, Term
import frappe
from frappe import _
from frappe.database.operator_map import NESTED_SET_OPERATORS, OPERATOR_MAP
from frappe.database.schema import SPECIAL_CHAR_PATTERN
from frappe.database.utils import DefaultOrderBy, FilterValue, convert_to_value, get_doctype_name
from frappe.model import get_permitted_fields
from frappe.query_builder import Criterion, Field, Order, functions
from frappe.query_builder.functions import Function, SqlFunctions
from frappe.query_builder.utils import PseudoColumnMapper
from frappe.utils.data import MARIADB_SPECIFIC_COMMENT
if TYPE_CHECKING:
from frappe.query_builder import DocType
TAB_PATTERN = re.compile("^tab")
WORDS_PATTERN = re.compile(r"\w+")
BRACKETS_PATTERN = re.compile(r"\(.*?\)|$")
COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
# less restrictive version of frappe.core.doctype.doctype.doctype.START_WITH_LETTERS_PATTERN
# to allow table names like __Auth
TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII)
# Pattern to validate field names in SELECT:
# Allows: name, `name`, name as alias, `name` as alias, `table name`.`name`, `table name`.`name` as alias, table.name, table.name as alias
ALLOWED_FIELD_PATTERN = re.compile(r"^(?:`?[\w\s-]+`?\.)?(`?\w+`?|\w+)(?:\s+as\s+\w+)?$", flags=re.ASCII)
# Pattern to validate basic SQL function call syntax: word(...) [as alias]
FUNCTION_CALL_PATTERN = re.compile(r"^\w+\(.*\)(?:\s+as\s+\w+)?$", flags=re.IGNORECASE | re.ASCII)
# Pattern to validate field names used in various SQL clauses (WHERE, GROUP BY, ORDER BY):
# Allows simple field names, backticked names, and table-qualified names (e.g., name, `name`, `table`.`name`, table.name)
# Does NOT allow aliases ('as alias') or functions.
ALLOWED_SQL_FIELD_PATTERN = re.compile(r"^(?:`?\w+`?\.)?(`?\w+`?|\w+)$", flags=re.ASCII)
# Pattern to validate characters allowed within function arguments that are not simple fields/literals.
# Allows alphanumeric, underscore, whitespace, +, -, *, /, ., (, ), quotes, and the keyword 'distinct'.
# Disallows characters like ; = < > etc. to prevent injection.
ALLOWED_ARGUMENT_CHARS_PATTERN = re.compile(
r"^(?:[\w\s\+\-\*\/\.\(\)\`\'\"]+|\bDISTINCT\b)+$", flags=re.IGNORECASE | re.ASCII
)
# Regex to parse field names:
# Group 1: Optional quote for table name
# Group 2: Optional table name (e.g., `tabDocType` or tabDocType or `tabNote Seen By`)
# Group 3: Optional quote for field name
# Group 4: Field name (e.g., `field` or field)
FIELD_PARSE_REGEX = re.compile(r"^(?:([`\"]?)(tab[\w\s-]+)\1\.)?([`\"]?)(\w+)\3$")
# Regex to capture: FunctionName(Arguments) [AS Alias]
# Group 1: Function Name (e.g., COUNT, SUM)
# Group 2: Arguments string (e.g., *, field1, 'literal', field2)
# Group 3: Optional Alias (e.g., average_price or `average_price`) - allows backticks
SQL_FUNCTION_PATTERN = re.compile(
r"^([a-zA-Z_]\w*)\s*\((.*?)\)(?:\s+as\s+(`?[\w\s-]+`?|\w+))?$", flags=re.IGNORECASE | re.ASCII
)
# Regex to split arguments, respecting potential quotes or nested parentheses
ARGS_SPLIT_PATTERN = re.compile(r",\s*(?![^()]*\))")
class Engine:
def get_query(
self,
table: str | Table,
fields: str | list | tuple | None = None,
filters: dict[str, FilterValue] | FilterValue | list[list | FilterValue] | None = None,
order_by: str | None = None,
group_by: str | None = None,
limit: int | None = None,
offset: int | None = None,
distinct: bool = False,
for_update: bool = False,
update: bool = False,
into: bool = False,
delete: bool = False,
*,
validate_filters: bool = False,
skip_locked: bool = False,
wait: bool = True,
ignore_permissions: bool = True,
user: str | None = None,
parent_doctype: str | None = None,
) -> QueryBuilder:
qb = frappe.local.qb
db_type = frappe.local.db.db_type
self.is_mariadb = db_type == "mariadb"
self.is_postgres = db_type == "postgres"
self.is_sqlite = db_type == "sqlite"
self.validate_filters = validate_filters
self.user = user or frappe.session.user
self.parent_doctype = parent_doctype
self.apply_permissions = not ignore_permissions
if isinstance(table, Table):
self.table = table
self.doctype = get_doctype_name(table.get_sql())
else:
self.doctype = table
self.validate_doctype()
self.table = qb.DocType(table)
if self.apply_permissions:
self.check_read_permission()
if update:
self.query = qb.update(self.table, immutable=False)
elif into:
self.query = qb.into(self.table, immutable=False)
elif delete:
self.query = qb.from_(self.table, immutable=False).delete()
else:
self.query = qb.from_(self.table, immutable=False)
self.apply_fields(fields)
self.apply_filters(filters)
self.apply_order_by(order_by)
if limit:
if not isinstance(limit, int) or limit < 0:
frappe.throw(_("Limit must be a non-negative integer"), TypeError)
self.query = self.query.limit(limit)
if offset:
if not isinstance(offset, int) or offset < 0:
frappe.throw(_("Offset must be a non-negative integer"), TypeError)
self.query = self.query.offset(offset)
if distinct:
self.query = self.query.distinct()
if for_update:
self.query = self.query.for_update(skip_locked=skip_locked, nowait=not wait)
if group_by:
self._validate_group_by(group_by)
self.query = self.query.groupby(group_by)
if self.apply_permissions:
self.add_permission_conditions()
self.query.immutable = True
return self.query
def validate_doctype(self):
if not TABLE_NAME_PATTERN.match(self.doctype):
frappe.throw(_("Invalid DocType: {0}").format(self.doctype))
def apply_fields(self, fields):
self.fields = self.parse_fields(fields)
if self.apply_permissions:
self.fields = self.apply_field_permissions()
if not self.fields:
self.fields = [self.table.name]
self.query._child_queries = []
for field in self.fields:
if isinstance(field, DynamicTableField):
self.query = field.apply_select(self.query)
elif isinstance(field, ChildQuery):
self.query._child_queries.append(field)
else:
self.query = self.query.select(field)
def apply_filters(
self,
filters: dict[str, FilterValue] | FilterValue | list[list | FilterValue] | None = None,
):
if filters is None:
return
if isinstance(filters, FilterValue):
filters = {"name": convert_to_value(filters)}
if isinstance(filters, Criterion):
self.query = self.query.where(filters)
elif isinstance(filters, dict):
self.apply_dict_filters(filters)
elif isinstance(filters, list | tuple):
if all(isinstance(d, FilterValue) for d in filters) and len(filters) > 0:
self.apply_dict_filters({"name": ("in", tuple(convert_to_value(f) for f in filters))})
else:
for filter in filters:
if isinstance(filter, FilterValue | Criterion | dict):
self.apply_filters(filter)
elif isinstance(filter, list | tuple):
self.apply_list_filters(filter)
else:
raise ValueError(f"Unknown filter type: {type(filters)}")
else:
raise ValueError(f"Unknown filter type: {type(filters)}")
def apply_list_filters(self, filter: list):
if len(filter) == 2:
field, value = filter
self._apply_filter(field, value)
elif len(filter) == 3:
field, operator, value = filter
self._apply_filter(field, value, operator)
elif len(filter) == 4:
doctype, field, operator, value = filter
self._apply_filter(field, value, operator, doctype)
else:
raise ValueError(f"Unknown filter format: {filter}")
def apply_dict_filters(self, filters: dict[str, FilterValue | list]):
for field, value in filters.items():
operator = "="
if isinstance(value, list | tuple):
operator, value = value
self._apply_filter(field, value, operator)
def _apply_filter(
self,
field: str | Field,
value: FilterValue | list | set | None,
operator: str = "=",
doctype: str | None = None,
):
_field = self._validate_and_prepare_filter_field(field, doctype)
_value = value
_operator = operator
# Apply implicit join if child table is referenced
if doctype and doctype != self.doctype:
meta = frappe.get_meta(doctype)
table = frappe.qb.DocType(doctype)
if meta.istable and not self.query.is_joined(table):
self.query = self.query.left_join(table).on(
(table.parent == self.table.name) & (table.parenttype == self.doctype)
)
_value = convert_to_value(_value)
if not _value and isinstance(_value, list | tuple | set):
_value = ("",)
# Handle nested set operators
if _operator in NESTED_SET_OPERATORS:
hierarchy = _operator
docname = _value
# Use the original field name string for get_field if _field was converted
# If _field is from a dynamic field, its name might be just the target fieldname.
# We need the original string ('link.target') or the fieldname from the main doctype.
original_field_name = field if isinstance(field, str) else _field.name
# Check if the original field name exists in the *main* doctype meta
main_meta = frappe.get_meta(self.doctype)
if main_meta.has_field(original_field_name):
_df = main_meta.get_field(original_field_name)
ref_doctype = _df.options if _df else self.doctype
else:
# If not in main doctype, assume it's a standard field like 'name' or refers to the main doctype itself
# This part might need refinement if nested set operators are used with dynamic fields.
ref_doctype = self.doctype
nodes = get_nested_set_hierarchy_result(ref_doctype, docname, hierarchy)
operator_fn = (
OPERATOR_MAP["not in"]
if hierarchy in ("not ancestors of", "not descendants of")
else OPERATOR_MAP["in"]
)
self.query = self.query.where(operator_fn(_field, nodes or ("",)))
return
operator_fn = OPERATOR_MAP[_operator.casefold()]
if _value is None and isinstance(_field, Field):
self.query = self.query.where(_field.isnull())
else:
self.query = self.query.where(operator_fn(_field, _value))
def _validate_and_prepare_filter_field(self, field: str | Field, doctype: str | None = None) -> Field:
"""Validate field name for filters and return a pypika Field object. Handles dynamic fields."""
if isinstance(field, Term):
# return if field is already a pypika Term
return field
# Reject backticks
if "`" in field:
frappe.throw(
_("Filter fields cannot contain backticks (`)."),
frappe.ValidationError,
title=_("Invalid Filter"),
)
# Handle dot notation (link_field.target_field or child_table_field.target_field)
if "." in field:
# Disallow tabDoc.field notation in filters.
dynamic_field = DynamicTableField.parse(field, self.doctype, allow_tab_notation=False)
if dynamic_field:
# Parsed successfully as link/child field access
target_doctype = dynamic_field.doctype
target_fieldname = dynamic_field.fieldname
parent_doctype_for_perm = (
dynamic_field.parent_doctype if isinstance(dynamic_field, ChildTableField) else None
)
self._check_filter_field_permission(target_doctype, target_fieldname, parent_doctype_for_perm)
self.query = dynamic_field.apply_join(self.query)
# Return the pypika Field object associated with the dynamic field
return dynamic_field.field
else:
# Contains '.' but is not a valid link/child field access pattern
# This rejects tabDoc.field and other invalid formats like a.b.c
frappe.throw(
_(
"Invalid filter field format: {0}. Use 'fieldname' or 'link_fieldname.target_fieldname'."
).format(field),
frappe.ValidationError,
title=_("Invalid Filter"),
)
else:
# No '.' and no '`'. Check if it's a simple field name (alphanumeric + underscore).
if not re.fullmatch(r"\w+", field):
frappe.throw(
_(
"Invalid characters in fieldname: {0}. Only letters, numbers, and underscores are allowed."
).format(field),
frappe.ValidationError,
title=_("Invalid Filter"),
)
# It's a simple, valid fieldname like 'name' or 'creation'
target_doctype = doctype or self.doctype
target_fieldname = field
parent_doctype_for_perm = self.parent_doctype if doctype else None
self._check_filter_field_permission(target_doctype, target_fieldname, parent_doctype_for_perm)
# Convert string field name to pypika Field object for the specified/current doctype
return frappe.qb.DocType(target_doctype)[target_fieldname]
def _check_filter_field_permission(self, doctype: str, fieldname: str, parent_doctype: str | None = None):
"""Check if the user has permission to access the given field for filtering."""
if not self.apply_permissions:
return
permission_type = self.get_permission_type(doctype)
permitted_fields = get_permitted_fields(
doctype=doctype,
parenttype=parent_doctype,
permission_type=permission_type,
ignore_virtual=True,
user=self.user,
)
if fieldname not in permitted_fields:
frappe.throw(
_("You do not have permission to filter on field: {0}").format(
frappe.bold(f"{doctype}.{fieldname}")
),
frappe.PermissionError,
title=_("Permission Error"),
)
def parse_string_field(self, field: str):
"""
Parses a field string into a pypika Field object.
Handles:
- *
- simple_field
- `quoted_field`
- tabDocType.simple_field
- `tabDocType`.`quoted_field`
- `tabTable Name`.`quoted_field`
- Aliases for all above formats (e.g., field as alias)
"""
if field == "*":
return self.table.star
alias = None
field_part = field
if " as " in field.lower(): # Case-insensitive check for ' as '
# Find the last occurrence of ' as ' to handle potential aliases named 'as'
parts = re.split(r"\s+as\s+", field, flags=re.IGNORECASE)
if len(parts) > 1:
field_part = parts[0].strip()
alias = parts[1].strip().strip('`"') # Remove potential quotes from alias
match = FIELD_PARSE_REGEX.match(field_part)
if not match:
frappe.throw(_("Could not parse field: {0}").format(field))
# Groups: 1: table_quote, 2: table_name_with_tab, 3: field_quote, 4: field_name
groups = match.groups()
table_name = groups[1] # This will be None if no table part (e.g., just 'field')
field_name = groups[3] # This will be the field name (e.g., 'field')
if table_name:
# Table name specified (e.g., `tabX`.`y` or tabX.y or `tabX Y`.`y`)
# Ensure the extracted table name is valid before creating DocType object
if not TABLE_NAME_PATTERN.match(table_name.lstrip("tab")):
frappe.throw(_("Invalid characters in table name: {0}").format(table_name))
table_obj = frappe.qb.DocType(table_name)
pypika_field = table_obj[field_name]
else:
# Simple field name (e.g., `y` or y) - use the main table
pypika_field = self.table[field_name]
if alias:
return pypika_field.as_(alias)
else:
return pypika_field
def parse_fields(
self, fields: str | list | tuple | Field | AggregateFunction | None
) -> "list[Field | AggregateFunction | Criterion | DynamicTableField | ChildQuery]":
if not fields:
return []
# return if fields is already a pypika Term
if isinstance(fields, Term):
return [fields]
initial_field_list = []
if isinstance(fields, str):
# Split comma-separated fields passed as a single string
initial_field_list.extend(f.strip() for f in COMMA_PATTERN.split(fields) if f.strip())
elif isinstance(fields, list | tuple):
for item in fields:
if isinstance(item, str) and "," in item:
# Split comma-separated strings within the list
initial_field_list.extend(f.strip() for f in COMMA_PATTERN.split(item) if f.strip())
else:
# Add non-comma-separated items directly
initial_field_list.append(item)
else:
frappe.throw(_("Fields must be a string, list, tuple, pypika Field, or pypika Function"))
_fields = []
# Iterate through the list where each item could be a single field, criterion, or a comma-separated string
for item in initial_field_list:
if isinstance(item, str):
# Sanitize and split potentially comma-separated strings within the list
sanitized_item = _sanitize_field(item.strip(), self.is_mariadb).strip()
if sanitized_item:
parsed = self._parse_single_field_item(sanitized_item)
if isinstance(parsed, list): # Result from parsing a child query dict
_fields.extend(parsed)
elif parsed:
_fields.append(parsed)
else:
# Handle non-string items (like dict for child query, or pre-parsed Field/Function)
parsed = self._parse_single_field_item(item)
if isinstance(parsed, list):
_fields.extend(parsed)
elif parsed:
_fields.append(parsed)
return _fields
def _parse_single_field_item(
self, field: str | Criterion | dict | Field | Function
) -> "list | Criterion | Field | Function | DynamicTableField | ChildQuery | None":
"""Parses a single item from the fields list/tuple. Assumes comma-separated strings have already been split."""
if isinstance(field, Criterion | Field | Function):
return field
elif isinstance(field, dict):
# Handle child queries defined as dicts {fieldname: [child_fields]}
_parsed_fields = []
for child_field, child_fields_list in field.items():
# Ensure child_fields_list is a list or tuple
if not isinstance(child_fields_list, list | tuple):
frappe.throw(
_("Child query fields for '{0}' must be a list or tuple.").format(child_field)
)
_parsed_fields.append(ChildQuery(child_field, list(child_fields_list), self.doctype))
# Return list as a dict entry might represent multiple child queries (though unlikely)
return _parsed_fields
# At this point, field must be a string (already validated and sanitized)
if not isinstance(field, str):
frappe.throw(_("Invalid field type: {0}").format(type(field)))
# Try parsing as SQL function first
if parsed_function := SqlFunctionParser.parse(field):
return parsed_function
# Then try parsing as dynamic field (link/child table access)
elif parsed := DynamicTableField.parse(field, self.doctype):
return parsed
# Otherwise, parse as a standard field (simple, quoted, table-qualified, with/without alias)
else:
# Note: Comma handling is done in parse_fields before this method is called
return self.parse_string_field(field)
def _validate_group_by(self, group_by: str):
"""Validate the group_by string argument."""
if not isinstance(group_by, str):
frappe.throw(_("Group By must be a string"), TypeError)
parts = COMMA_PATTERN.split(group_by)
for part in parts:
field_name = part.strip()
if not field_name:
continue
if field_name.isdigit():
continue
if not ALLOWED_SQL_FIELD_PATTERN.match(field_name):
frappe.throw(
_("Invalid field format in Group By: {0}").format(field_name),
frappe.PermissionError,
)
def apply_order_by(self, order_by: str | None):
if not order_by or order_by == DefaultOrderBy:
return
self._validate_order_by(order_by)
for declaration in order_by.split(","):
if _order_by := declaration.strip():
parts = _order_by.split(" ")
order_field = parts[0]
order_direction = Order.asc if (len(parts) > 1 and parts[1].lower() == "asc") else Order.desc
self.query = self.query.orderby(order_field, order=order_direction)
def _validate_order_by(self, order_by: str):
"""Validate the order_by string argument."""
if not isinstance(order_by, str):
frappe.throw(_("Order By must be a string"), TypeError)
valid_directions = {"asc", "desc"}
for declaration in order_by.split(","):
if _order_by := declaration.strip():
parts = _order_by.split()
field_name = parts[0]
direction = None
if len(parts) > 1:
direction = parts[1].lower()
if field_name.isdigit():
pass
elif not ALLOWED_SQL_FIELD_PATTERN.match(field_name):
frappe.throw(
_("Invalid field format in Order By: {0}").format(field_name),
frappe.PermissionError,
)
if direction and direction not in valid_directions:
frappe.throw(
_("Invalid direction in Order By: {0}. Must be 'ASC' or 'DESC'.").format(parts[1]),
ValueError,
)
def check_read_permission(self):
"""Check if user has read permission on the doctype"""
def has_permission(ptype):
return frappe.has_permission(
self.doctype,
ptype,
user=self.user,
parent_doctype=self.parent_doctype,
)
if not has_permission("select") and not has_permission("read"):
frappe.throw(
_("Insufficient Permission for {0}").format(frappe.bold(self.doctype)), frappe.PermissionError
)
def apply_field_permissions(self):
"""Filter the list of fields based on permlevel."""
allowed_fields = []
parent_permission_type = self.get_permission_type(self.doctype)
permitted_fields_cache = {}
def get_cached_permitted_fields(doctype, parenttype, permission_type):
cache_key = (doctype, parenttype, permission_type)
if cache_key not in permitted_fields_cache:
permitted_fields_cache[cache_key] = set(
get_permitted_fields(
doctype=doctype,
parenttype=parenttype,
permission_type=permission_type,
ignore_virtual=True,
)
)
return permitted_fields_cache[cache_key]
permitted_fields_set = get_cached_permitted_fields(
self.doctype, self.parent_doctype, parent_permission_type
)
for field in self.fields:
if isinstance(field, ChildTableField):
if parent_permission_type == "select":
# Skip child table fields if parent permission is only 'select'
continue
# Cache permitted fields for child doctypes if accessed multiple times
permitted_child_fields_set = get_cached_permitted_fields(
field.doctype, field.parent_doctype, self.get_permission_type(field.doctype)
)
# Check permission for the specific field in the child table
if field.fieldname in permitted_child_fields_set:
allowed_fields.append(field)
elif isinstance(field, LinkTableField):
# Check permission for the link field *in the parent doctype*
if field.link_fieldname in permitted_fields_set:
# Also check if user has permission to read/select the target doctype
target_doctype = field.doctype
has_target_perm = frappe.has_permission(
target_doctype, "select", user=self.user
) or frappe.has_permission(target_doctype, "read", user=self.user)
if has_target_perm:
# Finally, check if the specific field *in the target doctype* is permitted
permitted_target_fields_set = get_cached_permitted_fields(
target_doctype, None, self.get_permission_type(target_doctype)
)
if field.fieldname in permitted_target_fields_set:
allowed_fields.append(field)
elif isinstance(field, ChildQuery):
if parent_permission_type == "select":
# Skip child queries if parent permission is only 'select'
continue
# Cache permitted fields for the child doctype of the query
permitted_child_fields_set = get_cached_permitted_fields(
field.doctype, field.parent_doctype, self.get_permission_type(field.doctype)
)
# Filter the fields *within* the ChildQuery object based on permissions
field.fields = [f for f in field.fields if f in permitted_child_fields_set]
# Only add the child query if it still has fields after filtering
if field.fields:
allowed_fields.append(field)
elif isinstance(field, Field):
if field.name == "*":
# Expand '*' to include all permitted fields
# Avoid reparsing '*' recursively by passing the actual list
allowed_fields.extend(self.parse_fields(list(permitted_fields_set)))
# Check if the field name (without alias) is permitted
elif field.name in permitted_fields_set:
allowed_fields.append(field)
# Handle cases where the field might be aliased but the base name is permitted
elif hasattr(field, "alias") and field.alias and field.name in permitted_fields_set:
allowed_fields.append(field)
elif isinstance(field, PseudoColumnMapper | Function):
# Typically functions or complex terms
allowed_fields.append(field)
return allowed_fields
def get_user_permission_conditions(self, role_permissions):
"""Build conditions for user permissions and return tuple of (conditions, fetch_shared_docs)"""
conditions = []
fetch_shared_docs = False
# add user permission only if role has read perm
if not (role_permissions.get("read") or role_permissions.get("select")):
return conditions, fetch_shared_docs
user_permissions = frappe.permissions.get_user_permissions(self.user)
if not user_permissions:
return conditions, fetch_shared_docs
fetch_shared_docs = True
doctype_link_fields = self.get_doctype_link_fields()
for df in doctype_link_fields:
if df.get("ignore_user_permissions"):
continue
user_permission_values = user_permissions.get(df.get("options"), {})
if user_permission_values:
docs = []
for permission in user_permission_values:
if not permission.get("applicable_for"):
docs.append(permission.get("doc"))
# append docs based on user permission applicable on reference doctype
# this is useful when getting list of docs from a link field
# in this case parent doctype of the link
# will be the reference doctype
elif df.get("fieldname") == "name" and self.reference_doctype:
if permission.get("applicable_for") == self.reference_doctype:
docs.append(permission.get("doc"))
elif permission.get("applicable_for") == self.doctype:
docs.append(permission.get("doc"))
if docs:
field_name = df.get("fieldname")
strict_user_permissions = frappe.get_system_settings("apply_strict_user_permissions")
if strict_user_permissions:
conditions.append(self.table[field_name].isin(docs))
else:
empty_value_condition = self.table[field_name].isnull()
value_condition = self.table[field_name].isin(docs)
conditions.append(empty_value_condition | value_condition)
return conditions, fetch_shared_docs
def get_doctype_link_fields(self):
meta = frappe.get_meta(self.doctype)
# append current doctype with fieldname as 'name' as first link field
doctype_link_fields = [{"options": self.doctype, "fieldname": "name"}]
# append other link fields
doctype_link_fields.extend(meta.get_link_fields())
return doctype_link_fields
def add_permission_conditions(self):
conditions = []
role_permissions = frappe.permissions.get_role_permissions(self.doctype, user=self.user)
fetch_shared_docs = False
if self.requires_owner_constraint(role_permissions):
fetch_shared_docs = True
conditions.append(self.table.owner == self.user)
# skip user perm check if owner constraint is required
elif role_permissions.get("read") or role_permissions.get("select"):
user_perm_conditions, fetch_shared = self.get_user_permission_conditions(role_permissions)
conditions.extend(user_perm_conditions)
fetch_shared_docs = fetch_shared_docs or fetch_shared
permission_query_conditions = self.get_permission_query_conditions()
if permission_query_conditions:
conditions.extend(permission_query_conditions)
shared_docs = []
if fetch_shared_docs:
shared_docs = frappe.share.get_shared(self.doctype, self.user)
if shared_docs:
shared_condition = self.table.name.isin(shared_docs)
if conditions:
# (permission conditions) OR (shared condition)
self.query = self.query.where(Criterion.all(conditions) | shared_condition)
else:
self.query = self.query.where(shared_condition)
elif conditions:
# AND all permission conditions
self.query = self.query.where(Criterion.all(conditions))
def get_permission_query_conditions(self):
"""Add permission query conditions from hooks and server scripts"""
from frappe.core.doctype.server_script.server_script_utils import get_server_script_map
conditions = []
hooks = frappe.get_hooks("permission_query_conditions", {})
condition_methods = hooks.get(self.doctype, []) + hooks.get("*", [])
for method in condition_methods:
if c := frappe.call(frappe.get_attr(method), self.user, doctype=self.doctype):
conditions.append(RawCriterion(c))
# Get conditions from server scripts
if permission_script_name := get_server_script_map().get("permission_query", {}).get(self.doctype):
script = frappe.get_doc("Server Script", permission_script_name)
if condition := script.get_permission_query_conditions(self.user):
conditions.append(RawCriterion(condition))
return conditions
def get_permission_type(self, doctype) -> str:
"""Get permission type (select/read) based on user permissions"""
if frappe.only_has_select_perm(doctype, user=self.user):
return "select"
return "read"
def requires_owner_constraint(self, role_permissions):
"""Return True if "select" or "read" isn't available without being creator."""
if not role_permissions.get("has_if_owner_enabled"):
return
if_owner_perms = role_permissions.get("if_owner")
if not if_owner_perms:
return
# has select or read without if owner, no need for constraint
for perm_type in ("select", "read"):
if role_permissions.get(perm_type) and perm_type not in if_owner_perms:
return
# not checking if either select or read if present in if_owner_perms
# because either of those is required to perform a query
return True
class Permission:
@classmethod
def check_permissions(cls, query, **kwargs):
if not isinstance(query, str):
query = query.get_sql()
doctype = cls.get_tables_from_query(query)
if isinstance(doctype, str):
doctype = [doctype]
for dt in doctype:
dt = TAB_PATTERN.sub("", dt)
if not frappe.has_permission(
dt,
"select",
user=kwargs.get("user"),
parent_doctype=kwargs.get("parent_doctype"),
) and not frappe.has_permission(
dt,
"read",
user=kwargs.get("user"),
parent_doctype=kwargs.get("parent_doctype"),
):
frappe.throw(
_("Insufficient Permission for {0}").format(frappe.bold(dt)), frappe.PermissionError
)
@staticmethod
def get_tables_from_query(query: str):
return [table for table in WORDS_PATTERN.findall(query) if table.startswith("tab")]
class DynamicTableField:
def __init__(
self,
doctype: str,
fieldname: str,
parent_doctype: str,
alias: str | None = None,
) -> None:
self.doctype = doctype
self.fieldname = fieldname
self.alias = alias
self.parent_doctype = parent_doctype
def __str__(self) -> str:
table_name = f"`tab{self.doctype}`"
fieldname = f"`{self.fieldname}`"
if frappe.db.db_type == "postgres":
table_name = table_name.replace("`", '"')
fieldname = fieldname.replace("`", '"')
alias = f"AS {self.alias}" if self.alias else ""
return f"{table_name}.{fieldname} {alias}".strip()
@staticmethod
def parse(field: str, doctype: str, allow_tab_notation: bool = True):
if "." in field:
alias = None
# Handle 'as' alias, case-insensitive, taking the last occurrence
if " as " in field.lower():
parts = re.split(r"\s+as\s+", field, flags=re.IGNORECASE)
if len(parts) > 1:
field_part = parts[0].strip()
alias = parts[-1].strip().strip('`"') # Get last part as alias
field = field_part # Use the part before alias for further parsing
child_match = None
if allow_tab_notation:
# Regex to match `tabDoc`.`field`, "tabDoc"."field", tabDoc.field
# Group 1: Doctype name (without 'tab')
# Group 2: Optional quote for fieldname
# Group 3: Fieldname
# Ensures quotes are consistent or absent on fieldname using backreference \2
# Uses re.match to ensure the pattern matches the *entire* field string
# Allow spaces in doctype name (Group 1) and field name (Group 3)
child_match = re.match(r'[`"]?tab([\w\s]+)[`"]?\.([`"]?)([\w\s]+)\2$', field)
if child_match:
child_doctype_name = child_match.group(1)
child_field = child_match.group(3)
if child_doctype_name == doctype:
# Referencing a field in the main doctype using `tabDoctype.field` notation.
# This should be handled by the standard field parser, not as a DynamicTableField.
return None
# Found a child table reference like tabChildDoc.child_field
# Note: parent_fieldname is None here as it's directly specified via tab notation
return ChildTableField(child_doctype_name, child_field, doctype, alias=alias)
else:
# Try parsing as LinkTableField (link_field.target_field) or ChildTableField (child_field.target_field)
# This handles patterns not starting with 'tab' prefix
if "." not in field: # Should not happen due to outer check, but safety
return None
parts = field.split(".", 1)
if len(parts) != 2: # Ensure it splits into exactly two parts
return None
potential_parent_fieldname, target_fieldname = parts
# Basic validation for the parts to avoid unnecessary metadata lookups on invalid input
# We expect simple identifiers here. Quoted/complex names are handled elsewhere or by child_match.
if (
not potential_parent_fieldname.replace("_", "").isalnum()
or not target_fieldname.replace("_", "").isalnum()
):
return None
try:
meta = frappe.get_meta(doctype) # Get meta of the *parent* doctype
# Check if the first part is a valid fieldname in the parent doctype
if not meta.has_field(potential_parent_fieldname):
return None # Not a field in the parent, so not link/child access pattern
linked_field = meta.get_field(potential_parent_fieldname)
except Exception:
# Handle cases where doctype doesn't exist, etc.
print(f"Error getting metadata for {doctype} while parsing field {field}")
return None
if linked_field:
linked_doctype = linked_field.options
if linked_field.fieldtype == "Link":
# It's a Link field access: parent_doctype.link_fieldname.target_fieldname
return LinkTableField(
linked_doctype, target_fieldname, doctype, potential_parent_fieldname, alias=alias
)
elif linked_field.fieldtype in frappe.model.table_fields:
# It's a Child Table field access: parent_doctype.child_table_fieldname.target_fieldname
return ChildTableField(
linked_doctype, target_fieldname, doctype, potential_parent_fieldname, alias=alias
)
return None
def apply_select(self, query: QueryBuilder) -> QueryBuilder:
raise NotImplementedError
class ChildTableField(DynamicTableField):
def __init__(
self,
doctype: str,
fieldname: str,
parent_doctype: str,
parent_fieldname: str | None = None,
alias: str | None = None,
) -> None:
self.doctype = doctype
self.fieldname = fieldname
self.alias = alias
self.parent_doctype = parent_doctype
self.parent_fieldname = parent_fieldname
self.table = frappe.qb.DocType(self.doctype)
self.field = self.table[self.fieldname]
def apply_select(self, query: QueryBuilder) -> QueryBuilder:
table = frappe.qb.DocType(self.doctype)
query = self.apply_join(query)
return query.select(getattr(table, self.fieldname).as_(self.alias or None))
def apply_join(self, query: QueryBuilder) -> QueryBuilder:
table = frappe.qb.DocType(self.doctype)
main_table = frappe.qb.DocType(self.parent_doctype)
if not query.is_joined(table):
query = query.left_join(table).on(
(table.parent == main_table.name) & (table.parenttype == self.parent_doctype)
)
return query
class LinkTableField(DynamicTableField):
def __init__(
self,
doctype: str,
fieldname: str,
parent_doctype: str,
link_fieldname: str,
alias: str | None = None,
) -> None:
super().__init__(doctype, fieldname, parent_doctype, alias=alias)
self.link_fieldname = link_fieldname
self.table = frappe.qb.DocType(self.doctype)
self.field = self.table[self.fieldname]
def apply_select(self, query: QueryBuilder) -> QueryBuilder:
table = frappe.qb.DocType(self.doctype)
query = self.apply_join(query)
return query.select(getattr(table, self.fieldname).as_(self.alias or None))
def apply_join(self, query: QueryBuilder) -> QueryBuilder:
table = frappe.qb.DocType(self.doctype)
main_table = frappe.qb.DocType(self.parent_doctype)
if not query.is_joined(table):
query = query.left_join(table).on(table.name == getattr(main_table, self.link_fieldname))
return query
class ChildQuery:
def __init__(
self,
fieldname: str,
fields: list,
parent_doctype: str,
) -> None:
field = frappe.get_meta(parent_doctype).get_field(fieldname)
if field.fieldtype not in frappe.model.table_fields:
return
self.fieldname = fieldname
self.fields = fields
self.parent_doctype = parent_doctype
self.doctype = field.options
def get_query(self, parent_names=None) -> QueryBuilder:
filters = {
"parenttype": self.parent_doctype,
"parentfield": self.fieldname,
"parent": ["in", parent_names],
}
return frappe.qb.get_query(
self.doctype,
fields=[*self.fields, "parent", "parentfield"],
filters=filters,
order_by="idx asc",
)
class SqlFunctionParser:
_supported_functions: ClassVar[dict[str, BuiltinFunctionType]] = {
f.value.lower(): getattr(functions, f.name) for f in SqlFunctions if hasattr(functions, f.name)
}
@staticmethod
def _parse_argument_expression(arg_str: str) -> Term | None:
"""Attempts to parse simple arithmetic expressions between fields."""
# Map symbols to pypika's expected operation methods if needed, or rely on overloading
# For +, -, *, / pypika Field overloading works directly
supported_operators = {"+": operator.add, "-": operator.sub, "*": operator.mul, "/": operator.truediv}
for op_symbol, _op_func in supported_operators.items():
# Split only on the first occurrence of the operator
parts = arg_str.split(op_symbol, 1)
if len(parts) == 2:
left_str, right_str = parts[0].strip(), parts[1].strip()
# Validate both parts are valid field names (simple or quoted)
if ALLOWED_SQL_FIELD_PATTERN.match(left_str.strip('`"')) and ALLOWED_SQL_FIELD_PATTERN.match(
right_str.strip('`"')
):
# Create Field or PseudoColumnMapper objects
left_field = (
PseudoColumnMapper(left_str)
if "`" in left_str or '"' in left_str
else Field(left_str)
)
right_field = (
PseudoColumnMapper(right_str)
if "`" in right_str or '"' in right_str
else Field(right_str)
)
# Use pypika's operator overloading for Field objects
if op_symbol == "+":
return left_field + right_field
elif op_symbol == "-":
return left_field - right_field
elif op_symbol == "*":
return left_field * right_field
elif op_symbol == "/":
return left_field / right_field
# If no simple binary arithmetic expression is found
return None
@staticmethod
def parse(field_str: str) -> Function | None:
"""
Parses a string to see if it represents a *supported* SQL function call.
Returns a pypika Function object if valid and supported, otherwise None.
Handles simple arguments (field names, *), aliases, and simple expressions.
"""
match = SQL_FUNCTION_PATTERN.match(field_str.strip())
if not match:
return None
func_name, args_str, alias = match.groups()
func_name_lower = func_name.lower()
# Strip backticks from alias if present
if alias:
alias = alias.strip("`")
# Check if the function is in our supported list
pypika_func = SqlFunctionParser._supported_functions.get(func_name_lower)
if not pypika_func:
# Function name not found in SqlFunctions enum values
return None
# Handle NOW() specifically (often takes no arguments)
if func_name_lower == "now" and not args_str.strip():
return pypika_func(alias=alias or None)
# Parse arguments
parsed_args = []
if args_str.strip():
raw_args = ARGS_SPLIT_PATTERN.split(args_str.strip())
for arg in raw_args:
arg = arg.strip()
if not arg:
continue
if arg == "*":
# Only allow '*' for specific functions like COUNT
if func_name_lower != "count":
frappe.throw(_("Wildcard '*' argument is only supported for COUNT function."))
parsed_args.append(Term.wrap_constant("*"))
continue
evaluated_arg = literal_eval_(arg)
if evaluated_arg != arg: # Successfully evaluated to a literal
parsed_args.append(Term.wrap_constant(evaluated_arg))
else:
# Not '*' or a simple literal. Could be a field, quoted field, keyword, or expression.
# Check if it's a simple or quoted field name first.
if ALLOWED_SQL_FIELD_PATTERN.match(arg.strip('`"')):
# Pass the original arg (with quotes if present) to the mapper/field
if "`" in arg or '"' in arg:
parsed_args.append(PseudoColumnMapper(arg))
else:
parsed_args.append(Field(arg))
# Check if it's a valid expression/keyword based on allowed characters
elif ALLOWED_ARGUMENT_CHARS_PATTERN.match(arg):
# Attempt to parse as a simple arithmetic expression first
parsed_expr = SqlFunctionParser._parse_argument_expression(arg)
if parsed_expr:
parsed_args.append(parsed_expr)
else:
# Fallback: Pass the raw string argument for non-expression cases like 'distinct name'
parsed_args.append(arg)
else:
# Argument contains disallowed characters.
frappe.throw(
_("Invalid characters or format in function argument expression: {0}").format(
arg
),
frappe.ValidationError,
)
return pypika_func(*parsed_args, alias=alias or None)
def literal_eval_(literal):
try:
return literal_eval(literal)
except (ValueError, SyntaxError):
return literal
def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str) -> list[str]:
"""Get matching nodes based on operator."""
table = frappe.qb.DocType(doctype)
try:
lft, rgt = frappe.qb.from_(table).select("lft", "rgt").where(table.name == name).run()[0]
except IndexError:
lft, rgt = None, None
if hierarchy in ("descendants of", "not descendants of", "descendants of (inclusive)"):
result = (
frappe.qb.from_(table)
.select(table.name)
.where(table.lft > lft)
.where(table.rgt < rgt)
.orderby(table.lft, order=Order.asc)
.run(pluck=True)
)
if hierarchy == "descendants of (inclusive)":
result += [name]
else:
# Get ancestor elements of a DocType with a tree structure
result = (
frappe.qb.from_(table)
.select(table.name)
.where(table.lft < lft)
.where(table.rgt > rgt)
.orderby(table.lft, order=Order.desc)
.run(pluck=True)
)
return result
@lru_cache(maxsize=1024)
def _validate_select_field(field: str):
"""Validate a field string intended for use in a SELECT clause."""
if field == "*":
return
if field.isdigit():
return
if ALLOWED_FIELD_PATTERN.match(field) or SqlFunctionParser.parse(field):
return
frappe.throw(
_(
"Invalid field format for SELECT: {0}. Field names must be simple, backticked, table-qualified, aliased, a valid function call, or '*'."
).format(field),
frappe.PermissionError,
)
@lru_cache(maxsize=1024)
def _sanitize_field(field: str, is_mariadb):
"""Validate and sanitize a field string for SELECT clause by stripping comments."""
_validate_select_field(field)
stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower")
if is_mariadb:
stripped_field = MARIADB_SPECIFIC_COMMENT.sub("", stripped_field)
return stripped_field.strip()
class RawCriterion(Term):
"""A class to represent raw SQL string as a criterion.
Allows using raw SQL strings in pypika queries:
frappe.qb.from_("DocType").where(RawCriterion("name like 'a%'"))
"""
def __init__(self, sql_string: str):
self.sql_string = sql_string
super().__init__()
def get_sql(self, **kwargs: Any) -> str:
return self.sql_string
def __and__(self, other):
return CombinedRawCriterion(self, other, "AND")
def __or__(self, other):
return CombinedRawCriterion(self, other, "OR")
def __invert__(self):
return RawCriterion(f"NOT ({self.sql_string})")
class CombinedRawCriterion(RawCriterion):
def __init__(self, left, right, operator):
self.left = left
self.right = right
self.operator = operator
super(RawCriterion, self).__init__()
def get_sql(self, **kwargs: Any) -> str:
left_sql = self.left.get_sql(**kwargs) if hasattr(self.left, "get_sql") else str(self.left)
right_sql = self.right.get_sql(**kwargs) if hasattr(self.right, "get_sql") else str(self.right)
return f"({left_sql}) {self.operator} ({right_sql})"