fix: restrict '*' argument to COUNT function only
- Add STAR_ALLOWED_FUNCTIONS frozenset with COUNT - Pass function_name through parse chain to validate '*' usage - Prevents '*' in functions like SUM(*), AVG(*) where it's invalid
This commit is contained in:
parent
99039c23cb
commit
c40933dca7
1 changed files with 14 additions and 9 deletions
|
|
@ -144,6 +144,9 @@ FUNCTION_MAPPING = {
|
|||
"MONTH": Month,
|
||||
}
|
||||
|
||||
# Functions that accept '*' as an argument (e.g., COUNT(*))
|
||||
STAR_ALLOWED_FUNCTIONS = frozenset(("COUNT",))
|
||||
|
||||
# Mapping from operator names to pypika Arithmetic enum values
|
||||
# Operators use dict format: {"ADD": [left, right], "as": "alias"}
|
||||
# Supported: ADD (+), SUB (-), MUL (*), DIV (/)
|
||||
|
|
@ -1852,12 +1855,12 @@ class SQLFunctionParser:
|
|||
func_class = FUNCTION_MAPPING[function_name]
|
||||
|
||||
if isinstance(function_args, str):
|
||||
parsed_arg = self._parse_and_validate_argument(function_args)
|
||||
parsed_arg = self._parse_and_validate_argument(function_args, function_name=function_name)
|
||||
function_call = func_class(parsed_arg)
|
||||
elif isinstance(function_args, list):
|
||||
parsed_args = []
|
||||
for arg in function_args:
|
||||
parsed_arg = self._parse_and_validate_argument(arg)
|
||||
parsed_arg = self._parse_and_validate_argument(arg, function_name=function_name)
|
||||
parsed_args.append(parsed_arg)
|
||||
function_call = func_class(*parsed_args)
|
||||
elif isinstance(function_args, (int | float)):
|
||||
|
|
@ -1924,7 +1927,7 @@ class SQLFunctionParser:
|
|||
else:
|
||||
return expression
|
||||
|
||||
def _parse_and_validate_argument(self, arg):
|
||||
def _parse_and_validate_argument(self, arg, *, function_name: str | None = None):
|
||||
"""Parse and validate a single function/operator argument against SQL injection.
|
||||
|
||||
Supports:
|
||||
|
|
@ -1935,7 +1938,7 @@ class SQLFunctionParser:
|
|||
if isinstance(arg, (int | float)):
|
||||
return arg
|
||||
elif isinstance(arg, str):
|
||||
return self._validate_string_argument(arg)
|
||||
return self._validate_string_argument(arg, function_name=function_name)
|
||||
elif isinstance(arg, dict):
|
||||
# Recursively handle nested functions and operators
|
||||
if self.is_function_dict(arg):
|
||||
|
|
@ -1958,17 +1961,20 @@ class SQLFunctionParser:
|
|||
frappe.ValidationError,
|
||||
)
|
||||
|
||||
def _validate_string_argument(self, arg: str):
|
||||
def _validate_string_argument(self, arg: str, *, function_name: str | None = None):
|
||||
"""Validate string arguments to prevent SQL injection."""
|
||||
arg = arg.strip()
|
||||
|
||||
if not arg:
|
||||
frappe.throw(_("Empty string arguments are not allowed"), frappe.ValidationError)
|
||||
|
||||
# Special case: allow '*' for COUNT(*) and similar aggregate functions
|
||||
# Special case: allow '*' only for specific functions like COUNT(*)
|
||||
if arg == "*":
|
||||
# Star() produces correct unquoted * for COUNT(*)
|
||||
# Column("*") would produce COUNT("*") which is wrong
|
||||
if function_name not in STAR_ALLOWED_FUNCTIONS:
|
||||
frappe.throw(
|
||||
_("'*' is only allowed in {0} SQL function(s)").format(", ".join(STAR_ALLOWED_FUNCTIONS)),
|
||||
frappe.ValidationError,
|
||||
)
|
||||
return Star()
|
||||
|
||||
# Check for string literals (quoted strings)
|
||||
|
|
@ -1989,7 +1995,6 @@ class SQLFunctionParser:
|
|||
).format(arg),
|
||||
frappe.ValidationError,
|
||||
)
|
||||
|
||||
elif self._is_valid_field_name(arg):
|
||||
# Validate field name and check permissions
|
||||
self._validate_function_field_arg(arg)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue