fix(sanitize_fields): use sqlparse for function detection
Signed-off-by: Akhil Narang <me@akhilnarang.dev>
This commit is contained in:
parent
395af8aa04
commit
984c641bff
2 changed files with 90 additions and 16 deletions
|
|
@ -8,7 +8,11 @@ import json
|
|||
import re
|
||||
from collections import Counter
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from functools import cached_property, lru_cache
|
||||
|
||||
import sqlparse
|
||||
from sqlparse import tokens
|
||||
from sqlparse.sql import Function, Parenthesis, Statement
|
||||
|
||||
import frappe
|
||||
import frappe.defaults
|
||||
|
|
@ -33,6 +37,22 @@ from frappe.utils import (
|
|||
)
|
||||
from frappe.utils.data import DateTimeLikeObject, get_datetime, getdate, sbool
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _parse_sql(field: str) -> Statement | None:
|
||||
"""
|
||||
Parse a given SQL statement using `sqlparse`.
|
||||
|
||||
Args:
|
||||
field (str): The SQL statement string to parse.
|
||||
|
||||
Returns:
|
||||
Statement | None: A `sqlparse.sql.Statement` object if parsing succeeds, otherwise `None`.
|
||||
"""
|
||||
if parsed := sqlparse.parse(field):
|
||||
return parsed[0]
|
||||
|
||||
|
||||
LOCATE_PATTERN = re.compile(r"locate\([^,]+,\s*[`\"]?name[`\"]?\s*\)", flags=re.IGNORECASE)
|
||||
LOCATE_CAST_PATTERN = re.compile(r"locate\(([^,]+),\s*([`\"]?name[`\"]?)\s*\)", flags=re.IGNORECASE)
|
||||
FUNC_IFNULL_PATTERN = re.compile(r"(strpos|ifnull|coalesce)\(\s*[`\"]?name[`\"]?\s*,", flags=re.IGNORECASE)
|
||||
|
|
@ -456,6 +476,46 @@ from {tables}
|
|||
"sleep",
|
||||
]
|
||||
|
||||
def _find_subqueries(parsed: Statement) -> list:
|
||||
"""
|
||||
Recursively find all subqueries in a parsed SQL statement.
|
||||
"""
|
||||
subqueries = []
|
||||
|
||||
for token in parsed.tokens:
|
||||
if isinstance(token, Parenthesis):
|
||||
# Check for DML token for subquery check
|
||||
is_subquery = False
|
||||
for sub_token in token.tokens:
|
||||
if sub_token.ttype is tokens.DML:
|
||||
is_subquery = True
|
||||
break
|
||||
if is_subquery:
|
||||
subqueries.append(token)
|
||||
# Recursively check for nested subqueries
|
||||
subqueries.extend(_find_subqueries(token))
|
||||
elif token.is_group:
|
||||
subqueries.extend(_find_subqueries(token))
|
||||
|
||||
return subqueries
|
||||
|
||||
def _check_sql_token(statement: Statement) -> None:
|
||||
"""
|
||||
Checks the output of `sqlparse.parse()` to detect blocked functions and subqueries.
|
||||
"""
|
||||
if _find_subqueries(statement):
|
||||
_raise_exception()
|
||||
|
||||
for token in statement.tokens:
|
||||
if isinstance(token, Function):
|
||||
if (name := (token.get_name())) and name.lower() in blacklisted_functions:
|
||||
_raise_exception()
|
||||
if token.ttype == tokens.Keyword:
|
||||
if token.value.lower() in blacklisted_keywords:
|
||||
_raise_exception()
|
||||
if token.is_group:
|
||||
_check_sql_token(token)
|
||||
|
||||
def _raise_exception():
|
||||
frappe.throw(_("Use of sub-query or function is restricted"), frappe.DataError)
|
||||
|
||||
|
|
@ -470,21 +530,8 @@ from {tables}
|
|||
lower_field = field.lower().strip()
|
||||
|
||||
if SUB_QUERY_PATTERN.match(field):
|
||||
# Check for subquery anywhere in the field, not just at the beginning
|
||||
if "(" in lower_field:
|
||||
# Check all parentheses pairs, not just the first one
|
||||
paren_start = 0
|
||||
while True:
|
||||
location = lower_field.find("(", paren_start)
|
||||
if location == -1:
|
||||
break
|
||||
token = lower_field[location + 1 :].lstrip().split(" ", 1)[0]
|
||||
if any(
|
||||
re.search(r"\b" + re.escape(keyword) + r"\b", token)
|
||||
for keyword in blacklisted_keywords + blacklisted_functions
|
||||
):
|
||||
_raise_exception()
|
||||
paren_start = location + 1
|
||||
# Check all tokens for subquery detection
|
||||
_check_sql_token(_parse_sql(field))
|
||||
|
||||
if "@" in lower_field:
|
||||
# prevent access to global variables
|
||||
|
|
|
|||
|
|
@ -489,6 +489,23 @@ class TestDBQuery(IntegrationTestCase):
|
|||
)
|
||||
self.assertTrue("_relevance" in data[0])
|
||||
|
||||
# Test that fields with keywords in strings are allowed
|
||||
data = DatabaseQuery("DocType").execute(
|
||||
fields=["name", "locate('select', name)"],
|
||||
limit_start=0,
|
||||
limit_page_length=1,
|
||||
)
|
||||
self.assertTrue(data)
|
||||
|
||||
# Test that subqueries with other DML are blocked
|
||||
self.assertRaises(
|
||||
frappe.DataError,
|
||||
DatabaseQuery("DocType").execute,
|
||||
fields=["name", "issingle", "(insert into tabUser values (1))"],
|
||||
limit_start=0,
|
||||
limit_page_length=1,
|
||||
)
|
||||
|
||||
data = DatabaseQuery("DocType").execute(
|
||||
fields=["name", "issingle", "date(creation) as creation"],
|
||||
limit_start=0,
|
||||
|
|
@ -554,6 +571,16 @@ class TestDBQuery(IntegrationTestCase):
|
|||
limit_page_length=1,
|
||||
)
|
||||
|
||||
# Ensure search terms aren't blocked as functions
|
||||
from frappe.desk.search import search_link
|
||||
|
||||
search_terms = ("global", "user")
|
||||
|
||||
for term in search_terms:
|
||||
with self.subTest(term=term):
|
||||
result = search_link("ToDo", term)
|
||||
self.assertIsInstance(result, list)
|
||||
|
||||
def test_nested_permission(self):
|
||||
frappe.set_user("Administrator")
|
||||
create_nested_doctype()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue