diff --git a/frappe/database/query.py b/frappe/database/query.py index ef3e42f941..9e6d4611a2 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -144,6 +144,9 @@ FUNCTION_MAPPING = { "MONTH": Month, } +# Functions that accept '*' as an argument (e.g., COUNT(*)) +STAR_ALLOWED_FUNCTIONS = frozenset(("COUNT",)) + # Mapping from operator names to pypika Arithmetic enum values # Operators use dict format: {"ADD": [left, right], "as": "alias"} # Supported: ADD (+), SUB (-), MUL (*), DIV (/) @@ -1852,12 +1855,12 @@ class SQLFunctionParser: func_class = FUNCTION_MAPPING[function_name] if isinstance(function_args, str): - parsed_arg = self._parse_and_validate_argument(function_args) + parsed_arg = self._parse_and_validate_argument(function_args, function_name=function_name) function_call = func_class(parsed_arg) elif isinstance(function_args, list): parsed_args = [] for arg in function_args: - parsed_arg = self._parse_and_validate_argument(arg) + parsed_arg = self._parse_and_validate_argument(arg, function_name=function_name) parsed_args.append(parsed_arg) function_call = func_class(*parsed_args) elif isinstance(function_args, (int | float)): @@ -1924,7 +1927,7 @@ class SQLFunctionParser: else: return expression - def _parse_and_validate_argument(self, arg): + def _parse_and_validate_argument(self, arg, *, function_name: str | None = None): """Parse and validate a single function/operator argument against SQL injection. Supports: @@ -1935,7 +1938,7 @@ class SQLFunctionParser: if isinstance(arg, (int | float)): return arg elif isinstance(arg, str): - return self._validate_string_argument(arg) + return self._validate_string_argument(arg, function_name=function_name) elif isinstance(arg, dict): # Recursively handle nested functions and operators if self.is_function_dict(arg): @@ -1958,17 +1961,20 @@ class SQLFunctionParser: frappe.ValidationError, ) - def _validate_string_argument(self, arg: str): + def _validate_string_argument(self, arg: str, *, function_name: str | None = None): """Validate string arguments to prevent SQL injection.""" arg = arg.strip() if not arg: frappe.throw(_("Empty string arguments are not allowed"), frappe.ValidationError) - # Special case: allow '*' for COUNT(*) and similar aggregate functions + # Special case: allow '*' only for specific functions like COUNT(*) if arg == "*": - # Star() produces correct unquoted * for COUNT(*) - # Column("*") would produce COUNT("*") which is wrong + if function_name not in STAR_ALLOWED_FUNCTIONS: + frappe.throw( + _("'*' is only allowed in {0} SQL function(s)").format(", ".join(STAR_ALLOWED_FUNCTIONS)), + frappe.ValidationError, + ) return Star() # Check for string literals (quoted strings) @@ -1989,7 +1995,6 @@ class SQLFunctionParser: ).format(arg), frappe.ValidationError, ) - elif self._is_valid_field_name(arg): # Validate field name and check permissions self._validate_function_field_arg(arg)