fix: function detection
Signed-off-by: Akhil Narang <me@akhilnarang.dev>
This commit is contained in:
parent
b407fe8093
commit
ddcda11d67
1 changed files with 22 additions and 4 deletions
|
|
@ -33,6 +33,16 @@ COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
|
|||
# to allow table names like __Auth
|
||||
TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII)
|
||||
|
||||
|
||||
def _is_function_call(field_str: str) -> bool:
|
||||
"""Check if a string is a SQL function call using sqlparse."""
|
||||
parsed = sqlparse.parse(field_str.strip())
|
||||
if not parsed:
|
||||
return False
|
||||
|
||||
return any(isinstance(token, sqlparse.sql.Function) for token in parsed[0].tokens)
|
||||
|
||||
|
||||
# Pattern to validate field names in SELECT:
|
||||
# Allows: name, `name`, name as alias, `name` as alias, `table name`.`name`, `table name`.`name` as alias, table.name, table.name as alias
|
||||
ALLOWED_FIELD_PATTERN = re.compile(r"^(?:(`[\w\s-]+`|\w+)\.)?(`\w+`|\w+)(?:\s+as\s+\w+)?$", flags=re.ASCII)
|
||||
|
|
@ -732,8 +742,7 @@ class Engine:
|
|||
for item in initial_field_list:
|
||||
if isinstance(item, str):
|
||||
# Sanitize and split potentially comma-separated strings within the list
|
||||
sanitized_item = _sanitize_field(item.strip(), self.is_mariadb).strip()
|
||||
if sanitized_item:
|
||||
if sanitized_item := _sanitize_field(item.strip(), self.is_mariadb).strip():
|
||||
parsed = self._parse_single_field_item(sanitized_item)
|
||||
if isinstance(parsed, list): # Result from parsing a child query dict
|
||||
_fields.extend(parsed)
|
||||
|
|
@ -1028,8 +1037,8 @@ class Engine:
|
|||
elif hasattr(field, "alias") and field.alias and field.name in permitted_fields_set:
|
||||
allowed_fields.append(field)
|
||||
|
||||
elif isinstance(field, AggregateFunction | PseudoColumnMapper):
|
||||
# Typically functions or complex terms
|
||||
elif isinstance(field, Term):
|
||||
# Allow any Term subclass, like LiteralValue (raw SQL expressions), AggregateFunction, PseudoColumnMapper (functions or complex terms)
|
||||
allowed_fields.append(field)
|
||||
|
||||
return allowed_fields
|
||||
|
|
@ -1555,6 +1564,15 @@ def _validate_select_field(field: str):
|
|||
if field.isdigit():
|
||||
return
|
||||
|
||||
# Reject SQL functions
|
||||
if _is_function_call(field):
|
||||
frappe.throw(
|
||||
_(
|
||||
"SQL functions are not allowed in SELECT fields: {0}. Use the query builder API with functions instead."
|
||||
).format(field),
|
||||
frappe.ValidationError,
|
||||
)
|
||||
|
||||
if ALLOWED_FIELD_PATTERN.match(field):
|
||||
return
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue