diff --git a/frappe/model/db_query.py b/frappe/model/db_query.py index e96e032784..bf3f09acac 100644 --- a/frappe/model/db_query.py +++ b/frappe/model/db_query.py @@ -8,7 +8,11 @@ import json import re from collections import Counter from collections.abc import Mapping, Sequence -from functools import cached_property +from functools import cached_property, lru_cache + +import sqlparse +from sqlparse import tokens +from sqlparse.sql import Function, Parenthesis, Statement import frappe import frappe.defaults @@ -33,6 +37,22 @@ from frappe.utils import ( ) from frappe.utils.data import DateTimeLikeObject, get_datetime, getdate, sbool + +@lru_cache(maxsize=128) +def _parse_sql(field: str) -> Statement | None: + """ + Parse a given SQL statement using `sqlparse`. + + Args: + field (str): The SQL statement string to parse. + + Returns: + Statement | None: A `sqlparse.sql.Statement` object if parsing succeeds, otherwise `None`. + """ + if parsed := sqlparse.parse(field): + return parsed[0] + + LOCATE_PATTERN = re.compile(r"locate\([^,]+,\s*[`\"]?name[`\"]?\s*\)", flags=re.IGNORECASE) LOCATE_CAST_PATTERN = re.compile(r"locate\(([^,]+),\s*([`\"]?name[`\"]?)\s*\)", flags=re.IGNORECASE) FUNC_IFNULL_PATTERN = re.compile(r"(strpos|ifnull|coalesce)\(\s*[`\"]?name[`\"]?\s*,", flags=re.IGNORECASE) @@ -456,6 +476,46 @@ from {tables} "sleep", ] + def _find_subqueries(parsed: Statement) -> list: + """ + Recursively find all subqueries in a parsed SQL statement. + """ + subqueries = [] + + for token in parsed.tokens: + if isinstance(token, Parenthesis): + # Check for DML token for subquery check + is_subquery = False + for sub_token in token.tokens: + if sub_token.ttype is tokens.DML: + is_subquery = True + break + if is_subquery: + subqueries.append(token) + # Recursively check for nested subqueries + subqueries.extend(_find_subqueries(token)) + elif token.is_group: + subqueries.extend(_find_subqueries(token)) + + return subqueries + + def _check_sql_token(statement: Statement) -> None: + """ + Checks the output of `sqlparse.parse()` to detect blocked functions and subqueries. + """ + if _find_subqueries(statement): + _raise_exception() + + for token in statement.tokens: + if isinstance(token, Function): + if (name := (token.get_name())) and name.lower() in blacklisted_functions: + _raise_exception() + if token.ttype == tokens.Keyword: + if token.value.lower() in blacklisted_keywords: + _raise_exception() + if token.is_group: + _check_sql_token(token) + def _raise_exception(): frappe.throw(_("Use of sub-query or function is restricted"), frappe.DataError) @@ -470,21 +530,8 @@ from {tables} lower_field = field.lower().strip() if SUB_QUERY_PATTERN.match(field): - # Check for subquery anywhere in the field, not just at the beginning - if "(" in lower_field: - # Check all parentheses pairs, not just the first one - paren_start = 0 - while True: - location = lower_field.find("(", paren_start) - if location == -1: - break - token = lower_field[location + 1 :].lstrip().split(" ", 1)[0] - if any( - re.search(r"\b" + re.escape(keyword) + r"\b", token) - for keyword in blacklisted_keywords + blacklisted_functions - ): - _raise_exception() - paren_start = location + 1 + # Check all tokens for subquery detection + _check_sql_token(_parse_sql(field)) if "@" in lower_field: # prevent access to global variables diff --git a/frappe/tests/test_db_query.py b/frappe/tests/test_db_query.py index 26306dab41..3f196d7214 100644 --- a/frappe/tests/test_db_query.py +++ b/frappe/tests/test_db_query.py @@ -489,6 +489,23 @@ class TestDBQuery(IntegrationTestCase): ) self.assertTrue("_relevance" in data[0]) + # Test that fields with keywords in strings are allowed + data = DatabaseQuery("DocType").execute( + fields=["name", "locate('select', name)"], + limit_start=0, + limit_page_length=1, + ) + self.assertTrue(data) + + # Test that subqueries with other DML are blocked + self.assertRaises( + frappe.DataError, + DatabaseQuery("DocType").execute, + fields=["name", "issingle", "(insert into tabUser values (1))"], + limit_start=0, + limit_page_length=1, + ) + data = DatabaseQuery("DocType").execute( fields=["name", "issingle", "date(creation) as creation"], limit_start=0, @@ -554,6 +571,16 @@ class TestDBQuery(IntegrationTestCase): limit_page_length=1, ) + # Ensure search terms aren't blocked as functions + from frappe.desk.search import search_link + + search_terms = ("global", "user") + + for term in search_terms: + with self.subTest(term=term): + result = search_link("ToDo", term) + self.assertIsInstance(result, list) + def test_nested_permission(self): frappe.set_user("Administrator") create_nested_doctype()