feat: Added safe_qb for server scripts

This commit is contained in:
Aradhya-Tripathi 2021-10-05 12:39:22 +05:30
parent 22876d9803
commit 9c00a28869

View file

@ -14,6 +14,8 @@ from frappe.www.printview import get_visible_columns
import frappe.exceptions
import frappe.integrations.utils
from frappe.frappeclient import FrappeClient
from frappe.query_builder.utils import get_attr
from typing import get_type_hints
class ServerScriptNotEnabled(frappe.PermissionError):
pass
@ -29,6 +31,13 @@ class NamespaceDict(frappe._dict):
return ret
query_class = get_attr(str(frappe.qb).split("'")[1])
class SafeQb(query_class):
def __init__(self, *args, **kwargs):
_builder = get_type_hints(super()._builder).get('return')
_builder.run = read_sql
def safe_exec(script, _globals=None, _locals=None):
# server scripts can be disabled via site_config.json
# they are enabled by default
@ -52,6 +61,7 @@ def safe_exec(script, _globals=None, _locals=None):
def get_safe_globals():
datautils = frappe._dict()
safe_qb = SafeQb()
if frappe.db:
date_format = frappe.db.get_default("date_format") or "yyyy-mm-dd"
time_format = frappe.db.get_default("time_format") or "HH:mm:ss"
@ -69,8 +79,8 @@ def get_safe_globals():
out = NamespaceDict(
# make available limited methods of frappe
json=NamespaceDict(
loads = json.loads,
dumps = json.dumps),
loads=json.loads,
dumps=json.dumps),
dict=dict,
log=frappe.log,
_dict=frappe._dict,
@ -85,6 +95,7 @@ def get_safe_globals():
bold=frappe.bold,
copy_doc=frappe.copy_doc,
errprint=frappe.errprint,
qb=safe_qb,
get_meta=frappe.get_meta,
get_doc=frappe.get_doc,
@ -99,9 +110,9 @@ def get_safe_globals():
render_template=frappe.render_template,
msgprint=frappe.msgprint,
throw=frappe.throw,
sendmail = frappe.sendmail,
get_print = frappe.get_print,
attach_print = frappe.attach_print,
sendmail=frappe.sendmail,
get_print=frappe.get_print,
attach_print=frappe.attach_print,
user=user,
get_fullname=frappe.utils.get_fullname,
@ -112,8 +123,8 @@ def get_safe_globals():
user=user,
csrf_token=frappe.local.session.data.csrf_token if getattr(frappe.local, "session", None) else ''
),
make_get_request = frappe.integrations.utils.make_get_request,
make_post_request = frappe.integrations.utils.make_post_request,
make_get_request=frappe.integrations.utils.make_get_request,
make_post_request=frappe.integrations.utils.make_post_request,
socketio_port=frappe.conf.socketio_port,
get_hooks=frappe.get_hooks,
sanitize_html=frappe.utils.sanitize_html,
@ -141,19 +152,19 @@ def get_safe_globals():
out.frappe.date_format = date_format
out.frappe.time_format = time_format
out.frappe.db = NamespaceDict(
get_list = frappe.get_list,
get_all = frappe.get_all,
get_value = frappe.db.get_value,
set_value = frappe.db.set_value,
get_single_value = frappe.db.get_single_value,
get_default = frappe.db.get_default,
count = frappe.db.count,
min = frappe.db.min,
max = frappe.db.max,
avg = frappe.db.avg,
sum = frappe.db.sum,
escape = frappe.db.escape,
sql = read_sql
get_list=frappe.get_list,
get_all=frappe.get_all,
get_value=frappe.db.get_value,
set_value=frappe.db.set_value,
get_single_value=frappe.db.get_single_value,
get_default=frappe.db.get_default,
count=frappe.db.count,
min=frappe.db.min,
max=frappe.db.max,
avg=frappe.db.avg,
sum=frappe.db.sum,
escape=frappe.db.escape,
sql=read_sql
)
if frappe.response:
@ -175,6 +186,7 @@ def get_safe_globals():
def read_sql(query, *args, **kwargs):
'''a wrapper for frappe.db.sql to allow reads'''
query = str(query)
if query.strip().split(None, 1)[0].lower() == 'select':
return frappe.db.sql(query, *args, **kwargs)
else: