Merge branch 'set_value-refactor' of github.com:gavindsouza/frappe into set_value-refactor
This commit is contained in:
commit
4e5f02aea5
3 changed files with 88 additions and 25 deletions
|
|
@ -4,33 +4,64 @@ from pypika.terms import Function, ValueWrapper
|
|||
from pypika.utils import format_alias_sql
|
||||
|
||||
|
||||
class NamedParameterWrapper():
|
||||
def __init__(self, parameters: Dict[str, Any]):
|
||||
self.parameters = parameters
|
||||
class NamedParameterWrapper:
|
||||
"""Utility class to hold parameter values and keys"""
|
||||
|
||||
def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
|
||||
def __init__(self) -> None:
|
||||
self.parameters = {}
|
||||
|
||||
def get_sql(self, param_value: Any, **kwargs) -> str:
|
||||
"""returns SQL for a parameter, while adding the real value in a dict
|
||||
|
||||
Args:
|
||||
param_value (Any): Value of the parameter
|
||||
|
||||
Returns:
|
||||
str: parameter used in the SQL query
|
||||
"""
|
||||
param_key = f"%(param{len(self.parameters) + 1})s"
|
||||
self.parameters[param_key[2:-2]] = param_value
|
||||
return param_key
|
||||
|
||||
def get_sql(self, **kwargs):
|
||||
return f'%(param{len(self.parameters) + 1})s'
|
||||
def get_parameters(self) -> Dict[str, Any]:
|
||||
"""get dict with parameters and values
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: parameter dict
|
||||
"""
|
||||
return self.parameters
|
||||
|
||||
|
||||
class ParameterizedValueWrapper(ValueWrapper):
|
||||
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper= None, **kwargs: Any) -> str:
|
||||
"""
|
||||
Class to monkey patch ValueWrapper
|
||||
|
||||
Adds functionality to parameterize queries when a `param wrapper` is passed in get_sql()
|
||||
"""
|
||||
|
||||
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper: Optional[NamedParameterWrapper] = None, **kwargs: Any) -> str:
|
||||
if param_wrapper is None:
|
||||
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
|
||||
sql = self.get_value_sql(
|
||||
quote_char=quote_char,
|
||||
secondary_quote_char=secondary_quote_char,
|
||||
**kwargs,
|
||||
)
|
||||
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
|
||||
else:
|
||||
value_sql = self.value
|
||||
if isinstance(self.value, str):
|
||||
# add quotes if it's a string value
|
||||
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
|
||||
else:
|
||||
value_sql = self.value
|
||||
param_sql = param_wrapper.get_sql(**kwargs)
|
||||
param_wrapper.update_parameters(param_key=param_sql, param_value=value_sql, **kwargs)
|
||||
param_sql = param_wrapper.get_sql(param_value=value_sql, **kwargs)
|
||||
return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)
|
||||
|
||||
|
||||
class ParameterizedFunction(Function):
|
||||
"""
|
||||
Class to monkey patch pypika.terms.Functions
|
||||
|
||||
Only to pass `param_wrapper` in `get_function_sql`.
|
||||
"""
|
||||
def get_sql(self, **kwargs: Any) -> str:
|
||||
with_alias = kwargs.pop("with_alias", False)
|
||||
with_namespace = kwargs.pop("with_namespace", False)
|
||||
|
|
@ -38,15 +69,24 @@ class ParameterizedFunction(Function):
|
|||
dialect = kwargs.pop("dialect", None)
|
||||
param_wrapper = kwargs.pop("param_wrapper", None)
|
||||
|
||||
function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char, param_wrapper=param_wrapper, dialect=dialect)
|
||||
function_sql = self.get_function_sql(
|
||||
with_namespace=with_namespace,
|
||||
quote_char=quote_char,
|
||||
param_wrapper=param_wrapper,
|
||||
dialect=dialect,
|
||||
)
|
||||
|
||||
if self.schema is not None:
|
||||
function_sql = "{schema}.{function}".format(
|
||||
schema=self.schema.get_sql(quote_char=quote_char, dialect=dialect, **kwargs),
|
||||
schema=self.schema.get_sql(
|
||||
quote_char=quote_char, dialect=dialect, **kwargs
|
||||
),
|
||||
function=function_sql,
|
||||
)
|
||||
|
||||
if with_alias:
|
||||
return format_alias_sql(function_sql, self.alias, quote_char=quote_char, **kwargs)
|
||||
return format_alias_sql(
|
||||
function_sql, self.alias, quote_char=quote_char, **kwargs
|
||||
)
|
||||
|
||||
return function_sql
|
||||
|
|
|
|||
|
|
@ -59,11 +59,11 @@ def patch_query_execute():
|
|||
return frappe.db.sql(query, params, *args, **kwargs) # nosemgrep
|
||||
|
||||
def prepare_query(query):
|
||||
params = {}
|
||||
query = query.get_sql(param_wrapper=NamedParameterWrapper(params))
|
||||
param_collector = NamedParameterWrapper()
|
||||
query = query.get_sql(param_wrapper=param_collector)
|
||||
if frappe.flags.in_safe_exec and not query.lower().strip().startswith("select"):
|
||||
raise frappe.PermissionError('Only SELECT SQL allowed in scripting')
|
||||
return query, params
|
||||
return query, param_collector.get_parameters()
|
||||
|
||||
query_class = get_attr(str(frappe.qb).split("'")[1])
|
||||
builder_class = get_type_hints(query_class._builder).get('return')
|
||||
|
|
|
|||
|
|
@ -66,24 +66,47 @@ class TestBuilderBase(object):
|
|||
self.assertIsInstance(query.run, Callable)
|
||||
self.assertIsInstance(data, list)
|
||||
|
||||
def test_walk(self):
|
||||
|
||||
class TestParameterization(unittest.TestCase):
|
||||
def test_where_conditions(self):
|
||||
DocType = frappe.qb.DocType("DocType")
|
||||
query = (
|
||||
frappe.qb.from_(DocType)
|
||||
.select(DocType.name)
|
||||
.where(
|
||||
(DocType.owner == "Administrator' --")
|
||||
& (Coalesce(DocType.search_fields == "subject"))
|
||||
)
|
||||
.where((DocType.owner == "Administrator' --"))
|
||||
)
|
||||
self.assertTrue("walk" in dir(query))
|
||||
query, params = query.walk()
|
||||
|
||||
self.assertIn("%(param1)s", query)
|
||||
self.assertIn("%(param2)s", query)
|
||||
self.assertIn("param1", params)
|
||||
self.assertEqual(params["param1"], "Administrator' --")
|
||||
self.assertEqual(params["param2"], "subject")
|
||||
|
||||
def test_set_cnoditions(self):
|
||||
DocType = frappe.qb.DocType("DocType")
|
||||
query = frappe.qb.update(DocType).set(DocType.value, "some_value")
|
||||
|
||||
self.assertTrue("walk" in dir(query))
|
||||
query, params = query.walk()
|
||||
|
||||
self.assertIn("%(param1)s", query)
|
||||
self.assertIn("param1", params)
|
||||
self.assertEqual(params["param1"], "some_value")
|
||||
|
||||
def test_where_conditions_functions(self):
|
||||
DocType = frappe.qb.DocType("DocType")
|
||||
query = (
|
||||
frappe.qb.from_(DocType)
|
||||
.select(DocType.name)
|
||||
.where(Coalesce(DocType.search_fields == "subject"))
|
||||
)
|
||||
|
||||
self.assertTrue("walk" in dir(query))
|
||||
query, params = query.walk()
|
||||
|
||||
self.assertIn("%(param1)s", query)
|
||||
self.assertIn("param1", params)
|
||||
self.assertEqual(params["param1"], "subject")
|
||||
|
||||
|
||||
@run_only_if(db_type_is.MARIADB)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue