Merge pull request #21173 from resilient-tech/safe-filters

fix!: improved filter validation in `Engine.get_query`
This commit is contained in:
Ankush Menat 2023-06-02 11:45:06 +05:30 committed by GitHub
commit a09e29cfa0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 18 deletions

View file

@ -826,6 +826,7 @@ class Database:
fields=fields,
distinct=distinct,
limit=limit,
validate_filters=True,
)
if isinstance(fields, str) and fields == "*":
as_dict = True
@ -854,6 +855,7 @@ class Database:
order_by=order_by,
distinct=distinct,
limit=limit,
validate_filters=True,
).run(debug=debug, run=run, as_dict=as_dict, pluck=pluck)
return {}
@ -903,7 +905,12 @@ class Database:
field, val, modified=modified, modified_by=modified_by, update_modified=update_modified
)
query = frappe.qb.get_query(table=dt, filters=dn, update=True)
query = frappe.qb.get_query(
table=dt,
filters=dn,
update=True,
validate_filters=True,
)
if isinstance(dn, str):
frappe.clear_document_cache(dt, dn)
@ -1071,9 +1078,13 @@ class Database:
cache_count = frappe.cache().get_value(f"doctype:count:{dt}")
if cache_count is not None:
return cache_count
count = frappe.qb.get_query(table=dt, filters=filters, fields=Count("*"), distinct=distinct).run(
debug=debug
)[0][0]
count = frappe.qb.get_query(
table=dt,
filters=filters,
fields=Count("*"),
distinct=distinct,
validate_filters=True,
).run(debug=debug)[0][0]
if not filters and cache:
frappe.cache().set_value(f"doctype:count:{dt}", count, expires_in_sec=86400)
return count
@ -1193,7 +1204,12 @@ class Database:
Doctype name can be passed directly, it will be pre-pended with `tab`.
"""
filters = filters or kwargs.get("conditions")
query = frappe.qb.get_query(table=doctype, filters=filters, delete=True)
query = frappe.qb.get_query(
table=doctype,
filters=filters,
delete=True,
validate_filters=True,
)
if "debug" not in kwargs:
kwargs["debug"] = debug
return query.run(**kwargs)

View file

@ -9,6 +9,7 @@ from pypika.queries import QueryBuilder, Table
import frappe
from frappe import _
from frappe.database.operator_map import OPERATOR_MAP
from frappe.database.schema import SPECIAL_CHAR_PATTERN
from frappe.database.utils import DefaultOrderBy, get_doctype_name
from frappe.query_builder import Criterion, Field, Order, functions
from frappe.query_builder.functions import Function, SqlFunctions
@ -44,9 +45,12 @@ class Engine:
update: bool = False,
into: bool = False,
delete: bool = False,
*,
validate_filters: bool = False,
) -> QueryBuilder:
self.is_mariadb = frappe.db.db_type == "mariadb"
self.is_postgres = frappe.db.db_type == "postgres"
self.validate_filters = validate_filters
if isinstance(table, Table):
self.table = table
@ -157,14 +161,16 @@ class Engine:
_value = value
_operator = operator
if isinstance(_field, Field):
if not isinstance(_field, str):
pass
elif dynamic_field := DynamicTableField.parse(field, self.doctype):
elif not self.validate_filters and (
dynamic_field := DynamicTableField.parse(field, self.doctype)
):
# apply implicit join if link field's field is referenced
self.query = dynamic_field.apply_join(self.query)
_field = dynamic_field.field
elif has_function(field):
_field = self.get_function_object(field)
elif self.validate_filters and SPECIAL_CHAR_PATTERN.search(_field):
frappe.throw(_("Invalid filter: {0}").format(_field))
elif not doctype or doctype == self.doctype:
_field = self.table[field]
elif doctype:

View file

@ -202,7 +202,11 @@ def get_cards_for_user(doctype, txt, searchfield, start, page_len, filters):
if txt:
search_conditions = [numberCard[field].like(f"%{txt}%") for field in searchfields]
condition_query = frappe.qb.get_query(doctype, filters=filters)
condition_query = frappe.qb.get_query(
doctype,
filters=filters,
validate_filters=True,
)
return (
condition_query.select(numberCard.name, numberCard.label, numberCard.document_type)

View file

@ -36,7 +36,12 @@ def get_group_by_count(doctype: str, current_filters: str, field: str) -> list[d
ToDo = DocType("ToDo")
User = DocType("User")
count = Count("*").as_("count")
filtered_records = frappe.qb.get_query(doctype, filters=current_filters, fields=["name"])
filtered_records = frappe.qb.get_query(
doctype,
filters=current_filters,
fields=["name"],
validate_filters=True,
)
return (
frappe.qb.from_(ToDo)

View file

@ -218,13 +218,6 @@ class TestQuery(FrappeTestCase):
@run_only_if(db_type_is.MARIADB)
def test_filters(self):
self.assertEqual(
frappe.qb.get_query(
"User", filters={"IfNull(name, " ")": ("<", Now())}, fields=["Max(name)"]
).run(),
frappe.qb.from_("User").select(Max(Field("name"))).where(Ifnull("name", "") < Now()).run(),
)
self.assertEqual(
frappe.qb.get_query(
"DocType",
@ -258,6 +251,17 @@ class TestQuery(FrappeTestCase):
),
)
self.assertRaisesRegex(
frappe.ValidationError,
"Invalid filter",
lambda: frappe.qb.get_query(
"DocType",
fields=["name"],
filters={"permissions.role": "System Manager"},
validate_filters=True,
),
)
self.assertEqual(
frappe.qb.get_query(
"DocType",

View file

@ -31,6 +31,7 @@ def get_monthly_results(
Function(aggregation, goal_field),
],
filters=filters,
validate_filters=True,
)
.groupby("month_year")
.run()