diff --git a/frappe/utils/safe_exec.py b/frappe/utils/safe_exec.py index 0d1574b49a..a494df0e74 100644 --- a/frappe/utils/safe_exec.py +++ b/frappe/utils/safe_exec.py @@ -14,6 +14,8 @@ from frappe.www.printview import get_visible_columns import frappe.exceptions import frappe.integrations.utils from frappe.frappeclient import FrappeClient +from frappe.query_builder.utils import get_attr +from typing import get_type_hints class ServerScriptNotEnabled(frappe.PermissionError): pass @@ -29,6 +31,13 @@ class NamespaceDict(frappe._dict): return ret +query_class = get_attr(str(frappe.qb).split("'")[1]) +class SafeQb(query_class): + def __init__(self, *args, **kwargs): + _builder = get_type_hints(super()._builder).get('return') + _builder.run = read_sql + + def safe_exec(script, _globals=None, _locals=None): # server scripts can be disabled via site_config.json # they are enabled by default @@ -52,6 +61,7 @@ def safe_exec(script, _globals=None, _locals=None): def get_safe_globals(): datautils = frappe._dict() + safe_qb = SafeQb() if frappe.db: date_format = frappe.db.get_default("date_format") or "yyyy-mm-dd" time_format = frappe.db.get_default("time_format") or "HH:mm:ss" @@ -69,8 +79,8 @@ def get_safe_globals(): out = NamespaceDict( # make available limited methods of frappe json=NamespaceDict( - loads = json.loads, - dumps = json.dumps), + loads=json.loads, + dumps=json.dumps), dict=dict, log=frappe.log, _dict=frappe._dict, @@ -85,6 +95,7 @@ def get_safe_globals(): bold=frappe.bold, copy_doc=frappe.copy_doc, errprint=frappe.errprint, + qb=safe_qb, get_meta=frappe.get_meta, get_doc=frappe.get_doc, @@ -99,9 +110,9 @@ def get_safe_globals(): render_template=frappe.render_template, msgprint=frappe.msgprint, throw=frappe.throw, - sendmail = frappe.sendmail, - get_print = frappe.get_print, - attach_print = frappe.attach_print, + sendmail=frappe.sendmail, + get_print=frappe.get_print, + attach_print=frappe.attach_print, user=user, get_fullname=frappe.utils.get_fullname, @@ -112,8 +123,8 @@ def get_safe_globals(): user=user, csrf_token=frappe.local.session.data.csrf_token if getattr(frappe.local, "session", None) else '' ), - make_get_request = frappe.integrations.utils.make_get_request, - make_post_request = frappe.integrations.utils.make_post_request, + make_get_request=frappe.integrations.utils.make_get_request, + make_post_request=frappe.integrations.utils.make_post_request, socketio_port=frappe.conf.socketio_port, get_hooks=frappe.get_hooks, sanitize_html=frappe.utils.sanitize_html, @@ -141,19 +152,19 @@ def get_safe_globals(): out.frappe.date_format = date_format out.frappe.time_format = time_format out.frappe.db = NamespaceDict( - get_list = frappe.get_list, - get_all = frappe.get_all, - get_value = frappe.db.get_value, - set_value = frappe.db.set_value, - get_single_value = frappe.db.get_single_value, - get_default = frappe.db.get_default, - count = frappe.db.count, - min = frappe.db.min, - max = frappe.db.max, - avg = frappe.db.avg, - sum = frappe.db.sum, - escape = frappe.db.escape, - sql = read_sql + get_list=frappe.get_list, + get_all=frappe.get_all, + get_value=frappe.db.get_value, + set_value=frappe.db.set_value, + get_single_value=frappe.db.get_single_value, + get_default=frappe.db.get_default, + count=frappe.db.count, + min=frappe.db.min, + max=frappe.db.max, + avg=frappe.db.avg, + sum=frappe.db.sum, + escape=frappe.db.escape, + sql=read_sql ) if frappe.response: @@ -175,6 +186,7 @@ def get_safe_globals(): def read_sql(query, *args, **kwargs): '''a wrapper for frappe.db.sql to allow reads''' + query = str(query) if query.strip().split(None, 1)[0].lower() == 'select': return frappe.db.sql(query, *args, **kwargs) else: