diff --git a/frappe/database/query.py b/frappe/database/query.py index f135b3cd6c..37ab0404cf 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -32,6 +32,23 @@ COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))") # to allow table names like __Auth TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII) +# Pattern to validate field names in SELECT: +# Allows: name, `name`, name as alias, `name` as alias, `table`.`name`, `table`.`name` as alias, table.name, table.name as alias +ALLOWED_FIELD_PATTERN = re.compile(r"^(?:`?\w+`?\.)?(`?\w+`?|\w+)(?:\s+as\s+\w+)?$", flags=re.ASCII) + +# Pattern to validate basic SQL function call syntax: word(...) [as alias] +FUNCTION_CALL_PATTERN = re.compile(r"^\w+\(.*\)(?:\s+as\s+\w+)?$", flags=re.IGNORECASE | re.ASCII) + +# Pattern to validate field names used in various SQL clauses (WHERE, GROUP BY, ORDER BY): +# Allows simple field names, backticked names, and table-qualified names (e.g., name, `name`, `table`.`name`, table.name) +# Does NOT allow aliases ('as alias') or functions. +ALLOWED_SQL_FIELD_PATTERN = re.compile(r"^(?:`?\w+`?\.)?(`?\w+`?|\w+)$", flags=re.ASCII) + +# Regex to parse field names: +# Group 1: Optional table name (e.g., `tabDocType` or tabDocType) +# Group 2: Field name (e.g., `field` or field) +FIELD_PARSE_REGEX = re.compile(r"^(?:[`\"]?(tab\w+)[`\"]?\.)?[`\"]?(\w+)[`\"]?$") + class Engine: def get_query( @@ -65,6 +82,7 @@ class Engine: self.validate_filters = validate_filters self.user = user or frappe.session.user self.parent_doctype = parent_doctype + self.apply_permissions = not ignore_permissions if isinstance(table, Table): self.table = table @@ -74,6 +92,9 @@ class Engine: self.validate_doctype() self.table = qb.DocType(table) + if self.apply_permissions: + self.check_read_permission() + if update: self.query = qb.update(self.table, immutable=False) elif into: @@ -82,18 +103,19 @@ class Engine: self.query = qb.from_(self.table, immutable=False).delete() else: self.query = qb.from_(self.table, immutable=False) - self.fields = self.parse_fields(fields) - if not ignore_permissions: - self.fields = self.apply_field_permissions() - self.apply_fields(self.fields) + self.apply_fields(fields) self.apply_filters(filters) self.apply_order_by(order_by) if limit: + if not isinstance(limit, int) or limit < 0: + frappe.throw(_("Limit must be a non-negative integer"), TypeError) self.query = self.query.limit(limit) if offset: + if not isinstance(offset, int) or offset < 0: + frappe.throw(_("Offset must be a non-negative integer"), TypeError) self.query = self.query.offset(offset) if distinct: @@ -103,10 +125,10 @@ class Engine: self.query = self.query.for_update(skip_locked=skip_locked, nowait=not wait) if group_by: + self._validate_group_by(group_by) self.query = self.query.groupby(group_by) - if not ignore_permissions: - self.check_read_permission() + if self.apply_permissions: self.add_permission_conditions() self.query.immutable = True @@ -117,6 +139,10 @@ class Engine: frappe.throw(_("Invalid DocType: {0}").format(self.doctype)) def apply_fields(self, fields): + self.fields = self.parse_fields(fields) + if self.apply_permissions: + self.fields = self.apply_field_permissions() + if not self.fields: self.fields = [self.table.name] @@ -180,31 +206,59 @@ class Engine: self._apply_filter(field, value, operator) + def _validate_and_prepare_filter_field(self, field: str | Field, doctype: str | None = None) -> Field: + """Validate field name for filters and return a pypika Field object. Handles dynamic fields.""" + _field = field + is_fieldname_safe = False + + if not isinstance(_field, str): + # Assume it's a pypika Field or similar, return as is. + return _field + + # Always validate field name if it contains special characters to prevent injection + if SPECIAL_CHAR_PATTERN.search(_field): + # First, try to parse as a dynamic field (contains '.') + dynamic_field = DynamicTableField.parse(_field, self.doctype) + if dynamic_field: + # Legitimate dynamic field (e.g., table.field), apply join + self.query = dynamic_field.apply_join(self.query) + _field = dynamic_field.field # _field is now a pypika Field object + # If not a dynamic field and doesn't match the allowed pattern, reject it + elif not ALLOWED_SQL_FIELD_PATTERN.match(_field): + frappe.throw( + _( + "Invalid filter field format: {0}. Field names cannot contain special characters or disallowed patterns." + ).format(_field), + frappe.PermissionError, + ) + # If it matched the pattern (e.g., `fieldname` with backticks), mark as safe + else: + is_fieldname_safe = True + # No special characters, treat as a standard field name, mark as safe + else: + is_fieldname_safe = True + + # Convert string field name to pypika Field object if needed + if is_fieldname_safe: + # Note: We are converting the original `field` string here, + # not the potentially modified `_field` + # if it became a dynamic field object earlier. + _field = frappe.qb.DocType(doctype or self.doctype)[field] + + return _field + def _apply_filter( self, - field: str, + field: str | Field, value: FilterValue | list | set | None, operator: str = "=", doctype: str | None = None, ): - _field = field + _field = self._validate_and_prepare_filter_field(field, doctype) _value = value _operator = operator - if not isinstance(_field, str): - pass - elif not self.validate_filters and (dynamic_field := DynamicTableField.parse(field, self.doctype)): - # apply implicit join if link field's field is referenced - self.query = dynamic_field.apply_join(self.query) - _field = dynamic_field.field - elif self.validate_filters and SPECIAL_CHAR_PATTERN.search(_field): - frappe.throw(_("Invalid filter: {0}").format(_field), frappe.PermissionError) - elif not doctype or doctype == self.doctype: - _field = self.table[field] - elif doctype: - _field = frappe.qb.DocType(doctype)[field] - - # apply implicit join if child table is referenced + # Apply implicit join if child table is referenced if doctype and doctype != self.doctype: meta = frappe.get_meta(doctype) table = frappe.qb.DocType(doctype) @@ -218,12 +272,14 @@ class Engine: if not _value and isinstance(_value, list | tuple | set): _value = ("",) - # Nested set + # Handle nested set operators if _operator in NESTED_SET_OPERATORS: hierarchy = _operator docname = _value - _df = frappe.get_meta(self.doctype).get_field(field) + # Use the original field name string for get_field if _field was converted + original_field_name = field if isinstance(field, str) else _field.name + _df = frappe.get_meta(self.doctype).get_field(original_field_name) ref_doctype = _df.options if _df else self.doctype nodes = get_nested_set_hierarchy_result(ref_doctype, docname, hierarchy) @@ -232,10 +288,7 @@ class Engine: if hierarchy in ("not ancestors of", "not descendants of") else OPERATOR_MAP["in"] ) - if nodes: - self.query = self.query.where(operator_fn(_field, nodes)) - else: - self.query = self.query.where(operator_fn(_field, ("",))) + self.query = self.query.where(operator_fn(_field, nodes or ("",))) return operator_fn = OPERATOR_MAP[_operator.casefold()] @@ -298,69 +351,138 @@ class Engine: # Fall back for functions not present in `SqlFunctions`` return Function(func, *_args, alias=alias or None) - def sanitize_fields(self, fields: str | list | tuple): - if isinstance(fields, list | tuple): - return [ - _sanitize_field(field, self.is_mariadb) if isinstance(field, str) else field - for field in fields - ] - elif isinstance(fields, str): - return _sanitize_field(fields, self.is_mariadb) - return fields - def parse_string_field(self, field: str): + """ + Parses a field string into a pypika Field object. + + Handles: + - * + - simple_field + - `quoted_field` + - tabDocType.simple_field + - `tabDocType`.`quoted_field` + - Aliases for all above formats (e.g., field as alias) + """ if field == "*": return self.table.star - alias = None - if " as " in field: - field, alias = field.split(" as ") - if "`" in field: - if alias: - return PseudoColumnMapper(f"{field} {alias}") - return PseudoColumnMapper(field) - if alias: - return self.table[field].as_(alias) - return self.table[field] - def parse_fields(self, fields: str | list | tuple | None) -> list: + alias = None + field_part = field + if " as " in field.lower(): # Case-insensitive check for ' as ' + # Find the last occurrence of ' as ' to handle potential aliases named 'as' + parts = re.split(r"\s+as\s+", field, flags=re.IGNORECASE) + if len(parts) > 1: + field_part = parts[0].strip() + alias = parts[1].strip().strip('`"') # Remove potential quotes from alias + + match = FIELD_PARSE_REGEX.match(field_part) + + if not match: + frappe.throw(_("Could not parse field: {0}").format(field)) + + table_name, field_name = match.groups() + + if table_name: + # Table name specified (e.g., `tabX`.`y` or tabX.y) + table_obj = frappe.qb.DocType(table_name) + pypika_field = table_obj[field_name] + else: + # Simple field name (e.g., `y` or y) - use the main table + pypika_field = self.table[field_name] + + if alias: + return pypika_field.as_(alias) + else: + return pypika_field + + def _parse_single_field_item( + self, field: str | Criterion | dict + ) -> list | Criterion | Field | "DynamicTableField" | "ChildQuery" | None: + """Parses a single item from the fields list/tuple. Assumes comma-separated strings have already been split.""" + if isinstance(field, Criterion): + return field + elif isinstance(field, dict): + # Handle child queries defined as dicts {fieldname: [child_fields]} + _parsed_fields = [] + for child_field, child_fields_list in field.items(): + # Ensure child_fields_list is a list or tuple + if not isinstance(child_fields_list, list | tuple): + frappe.throw( + _("Child query fields for '{0}' must be a list or tuple.").format(child_field) + ) + _parsed_fields.append(ChildQuery(child_field, list(child_fields_list), self.doctype)) + # Return list as a dict entry might represent multiple child queries (though unlikely) + return _parsed_fields + + # At this point, field must be a string (already validated and sanitized) + if not isinstance(field, str): + frappe.throw(_("Invalid field type: {0}").format(type(field))) + + # Check for functions or dynamic fields first + if has_function(field): + return self.get_function_object(field) + elif parsed := DynamicTableField.parse(field, self.doctype): + return parsed + # Otherwise, parse as a standard field (simple, quoted, table-qualified, with/without alias) + else: + # Note: Comma handling is done in parse_fields before this method is called + return self.parse_string_field(field) + + def parse_fields( + self, fields: str | list | tuple | None + ) -> list[Field | Criterion | "DynamicTableField" | "ChildQuery"]: if not fields: return [] - fields = self.sanitize_fields(fields) - if not isinstance(fields, list | tuple): - fields = [fields] - - def parse_field(field: str): - if has_function(field): - return self.get_function_object(field) - elif parsed := DynamicTableField.parse(field, self.doctype): - return parsed - else: - return self.parse_string_field(field) + sanitized_field_list = [] + if isinstance(fields, str): + # Split comma-separated fields passed as a single string *before* sanitizing + sanitized_field_list.extend( + _sanitize_field(f.strip(), self.is_mariadb) for f in COMMA_PATTERN.split(fields) if f.strip() + ) + elif isinstance(fields, list | tuple): + # Sanitize fields if input is already a list/tuple + sanitized_field_list.extend( + _sanitize_field(field, self.is_mariadb) if isinstance(field, str) else field + for field in fields + ) + else: + frappe.throw(_("Fields must be a string, list, or tuple")) _fields = [] - for field in fields: - if isinstance(field, Criterion): - _fields.append(field) - elif isinstance(field, dict): - for child_field, fields in field.items(): - _fields.append(ChildQuery(child_field, fields, self.doctype)) - elif isinstance(field, str): - if "," in field: - field = field.casefold() if "`" not in field else field - field_list = COMMA_PATTERN.split(field) - for field in field_list: - if _field := field.strip(): - _fields.append(parse_field(_field)) - else: - _fields.append(parse_field(field)) + # Iterate through the list where each item is a single field definition or criterion + for field_item in sanitized_field_list: + parsed = self._parse_single_field_item(field_item) + if isinstance(parsed, list): # Result from parsing a child query dict + _fields.extend(parsed) + elif parsed: + _fields.append(parsed) return _fields + def _validate_group_by(self, group_by: str): + """Validate the group_by string argument.""" + if not isinstance(group_by, str): + frappe.throw(_("Group By must be a string"), TypeError) + parts = COMMA_PATTERN.split(group_by) + for part in parts: + field_name = part.strip() + if not field_name: + continue + if field_name.isdigit(): + continue + if not ALLOWED_SQL_FIELD_PATTERN.match(field_name): + frappe.throw( + _("Invalid field format in Group By: {0}").format(field_name), + frappe.PermissionError, + ) + def apply_order_by(self, order_by: str | None): if not order_by or order_by == DefaultOrderBy: return + self._validate_order_by(order_by) + for declaration in order_by.split(","): if _order_by := declaration.strip(): parts = _order_by.split(" ") @@ -368,6 +490,35 @@ class Engine: order_direction = Order.asc if (len(parts) > 1 and parts[1].lower() == "asc") else Order.desc self.query = self.query.orderby(order_field, order=order_direction) + def _validate_order_by(self, order_by: str): + """Validate the order_by string argument.""" + if not isinstance(order_by, str): + frappe.throw(_("Order By must be a string"), TypeError) + + valid_directions = {"asc", "desc"} + + for declaration in order_by.split(","): + if _order_by := declaration.strip(): + parts = _order_by.split() + field_name = parts[0] + direction = None + if len(parts) > 1: + direction = parts[1].lower() + + if field_name.isdigit(): + pass + elif not ALLOWED_SQL_FIELD_PATTERN.match(field_name): + frappe.throw( + _("Invalid field format in Order By: {0}").format(field_name), + frappe.PermissionError, + ) + + if direction and direction not in valid_directions: + frappe.throw( + _("Invalid direction in Order By: {0}. Must be 'ASC' or 'DESC'.").format(parts[1]), + ValueError, + ) + def check_read_permission(self): """Check if user has read permission on the doctype""" @@ -385,41 +536,66 @@ class Engine: ) def apply_field_permissions(self): - """Allow fields that user has permission to read""" - permitted_fields = get_permitted_fields( - doctype=self.doctype, - parenttype=self.parent_doctype, - permission_type=self.get_permission_type(self.doctype), - ignore_virtual=True, - ) + """Filter the list of fields based on permlevel.""" allowed_fields = [] + permitted_fields_set = set( + get_permitted_fields( + doctype=self.doctype, + parenttype=self.parent_doctype, + permission_type=self.get_permission_type(self.doctype), + ignore_virtual=True, + ) + ) + for field in self.fields: if isinstance(field, ChildTableField): - permitted_child_fields = get_permitted_fields( - doctype=field.doctype, - parenttype=field.parent_doctype, - permission_type=self.get_permission_type(field.doctype), - ignore_virtual=True, + # Cache permitted fields for child doctypes if accessed multiple times + permitted_child_fields_set = set( + get_permitted_fields( + doctype=field.doctype, + parenttype=field.parent_doctype, + permission_type=self.get_permission_type(field.doctype), + ignore_virtual=True, + ) ) - if field.child_fieldname in permitted_child_fields: + # Check permission for the specific field in the child table + if field.fieldname in permitted_child_fields_set: allowed_fields.append(field) elif isinstance(field, LinkTableField): - if field.link_fieldname in permitted_fields: + # Check permission for the link field *in the parent doctype* + if field.link_fieldname in permitted_fields_set: allowed_fields.append(field) elif isinstance(field, ChildQuery): - permitted_child_fields = get_permitted_fields( - doctype=field.doctype, - parenttype=field.parent_doctype, - permission_type=self.get_permission_type(field.doctype), - ignore_virtual=True, + # Cache permitted fields for the child doctype of the query + permitted_child_fields_set = set( + get_permitted_fields( + doctype=field.doctype, + parenttype=field.parent_doctype, + permission_type=self.get_permission_type(field.doctype), + ignore_virtual=True, + ) ) - field.fields = [f for f in field.fields if f in permitted_child_fields] - allowed_fields.append(field) + # Filter the fields *within* the ChildQuery object based on permissions + field.fields = [f for f in field.fields if f in permitted_child_fields_set] + # Only add the child query if it still has fields after filtering + if field.fields: + allowed_fields.append(field) elif isinstance(field, Field): if field.name == "*": - allowed_fields.extend(self.parse_fields(permitted_fields)) - elif field.name in permitted_fields: + # Expand '*' to include all permitted fields + # Avoid reparsing '*' recursively by passing the actual list + allowed_fields.extend(self.parse_fields(list(permitted_fields_set))) + # Check if the field name (without alias) is permitted + elif field.name in permitted_fields_set: allowed_fields.append(field) + # Handle cases where the field might be aliased but the base name is permitted + elif hasattr(field, "alias") and field.alias and field.name in permitted_fields_set: + allowed_fields.append(field) + + elif isinstance(field, PseudoColumnMapper): + # Typically functions or complex terms + allowed_fields.append(field) + return allowed_fields def get_user_permission_conditions(self, role_permissions): @@ -769,16 +945,37 @@ def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str) -> return result +@lru_cache(maxsize=1024) +def _validate_select_field(field: str): + """Validate a field string intended for use in a SELECT clause.""" + if field == "*": + return + + if field.isdigit(): + return + + if ALLOWED_FIELD_PATTERN.match(field) or FUNCTION_CALL_PATTERN.match(field): + return + + frappe.throw( + _( + "Invalid field format for SELECT: {0}. Field names must be simple, backticked, table-qualified, aliased, a valid function call, or '*'." + ).format(field), + frappe.PermissionError, + ) + + @lru_cache(maxsize=1024) def _sanitize_field(field: str, is_mariadb): - if field == "*" or not SPECIAL_CHAR_PATTERN.search(field): - # Skip checking if there are no special characters - return field + """Validate and sanitize a field string for SELECT clause by stripping comments.""" + _validate_select_field(field) stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower") + if is_mariadb: - return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field) - return stripped_field + stripped_field = MARIADB_SPECIFIC_COMMENT.sub("", stripped_field) + + return stripped_field.strip() class RawCriterion(Term):