fix(sanitize_fields): use sqlparse for function detection

Signed-off-by: Akhil Narang <me@akhilnarang.dev>
This commit is contained in:
Akhil Narang 2025-10-13 15:49:43 +05:30
parent 395af8aa04
commit 984c641bff
No known key found for this signature in database
GPG key ID: 9DCC61E211BF645F
2 changed files with 90 additions and 16 deletions

View file

@ -8,7 +8,11 @@ import json
import re
from collections import Counter
from collections.abc import Mapping, Sequence
from functools import cached_property
from functools import cached_property, lru_cache
import sqlparse
from sqlparse import tokens
from sqlparse.sql import Function, Parenthesis, Statement
import frappe
import frappe.defaults
@ -33,6 +37,22 @@ from frappe.utils import (
)
from frappe.utils.data import DateTimeLikeObject, get_datetime, getdate, sbool
@lru_cache(maxsize=128)
def _parse_sql(field: str) -> Statement | None:
"""
Parse a given SQL statement using `sqlparse`.
Args:
field (str): The SQL statement string to parse.
Returns:
Statement | None: A `sqlparse.sql.Statement` object if parsing succeeds, otherwise `None`.
"""
if parsed := sqlparse.parse(field):
return parsed[0]
LOCATE_PATTERN = re.compile(r"locate\([^,]+,\s*[`\"]?name[`\"]?\s*\)", flags=re.IGNORECASE)
LOCATE_CAST_PATTERN = re.compile(r"locate\(([^,]+),\s*([`\"]?name[`\"]?)\s*\)", flags=re.IGNORECASE)
FUNC_IFNULL_PATTERN = re.compile(r"(strpos|ifnull|coalesce)\(\s*[`\"]?name[`\"]?\s*,", flags=re.IGNORECASE)
@ -456,6 +476,46 @@ from {tables}
"sleep",
]
def _find_subqueries(parsed: Statement) -> list:
"""
Recursively find all subqueries in a parsed SQL statement.
"""
subqueries = []
for token in parsed.tokens:
if isinstance(token, Parenthesis):
# Check for DML token for subquery check
is_subquery = False
for sub_token in token.tokens:
if sub_token.ttype is tokens.DML:
is_subquery = True
break
if is_subquery:
subqueries.append(token)
# Recursively check for nested subqueries
subqueries.extend(_find_subqueries(token))
elif token.is_group:
subqueries.extend(_find_subqueries(token))
return subqueries
def _check_sql_token(statement: Statement) -> None:
"""
Checks the output of `sqlparse.parse()` to detect blocked functions and subqueries.
"""
if _find_subqueries(statement):
_raise_exception()
for token in statement.tokens:
if isinstance(token, Function):
if (name := (token.get_name())) and name.lower() in blacklisted_functions:
_raise_exception()
if token.ttype == tokens.Keyword:
if token.value.lower() in blacklisted_keywords:
_raise_exception()
if token.is_group:
_check_sql_token(token)
def _raise_exception():
frappe.throw(_("Use of sub-query or function is restricted"), frappe.DataError)
@ -470,21 +530,8 @@ from {tables}
lower_field = field.lower().strip()
if SUB_QUERY_PATTERN.match(field):
# Check for subquery anywhere in the field, not just at the beginning
if "(" in lower_field:
# Check all parentheses pairs, not just the first one
paren_start = 0
while True:
location = lower_field.find("(", paren_start)
if location == -1:
break
token = lower_field[location + 1 :].lstrip().split(" ", 1)[0]
if any(
re.search(r"\b" + re.escape(keyword) + r"\b", token)
for keyword in blacklisted_keywords + blacklisted_functions
):
_raise_exception()
paren_start = location + 1
# Check all tokens for subquery detection
_check_sql_token(_parse_sql(field))
if "@" in lower_field:
# prevent access to global variables

View file

@ -489,6 +489,23 @@ class TestDBQuery(IntegrationTestCase):
)
self.assertTrue("_relevance" in data[0])
# Test that fields with keywords in strings are allowed
data = DatabaseQuery("DocType").execute(
fields=["name", "locate('select', name)"],
limit_start=0,
limit_page_length=1,
)
self.assertTrue(data)
# Test that subqueries with other DML are blocked
self.assertRaises(
frappe.DataError,
DatabaseQuery("DocType").execute,
fields=["name", "issingle", "(insert into tabUser values (1))"],
limit_start=0,
limit_page_length=1,
)
data = DatabaseQuery("DocType").execute(
fields=["name", "issingle", "date(creation) as creation"],
limit_start=0,
@ -554,6 +571,16 @@ class TestDBQuery(IntegrationTestCase):
limit_page_length=1,
)
# Ensure search terms aren't blocked as functions
from frappe.desk.search import search_link
search_terms = ("global", "user")
for term in search_terms:
with self.subTest(term=term):
result = search_link("ToDo", term)
self.assertIsInstance(result, list)
def test_nested_permission(self):
frappe.set_user("Administrator")
create_nested_doctype()