fix: restrict '*' argument to COUNT function only

- Add STAR_ALLOWED_FUNCTIONS frozenset with COUNT
- Pass function_name through parse chain to validate '*' usage
- Prevents '*' in functions like SUM(*), AVG(*) where it's invalid
This commit is contained in:
Sagar Vora 2025-12-02 01:00:16 +05:30
parent 99039c23cb
commit c40933dca7

View file

@ -144,6 +144,9 @@ FUNCTION_MAPPING = {
"MONTH": Month,
}
# Functions that accept '*' as an argument (e.g., COUNT(*))
STAR_ALLOWED_FUNCTIONS = frozenset(("COUNT",))
# Mapping from operator names to pypika Arithmetic enum values
# Operators use dict format: {"ADD": [left, right], "as": "alias"}
# Supported: ADD (+), SUB (-), MUL (*), DIV (/)
@ -1852,12 +1855,12 @@ class SQLFunctionParser:
func_class = FUNCTION_MAPPING[function_name]
if isinstance(function_args, str):
parsed_arg = self._parse_and_validate_argument(function_args)
parsed_arg = self._parse_and_validate_argument(function_args, function_name=function_name)
function_call = func_class(parsed_arg)
elif isinstance(function_args, list):
parsed_args = []
for arg in function_args:
parsed_arg = self._parse_and_validate_argument(arg)
parsed_arg = self._parse_and_validate_argument(arg, function_name=function_name)
parsed_args.append(parsed_arg)
function_call = func_class(*parsed_args)
elif isinstance(function_args, (int | float)):
@ -1924,7 +1927,7 @@ class SQLFunctionParser:
else:
return expression
def _parse_and_validate_argument(self, arg):
def _parse_and_validate_argument(self, arg, *, function_name: str | None = None):
"""Parse and validate a single function/operator argument against SQL injection.
Supports:
@ -1935,7 +1938,7 @@ class SQLFunctionParser:
if isinstance(arg, (int | float)):
return arg
elif isinstance(arg, str):
return self._validate_string_argument(arg)
return self._validate_string_argument(arg, function_name=function_name)
elif isinstance(arg, dict):
# Recursively handle nested functions and operators
if self.is_function_dict(arg):
@ -1958,17 +1961,20 @@ class SQLFunctionParser:
frappe.ValidationError,
)
def _validate_string_argument(self, arg: str):
def _validate_string_argument(self, arg: str, *, function_name: str | None = None):
"""Validate string arguments to prevent SQL injection."""
arg = arg.strip()
if not arg:
frappe.throw(_("Empty string arguments are not allowed"), frappe.ValidationError)
# Special case: allow '*' for COUNT(*) and similar aggregate functions
# Special case: allow '*' only for specific functions like COUNT(*)
if arg == "*":
# Star() produces correct unquoted * for COUNT(*)
# Column("*") would produce COUNT("*") which is wrong
if function_name not in STAR_ALLOWED_FUNCTIONS:
frappe.throw(
_("'*' is only allowed in {0} SQL function(s)").format(", ".join(STAR_ALLOWED_FUNCTIONS)),
frappe.ValidationError,
)
return Star()
# Check for string literals (quoted strings)
@ -1989,7 +1995,6 @@ class SQLFunctionParser:
).format(arg),
frappe.ValidationError,
)
elif self._is_valid_field_name(arg):
# Validate field name and check permissions
self._validate_function_field_arg(arg)