refactor: pythonic NamedParameterWrapper
This commit is contained in:
parent
15f3523c24
commit
0f75394720
2 changed files with 11 additions and 8 deletions
|
|
@ -5,18 +5,21 @@ from pypika.utils import format_alias_sql
|
|||
|
||||
|
||||
class NamedParameterWrapper():
|
||||
def __init__(self, parameters: Dict[str, Any]):
|
||||
self.parameters = parameters
|
||||
def __init__(self) -> None:
|
||||
self.parameters={}
|
||||
|
||||
def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
|
||||
def update_parameters(self, param_key: Any, param_value: Any, **kwargs)->None:
|
||||
self.parameters[param_key[2:-2]] = param_value
|
||||
|
||||
def get_sql(self, **kwargs):
|
||||
def get_sql(self, **kwargs)->str:
|
||||
return f'%(param{len(self.parameters) + 1})s'
|
||||
|
||||
def get_parameters(self)->Dict[str, Any]:
|
||||
return self.parameters
|
||||
|
||||
|
||||
class ParameterizedValueWrapper(ValueWrapper):
|
||||
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper= None, **kwargs: Any) -> str:
|
||||
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", param_wrapper:NamedParameterWrapper = None, **kwargs: Any) -> str:
|
||||
if param_wrapper is None:
|
||||
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
|
||||
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
|
||||
|
|
|
|||
|
|
@ -59,11 +59,11 @@ def patch_query_execute():
|
|||
return frappe.db.sql(query, params, *args, **kwargs) # nosemgrep
|
||||
|
||||
def prepare_query(query):
|
||||
params = {}
|
||||
query = query.get_sql(param_wrapper=NamedParameterWrapper(params))
|
||||
param_collector = NamedParameterWrapper()
|
||||
query = query.get_sql(param_wrapper=param_collector)
|
||||
if frappe.flags.in_safe_exec and not query.lower().strip().startswith("select"):
|
||||
raise frappe.PermissionError('Only SELECT SQL allowed in scripting')
|
||||
return query, params
|
||||
return query, param_collector.get_parameters()
|
||||
|
||||
query_class = get_attr(str(frappe.qb).split("'")[1])
|
||||
builder_class = get_type_hints(query_class._builder).get('return')
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue