From ddca77429c05513d8a7238bb059fd4d80a37086f Mon Sep 17 00:00:00 2001 From: Faris Ansari Date: Thu, 1 May 2025 03:20:35 +0530 Subject: [PATCH] fix: secure query building Add strict validation using regex for fields in SELECT, filters, GROUP BY, and ORDER BY clauses to avoid potential SQL injection risks. Refactor field parsing and validation logic into dedicated functions. --- frappe/database/query.py | 403 +++++++++++++++++++++++++++++---------- 1 file changed, 300 insertions(+), 103 deletions(-) diff --git a/frappe/database/query.py b/frappe/database/query.py index f135b3cd6c..37ab0404cf 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -32,6 +32,23 @@ COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))") # to allow table names like __Auth TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII) +# Pattern to validate field names in SELECT: +# Allows: name, `name`, name as alias, `name` as alias, `table`.`name`, `table`.`name` as alias, table.name, table.name as alias +ALLOWED_FIELD_PATTERN = re.compile(r"^(?:`?\w+`?\.)?(`?\w+`?|\w+)(?:\s+as\s+\w+)?$", flags=re.ASCII) + +# Pattern to validate basic SQL function call syntax: word(...) [as alias] +FUNCTION_CALL_PATTERN = re.compile(r"^\w+\(.*\)(?:\s+as\s+\w+)?$", flags=re.IGNORECASE | re.ASCII) + +# Pattern to validate field names used in various SQL clauses (WHERE, GROUP BY, ORDER BY): +# Allows simple field names, backticked names, and table-qualified names (e.g., name, `name`, `table`.`name`, table.name) +# Does NOT allow aliases ('as alias') or functions. +ALLOWED_SQL_FIELD_PATTERN = re.compile(r"^(?:`?\w+`?\.)?(`?\w+`?|\w+)$", flags=re.ASCII) + +# Regex to parse field names: +# Group 1: Optional table name (e.g., `tabDocType` or tabDocType) +# Group 2: Field name (e.g., `field` or field) +FIELD_PARSE_REGEX = re.compile(r"^(?:[`\"]?(tab\w+)[`\"]?\.)?[`\"]?(\w+)[`\"]?$") + class Engine: def get_query( @@ -65,6 +82,7 @@ class Engine: self.validate_filters = validate_filters self.user = user or frappe.session.user self.parent_doctype = parent_doctype + self.apply_permissions = not ignore_permissions if isinstance(table, Table): self.table = table @@ -74,6 +92,9 @@ class Engine: self.validate_doctype() self.table = qb.DocType(table) + if self.apply_permissions: + self.check_read_permission() + if update: self.query = qb.update(self.table, immutable=False) elif into: @@ -82,18 +103,19 @@ class Engine: self.query = qb.from_(self.table, immutable=False).delete() else: self.query = qb.from_(self.table, immutable=False) - self.fields = self.parse_fields(fields) - if not ignore_permissions: - self.fields = self.apply_field_permissions() - self.apply_fields(self.fields) + self.apply_fields(fields) self.apply_filters(filters) self.apply_order_by(order_by) if limit: + if not isinstance(limit, int) or limit < 0: + frappe.throw(_("Limit must be a non-negative integer"), TypeError) self.query = self.query.limit(limit) if offset: + if not isinstance(offset, int) or offset < 0: + frappe.throw(_("Offset must be a non-negative integer"), TypeError) self.query = self.query.offset(offset) if distinct: @@ -103,10 +125,10 @@ class Engine: self.query = self.query.for_update(skip_locked=skip_locked, nowait=not wait) if group_by: + self._validate_group_by(group_by) self.query = self.query.groupby(group_by) - if not ignore_permissions: - self.check_read_permission() + if self.apply_permissions: self.add_permission_conditions() self.query.immutable = True @@ -117,6 +139,10 @@ class Engine: frappe.throw(_("Invalid DocType: {0}").format(self.doctype)) def apply_fields(self, fields): + self.fields = self.parse_fields(fields) + if self.apply_permissions: + self.fields = self.apply_field_permissions() + if not self.fields: self.fields = [self.table.name] @@ -180,31 +206,59 @@ class Engine: self._apply_filter(field, value, operator) + def _validate_and_prepare_filter_field(self, field: str | Field, doctype: str | None = None) -> Field: + """Validate field name for filters and return a pypika Field object. Handles dynamic fields.""" + _field = field + is_fieldname_safe = False + + if not isinstance(_field, str): + # Assume it's a pypika Field or similar, return as is. + return _field + + # Always validate field name if it contains special characters to prevent injection + if SPECIAL_CHAR_PATTERN.search(_field): + # First, try to parse as a dynamic field (contains '.') + dynamic_field = DynamicTableField.parse(_field, self.doctype) + if dynamic_field: + # Legitimate dynamic field (e.g., table.field), apply join + self.query = dynamic_field.apply_join(self.query) + _field = dynamic_field.field # _field is now a pypika Field object + # If not a dynamic field and doesn't match the allowed pattern, reject it + elif not ALLOWED_SQL_FIELD_PATTERN.match(_field): + frappe.throw( + _( + "Invalid filter field format: {0}. Field names cannot contain special characters or disallowed patterns." + ).format(_field), + frappe.PermissionError, + ) + # If it matched the pattern (e.g., `fieldname` with backticks), mark as safe + else: + is_fieldname_safe = True + # No special characters, treat as a standard field name, mark as safe + else: + is_fieldname_safe = True + + # Convert string field name to pypika Field object if needed + if is_fieldname_safe: + # Note: We are converting the original `field` string here, + # not the potentially modified `_field` + # if it became a dynamic field object earlier. + _field = frappe.qb.DocType(doctype or self.doctype)[field] + + return _field + def _apply_filter( self, - field: str, + field: str | Field, value: FilterValue | list | set | None, operator: str = "=", doctype: str | None = None, ): - _field = field + _field = self._validate_and_prepare_filter_field(field, doctype) _value = value _operator = operator - if not isinstance(_field, str): - pass - elif not self.validate_filters and (dynamic_field := DynamicTableField.parse(field, self.doctype)): - # apply implicit join if link field's field is referenced - self.query = dynamic_field.apply_join(self.query) - _field = dynamic_field.field - elif self.validate_filters and SPECIAL_CHAR_PATTERN.search(_field): - frappe.throw(_("Invalid filter: {0}").format(_field), frappe.PermissionError) - elif not doctype or doctype == self.doctype: - _field = self.table[field] - elif doctype: - _field = frappe.qb.DocType(doctype)[field] - - # apply implicit join if child table is referenced + # Apply implicit join if child table is referenced if doctype and doctype != self.doctype: meta = frappe.get_meta(doctype) table = frappe.qb.DocType(doctype) @@ -218,12 +272,14 @@ class Engine: if not _value and isinstance(_value, list | tuple | set): _value = ("",) - # Nested set + # Handle nested set operators if _operator in NESTED_SET_OPERATORS: hierarchy = _operator docname = _value - _df = frappe.get_meta(self.doctype).get_field(field) + # Use the original field name string for get_field if _field was converted + original_field_name = field if isinstance(field, str) else _field.name + _df = frappe.get_meta(self.doctype).get_field(original_field_name) ref_doctype = _df.options if _df else self.doctype nodes = get_nested_set_hierarchy_result(ref_doctype, docname, hierarchy) @@ -232,10 +288,7 @@ class Engine: if hierarchy in ("not ancestors of", "not descendants of") else OPERATOR_MAP["in"] ) - if nodes: - self.query = self.query.where(operator_fn(_field, nodes)) - else: - self.query = self.query.where(operator_fn(_field, ("",))) + self.query = self.query.where(operator_fn(_field, nodes or ("",))) return operator_fn = OPERATOR_MAP[_operator.casefold()] @@ -298,69 +351,138 @@ class Engine: # Fall back for functions not present in `SqlFunctions`` return Function(func, *_args, alias=alias or None) - def sanitize_fields(self, fields: str | list | tuple): - if isinstance(fields, list | tuple): - return [ - _sanitize_field(field, self.is_mariadb) if isinstance(field, str) else field - for field in fields - ] - elif isinstance(fields, str): - return _sanitize_field(fields, self.is_mariadb) - return fields - def parse_string_field(self, field: str): + """ + Parses a field string into a pypika Field object. + + Handles: + - * + - simple_field + - `quoted_field` + - tabDocType.simple_field + - `tabDocType`.`quoted_field` + - Aliases for all above formats (e.g., field as alias) + """ if field == "*": return self.table.star - alias = None - if " as " in field: - field, alias = field.split(" as ") - if "`" in field: - if alias: - return PseudoColumnMapper(f"{field} {alias}") - return PseudoColumnMapper(field) - if alias: - return self.table[field].as_(alias) - return self.table[field] - def parse_fields(self, fields: str | list | tuple | None) -> list: + alias = None + field_part = field + if " as " in field.lower(): # Case-insensitive check for ' as ' + # Find the last occurrence of ' as ' to handle potential aliases named 'as' + parts = re.split(r"\s+as\s+", field, flags=re.IGNORECASE) + if len(parts) > 1: + field_part = parts[0].strip() + alias = parts[1].strip().strip('`"') # Remove potential quotes from alias + + match = FIELD_PARSE_REGEX.match(field_part) + + if not match: + frappe.throw(_("Could not parse field: {0}").format(field)) + + table_name, field_name = match.groups() + + if table_name: + # Table name specified (e.g., `tabX`.`y` or tabX.y) + table_obj = frappe.qb.DocType(table_name) + pypika_field = table_obj[field_name] + else: + # Simple field name (e.g., `y` or y) - use the main table + pypika_field = self.table[field_name] + + if alias: + return pypika_field.as_(alias) + else: + return pypika_field + + def _parse_single_field_item( + self, field: str | Criterion | dict + ) -> list | Criterion | Field | "DynamicTableField" | "ChildQuery" | None: + """Parses a single item from the fields list/tuple. Assumes comma-separated strings have already been split.""" + if isinstance(field, Criterion): + return field + elif isinstance(field, dict): + # Handle child queries defined as dicts {fieldname: [child_fields]} + _parsed_fields = [] + for child_field, child_fields_list in field.items(): + # Ensure child_fields_list is a list or tuple + if not isinstance(child_fields_list, list | tuple): + frappe.throw( + _("Child query fields for '{0}' must be a list or tuple.").format(child_field) + ) + _parsed_fields.append(ChildQuery(child_field, list(child_fields_list), self.doctype)) + # Return list as a dict entry might represent multiple child queries (though unlikely) + return _parsed_fields + + # At this point, field must be a string (already validated and sanitized) + if not isinstance(field, str): + frappe.throw(_("Invalid field type: {0}").format(type(field))) + + # Check for functions or dynamic fields first + if has_function(field): + return self.get_function_object(field) + elif parsed := DynamicTableField.parse(field, self.doctype): + return parsed + # Otherwise, parse as a standard field (simple, quoted, table-qualified, with/without alias) + else: + # Note: Comma handling is done in parse_fields before this method is called + return self.parse_string_field(field) + + def parse_fields( + self, fields: str | list | tuple | None + ) -> list[Field | Criterion | "DynamicTableField" | "ChildQuery"]: if not fields: return [] - fields = self.sanitize_fields(fields) - if not isinstance(fields, list | tuple): - fields = [fields] - - def parse_field(field: str): - if has_function(field): - return self.get_function_object(field) - elif parsed := DynamicTableField.parse(field, self.doctype): - return parsed - else: - return self.parse_string_field(field) + sanitized_field_list = [] + if isinstance(fields, str): + # Split comma-separated fields passed as a single string *before* sanitizing + sanitized_field_list.extend( + _sanitize_field(f.strip(), self.is_mariadb) for f in COMMA_PATTERN.split(fields) if f.strip() + ) + elif isinstance(fields, list | tuple): + # Sanitize fields if input is already a list/tuple + sanitized_field_list.extend( + _sanitize_field(field, self.is_mariadb) if isinstance(field, str) else field + for field in fields + ) + else: + frappe.throw(_("Fields must be a string, list, or tuple")) _fields = [] - for field in fields: - if isinstance(field, Criterion): - _fields.append(field) - elif isinstance(field, dict): - for child_field, fields in field.items(): - _fields.append(ChildQuery(child_field, fields, self.doctype)) - elif isinstance(field, str): - if "," in field: - field = field.casefold() if "`" not in field else field - field_list = COMMA_PATTERN.split(field) - for field in field_list: - if _field := field.strip(): - _fields.append(parse_field(_field)) - else: - _fields.append(parse_field(field)) + # Iterate through the list where each item is a single field definition or criterion + for field_item in sanitized_field_list: + parsed = self._parse_single_field_item(field_item) + if isinstance(parsed, list): # Result from parsing a child query dict + _fields.extend(parsed) + elif parsed: + _fields.append(parsed) return _fields + def _validate_group_by(self, group_by: str): + """Validate the group_by string argument.""" + if not isinstance(group_by, str): + frappe.throw(_("Group By must be a string"), TypeError) + parts = COMMA_PATTERN.split(group_by) + for part in parts: + field_name = part.strip() + if not field_name: + continue + if field_name.isdigit(): + continue + if not ALLOWED_SQL_FIELD_PATTERN.match(field_name): + frappe.throw( + _("Invalid field format in Group By: {0}").format(field_name), + frappe.PermissionError, + ) + def apply_order_by(self, order_by: str | None): if not order_by or order_by == DefaultOrderBy: return + self._validate_order_by(order_by) + for declaration in order_by.split(","): if _order_by := declaration.strip(): parts = _order_by.split(" ") @@ -368,6 +490,35 @@ class Engine: order_direction = Order.asc if (len(parts) > 1 and parts[1].lower() == "asc") else Order.desc self.query = self.query.orderby(order_field, order=order_direction) + def _validate_order_by(self, order_by: str): + """Validate the order_by string argument.""" + if not isinstance(order_by, str): + frappe.throw(_("Order By must be a string"), TypeError) + + valid_directions = {"asc", "desc"} + + for declaration in order_by.split(","): + if _order_by := declaration.strip(): + parts = _order_by.split() + field_name = parts[0] + direction = None + if len(parts) > 1: + direction = parts[1].lower() + + if field_name.isdigit(): + pass + elif not ALLOWED_SQL_FIELD_PATTERN.match(field_name): + frappe.throw( + _("Invalid field format in Order By: {0}").format(field_name), + frappe.PermissionError, + ) + + if direction and direction not in valid_directions: + frappe.throw( + _("Invalid direction in Order By: {0}. Must be 'ASC' or 'DESC'.").format(parts[1]), + ValueError, + ) + def check_read_permission(self): """Check if user has read permission on the doctype""" @@ -385,41 +536,66 @@ class Engine: ) def apply_field_permissions(self): - """Allow fields that user has permission to read""" - permitted_fields = get_permitted_fields( - doctype=self.doctype, - parenttype=self.parent_doctype, - permission_type=self.get_permission_type(self.doctype), - ignore_virtual=True, - ) + """Filter the list of fields based on permlevel.""" allowed_fields = [] + permitted_fields_set = set( + get_permitted_fields( + doctype=self.doctype, + parenttype=self.parent_doctype, + permission_type=self.get_permission_type(self.doctype), + ignore_virtual=True, + ) + ) + for field in self.fields: if isinstance(field, ChildTableField): - permitted_child_fields = get_permitted_fields( - doctype=field.doctype, - parenttype=field.parent_doctype, - permission_type=self.get_permission_type(field.doctype), - ignore_virtual=True, + # Cache permitted fields for child doctypes if accessed multiple times + permitted_child_fields_set = set( + get_permitted_fields( + doctype=field.doctype, + parenttype=field.parent_doctype, + permission_type=self.get_permission_type(field.doctype), + ignore_virtual=True, + ) ) - if field.child_fieldname in permitted_child_fields: + # Check permission for the specific field in the child table + if field.fieldname in permitted_child_fields_set: allowed_fields.append(field) elif isinstance(field, LinkTableField): - if field.link_fieldname in permitted_fields: + # Check permission for the link field *in the parent doctype* + if field.link_fieldname in permitted_fields_set: allowed_fields.append(field) elif isinstance(field, ChildQuery): - permitted_child_fields = get_permitted_fields( - doctype=field.doctype, - parenttype=field.parent_doctype, - permission_type=self.get_permission_type(field.doctype), - ignore_virtual=True, + # Cache permitted fields for the child doctype of the query + permitted_child_fields_set = set( + get_permitted_fields( + doctype=field.doctype, + parenttype=field.parent_doctype, + permission_type=self.get_permission_type(field.doctype), + ignore_virtual=True, + ) ) - field.fields = [f for f in field.fields if f in permitted_child_fields] - allowed_fields.append(field) + # Filter the fields *within* the ChildQuery object based on permissions + field.fields = [f for f in field.fields if f in permitted_child_fields_set] + # Only add the child query if it still has fields after filtering + if field.fields: + allowed_fields.append(field) elif isinstance(field, Field): if field.name == "*": - allowed_fields.extend(self.parse_fields(permitted_fields)) - elif field.name in permitted_fields: + # Expand '*' to include all permitted fields + # Avoid reparsing '*' recursively by passing the actual list + allowed_fields.extend(self.parse_fields(list(permitted_fields_set))) + # Check if the field name (without alias) is permitted + elif field.name in permitted_fields_set: allowed_fields.append(field) + # Handle cases where the field might be aliased but the base name is permitted + elif hasattr(field, "alias") and field.alias and field.name in permitted_fields_set: + allowed_fields.append(field) + + elif isinstance(field, PseudoColumnMapper): + # Typically functions or complex terms + allowed_fields.append(field) + return allowed_fields def get_user_permission_conditions(self, role_permissions): @@ -769,16 +945,37 @@ def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str) -> return result +@lru_cache(maxsize=1024) +def _validate_select_field(field: str): + """Validate a field string intended for use in a SELECT clause.""" + if field == "*": + return + + if field.isdigit(): + return + + if ALLOWED_FIELD_PATTERN.match(field) or FUNCTION_CALL_PATTERN.match(field): + return + + frappe.throw( + _( + "Invalid field format for SELECT: {0}. Field names must be simple, backticked, table-qualified, aliased, a valid function call, or '*'." + ).format(field), + frappe.PermissionError, + ) + + @lru_cache(maxsize=1024) def _sanitize_field(field: str, is_mariadb): - if field == "*" or not SPECIAL_CHAR_PATTERN.search(field): - # Skip checking if there are no special characters - return field + """Validate and sanitize a field string for SELECT clause by stripping comments.""" + _validate_select_field(field) stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower") + if is_mariadb: - return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field) - return stripped_field + stripped_field = MARIADB_SPECIFIC_COMMENT.sub("", stripped_field) + + return stripped_field.strip() class RawCriterion(Term):