Merge branch 'set_value-refactor' of github.com:gavindsouza/frappe into set_value-refactor

This commit is contained in:
Gavin D'souza 2022-01-11 13:00:40 +05:30
commit 4e5f02aea5
3 changed files with 88 additions and 25 deletions

View file

@ -4,33 +4,64 @@ from pypika.terms import Function, ValueWrapper
from pypika.utils import format_alias_sql
class NamedParameterWrapper():
def __init__(self, parameters: Dict[str, Any]):
self.parameters = parameters
class NamedParameterWrapper:
"""Utility class to hold parameter values and keys"""
def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
def __init__(self) -> None:
self.parameters = {}
def get_sql(self, param_value: Any, **kwargs) -> str:
"""returns SQL for a parameter, while adding the real value in a dict
Args:
param_value (Any): Value of the parameter
Returns:
str: parameter used in the SQL query
"""
param_key = f"%(param{len(self.parameters) + 1})s"
self.parameters[param_key[2:-2]] = param_value
return param_key
def get_sql(self, **kwargs):
return f'%(param{len(self.parameters) + 1})s'
def get_parameters(self) -> Dict[str, Any]:
"""get dict with parameters and values
Returns:
Dict[str, Any]: parameter dict
"""
return self.parameters
class ParameterizedValueWrapper(ValueWrapper):
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper= None, **kwargs: Any) -> str:
"""
Class to monkey patch ValueWrapper
Adds functionality to parameterize queries when a `param wrapper` is passed in get_sql()
"""
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper: Optional[NamedParameterWrapper] = None, **kwargs: Any) -> str:
if param_wrapper is None:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
sql = self.get_value_sql(
quote_char=quote_char,
secondary_quote_char=secondary_quote_char,
**kwargs,
)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
else:
value_sql = self.value
if isinstance(self.value, str):
# add quotes if it's a string value
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
else:
value_sql = self.value
param_sql = param_wrapper.get_sql(**kwargs)
param_wrapper.update_parameters(param_key=param_sql, param_value=value_sql, **kwargs)
param_sql = param_wrapper.get_sql(param_value=value_sql, **kwargs)
return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)
class ParameterizedFunction(Function):
"""
Class to monkey patch pypika.terms.Functions
Only to pass `param_wrapper` in `get_function_sql`.
"""
def get_sql(self, **kwargs: Any) -> str:
with_alias = kwargs.pop("with_alias", False)
with_namespace = kwargs.pop("with_namespace", False)
@ -38,15 +69,24 @@ class ParameterizedFunction(Function):
dialect = kwargs.pop("dialect", None)
param_wrapper = kwargs.pop("param_wrapper", None)
function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char, param_wrapper=param_wrapper, dialect=dialect)
function_sql = self.get_function_sql(
with_namespace=with_namespace,
quote_char=quote_char,
param_wrapper=param_wrapper,
dialect=dialect,
)
if self.schema is not None:
function_sql = "{schema}.{function}".format(
schema=self.schema.get_sql(quote_char=quote_char, dialect=dialect, **kwargs),
schema=self.schema.get_sql(
quote_char=quote_char, dialect=dialect, **kwargs
),
function=function_sql,
)
if with_alias:
return format_alias_sql(function_sql, self.alias, quote_char=quote_char, **kwargs)
return format_alias_sql(
function_sql, self.alias, quote_char=quote_char, **kwargs
)
return function_sql

View file

@ -59,11 +59,11 @@ def patch_query_execute():
return frappe.db.sql(query, params, *args, **kwargs) # nosemgrep
def prepare_query(query):
params = {}
query = query.get_sql(param_wrapper=NamedParameterWrapper(params))
param_collector = NamedParameterWrapper()
query = query.get_sql(param_wrapper=param_collector)
if frappe.flags.in_safe_exec and not query.lower().strip().startswith("select"):
raise frappe.PermissionError('Only SELECT SQL allowed in scripting')
return query, params
return query, param_collector.get_parameters()
query_class = get_attr(str(frappe.qb).split("'")[1])
builder_class = get_type_hints(query_class._builder).get('return')

View file

@ -66,24 +66,47 @@ class TestBuilderBase(object):
self.assertIsInstance(query.run, Callable)
self.assertIsInstance(data, list)
def test_walk(self):
class TestParameterization(unittest.TestCase):
def test_where_conditions(self):
DocType = frappe.qb.DocType("DocType")
query = (
frappe.qb.from_(DocType)
.select(DocType.name)
.where(
(DocType.owner == "Administrator' --")
& (Coalesce(DocType.search_fields == "subject"))
)
.where((DocType.owner == "Administrator' --"))
)
self.assertTrue("walk" in dir(query))
query, params = query.walk()
self.assertIn("%(param1)s", query)
self.assertIn("%(param2)s", query)
self.assertIn("param1", params)
self.assertEqual(params["param1"], "Administrator' --")
self.assertEqual(params["param2"], "subject")
def test_set_cnoditions(self):
DocType = frappe.qb.DocType("DocType")
query = frappe.qb.update(DocType).set(DocType.value, "some_value")
self.assertTrue("walk" in dir(query))
query, params = query.walk()
self.assertIn("%(param1)s", query)
self.assertIn("param1", params)
self.assertEqual(params["param1"], "some_value")
def test_where_conditions_functions(self):
DocType = frappe.qb.DocType("DocType")
query = (
frappe.qb.from_(DocType)
.select(DocType.name)
.where(Coalesce(DocType.search_fields == "subject"))
)
self.assertTrue("walk" in dir(query))
query, params = query.walk()
self.assertIn("%(param1)s", query)
self.assertIn("param1", params)
self.assertEqual(params["param1"], "subject")
@run_only_if(db_type_is.MARIADB)