diff --git a/frappe/query_builder/terms.py b/frappe/query_builder/terms.py index 0e0e2c4800..b0be40e0d2 100644 --- a/frappe/query_builder/terms.py +++ b/frappe/query_builder/terms.py @@ -4,33 +4,64 @@ from pypika.terms import Function, ValueWrapper from pypika.utils import format_alias_sql -class NamedParameterWrapper(): - def __init__(self, parameters: Dict[str, Any]): - self.parameters = parameters +class NamedParameterWrapper: + """Utility class to hold parameter values and keys""" - def update_parameters(self, param_key: Any, param_value: Any, **kwargs): + def __init__(self) -> None: + self.parameters = {} + + def get_sql(self, param_value: Any, **kwargs) -> str: + """returns SQL for a parameter, while adding the real value in a dict + + Args: + param_value (Any): Value of the parameter + + Returns: + str: parameter used in the SQL query + """ + param_key = f"%(param{len(self.parameters) + 1})s" self.parameters[param_key[2:-2]] = param_value + return param_key - def get_sql(self, **kwargs): - return f'%(param{len(self.parameters) + 1})s' + def get_parameters(self) -> Dict[str, Any]: + """get dict with parameters and values + + Returns: + Dict[str, Any]: parameter dict + """ + return self.parameters class ParameterizedValueWrapper(ValueWrapper): - def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper= None, **kwargs: Any) -> str: + """ + Class to monkey patch ValueWrapper + + Adds functionality to parameterize queries when a `param wrapper` is passed in get_sql() + """ + + def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper: Optional[NamedParameterWrapper] = None, **kwargs: Any) -> str: if param_wrapper is None: - sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) + sql = self.get_value_sql( + quote_char=quote_char, + secondary_quote_char=secondary_quote_char, + **kwargs, + ) return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) else: + value_sql = self.value if isinstance(self.value, str): + # add quotes if it's a string value value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) - else: - value_sql = self.value - param_sql = param_wrapper.get_sql(**kwargs) - param_wrapper.update_parameters(param_key=param_sql, param_value=value_sql, **kwargs) + param_sql = param_wrapper.get_sql(param_value=value_sql, **kwargs) return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) class ParameterizedFunction(Function): + """ + Class to monkey patch pypika.terms.Functions + + Only to pass `param_wrapper` in `get_function_sql`. + """ def get_sql(self, **kwargs: Any) -> str: with_alias = kwargs.pop("with_alias", False) with_namespace = kwargs.pop("with_namespace", False) @@ -38,15 +69,24 @@ class ParameterizedFunction(Function): dialect = kwargs.pop("dialect", None) param_wrapper = kwargs.pop("param_wrapper", None) - function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char, param_wrapper=param_wrapper, dialect=dialect) + function_sql = self.get_function_sql( + with_namespace=with_namespace, + quote_char=quote_char, + param_wrapper=param_wrapper, + dialect=dialect, + ) if self.schema is not None: function_sql = "{schema}.{function}".format( - schema=self.schema.get_sql(quote_char=quote_char, dialect=dialect, **kwargs), + schema=self.schema.get_sql( + quote_char=quote_char, dialect=dialect, **kwargs + ), function=function_sql, ) if with_alias: - return format_alias_sql(function_sql, self.alias, quote_char=quote_char, **kwargs) + return format_alias_sql( + function_sql, self.alias, quote_char=quote_char, **kwargs + ) return function_sql diff --git a/frappe/query_builder/utils.py b/frappe/query_builder/utils.py index 7797ce856c..cbd6147e01 100644 --- a/frappe/query_builder/utils.py +++ b/frappe/query_builder/utils.py @@ -59,11 +59,11 @@ def patch_query_execute(): return frappe.db.sql(query, params, *args, **kwargs) # nosemgrep def prepare_query(query): - params = {} - query = query.get_sql(param_wrapper=NamedParameterWrapper(params)) + param_collector = NamedParameterWrapper() + query = query.get_sql(param_wrapper=param_collector) if frappe.flags.in_safe_exec and not query.lower().strip().startswith("select"): raise frappe.PermissionError('Only SELECT SQL allowed in scripting') - return query, params + return query, param_collector.get_parameters() query_class = get_attr(str(frappe.qb).split("'")[1]) builder_class = get_type_hints(query_class._builder).get('return') diff --git a/frappe/tests/test_query_builder.py b/frappe/tests/test_query_builder.py index ab1d8292b8..bc98e166ea 100644 --- a/frappe/tests/test_query_builder.py +++ b/frappe/tests/test_query_builder.py @@ -66,24 +66,47 @@ class TestBuilderBase(object): self.assertIsInstance(query.run, Callable) self.assertIsInstance(data, list) - def test_walk(self): + +class TestParameterization(unittest.TestCase): + def test_where_conditions(self): DocType = frappe.qb.DocType("DocType") query = ( frappe.qb.from_(DocType) .select(DocType.name) - .where( - (DocType.owner == "Administrator' --") - & (Coalesce(DocType.search_fields == "subject")) - ) + .where((DocType.owner == "Administrator' --")) ) self.assertTrue("walk" in dir(query)) query, params = query.walk() self.assertIn("%(param1)s", query) - self.assertIn("%(param2)s", query) self.assertIn("param1", params) self.assertEqual(params["param1"], "Administrator' --") - self.assertEqual(params["param2"], "subject") + + def test_set_cnoditions(self): + DocType = frappe.qb.DocType("DocType") + query = frappe.qb.update(DocType).set(DocType.value, "some_value") + + self.assertTrue("walk" in dir(query)) + query, params = query.walk() + + self.assertIn("%(param1)s", query) + self.assertIn("param1", params) + self.assertEqual(params["param1"], "some_value") + + def test_where_conditions_functions(self): + DocType = frappe.qb.DocType("DocType") + query = ( + frappe.qb.from_(DocType) + .select(DocType.name) + .where(Coalesce(DocType.search_fields == "subject")) + ) + + self.assertTrue("walk" in dir(query)) + query, params = query.walk() + + self.assertIn("%(param1)s", query) + self.assertIn("param1", params) + self.assertEqual(params["param1"], "subject") @run_only_if(db_type_is.MARIADB)