fix: function detection

Signed-off-by: Akhil Narang <me@akhilnarang.dev>
This commit is contained in:
Akhil Narang 2025-11-13 14:15:36 +05:30
parent b407fe8093
commit ddcda11d67
No known key found for this signature in database
GPG key ID: 9DCC61E211BF645F

View file

@ -33,6 +33,16 @@ COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
# to allow table names like __Auth
TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII)
def _is_function_call(field_str: str) -> bool:
"""Check if a string is a SQL function call using sqlparse."""
parsed = sqlparse.parse(field_str.strip())
if not parsed:
return False
return any(isinstance(token, sqlparse.sql.Function) for token in parsed[0].tokens)
# Pattern to validate field names in SELECT:
# Allows: name, `name`, name as alias, `name` as alias, `table name`.`name`, `table name`.`name` as alias, table.name, table.name as alias
ALLOWED_FIELD_PATTERN = re.compile(r"^(?:(`[\w\s-]+`|\w+)\.)?(`\w+`|\w+)(?:\s+as\s+\w+)?$", flags=re.ASCII)
@ -732,8 +742,7 @@ class Engine:
for item in initial_field_list:
if isinstance(item, str):
# Sanitize and split potentially comma-separated strings within the list
sanitized_item = _sanitize_field(item.strip(), self.is_mariadb).strip()
if sanitized_item:
if sanitized_item := _sanitize_field(item.strip(), self.is_mariadb).strip():
parsed = self._parse_single_field_item(sanitized_item)
if isinstance(parsed, list): # Result from parsing a child query dict
_fields.extend(parsed)
@ -1028,8 +1037,8 @@ class Engine:
elif hasattr(field, "alias") and field.alias and field.name in permitted_fields_set:
allowed_fields.append(field)
elif isinstance(field, AggregateFunction | PseudoColumnMapper):
# Typically functions or complex terms
elif isinstance(field, Term):
# Allow any Term subclass, like LiteralValue (raw SQL expressions), AggregateFunction, PseudoColumnMapper (functions or complex terms)
allowed_fields.append(field)
return allowed_fields
@ -1555,6 +1564,15 @@ def _validate_select_field(field: str):
if field.isdigit():
return
# Reject SQL functions
if _is_function_call(field):
frappe.throw(
_(
"SQL functions are not allowed in SELECT fields: {0}. Use the query builder API with functions instead."
).format(field),
frappe.ValidationError,
)
if ALLOWED_FIELD_PATTERN.match(field):
return