fix: stripping comments sent to the database API (#18229)

* feat: stripping comments sent to the database API

* test: Added tests for comment stripping

* fix: only stripping comments in string fields

* refactor: removing on the fly mutations

* refactor: added helper to avoid mutations

* refactor: simplify sanitization

* refactor: removing indexing from everywhere

* refactor: readable functions

* test: only run mdb test on mdb

Co-authored-by: Ankush Menat <ankush@frappe.io>
This commit is contained in:
Aradhya Tripathi 2022-10-12 16:36:55 +05:30 committed by GitHub
parent 1da6ca6731
commit 457de5c6b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 40 deletions

View file

@ -5,6 +5,7 @@ from functools import cached_property
from types import BuiltinFunctionType
from typing import TYPE_CHECKING, Callable
import sqlparse
from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder
import frappe
@ -14,6 +15,7 @@ from frappe.model.db_query import get_timespan_date_range
from frappe.query_builder import Criterion, Field, Order, Table, functions
from frappe.query_builder.functions import Function, SqlFunctions
from frappe.query_builder.utils import PseudoColumn
from frappe.utils.data import MARIADB_SPECIFIC_COMMENT
if TYPE_CHECKING:
from frappe.query_builder import DocType
@ -492,7 +494,8 @@ class Engine:
def get_fieldnames_from_child_table(self, doctype, fields):
# Hacky and flaky implementation of implicit joins.
# convert child_table.fieldname to `tabChild DocType`.`fieldname`
for idx, field in enumerate(fields, start=0):
_fields = []
for field in fields:
if "." in field and "tab" not in field:
alias = None
if " as " in field:
@ -506,12 +509,63 @@ class Engine:
field = f"`tab{self.linked_doctype}`.`{linked_fieldname}`"
if alias:
field = f"{field} as {alias}"
fields[idx] = field
_fields.append(field)
return _fields
def sanitize_fields(self, fields: str | list | tuple):
is_mariadb = frappe.db.db_type == "mariadb"
def _sanitize_field(field: str):
if not isinstance(field, str):
return field
stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower")
if is_mariadb:
return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field)
return stripped_field
if isinstance(fields, (list, tuple)):
return [_sanitize_field(field) for field in fields]
elif isinstance(fields, str):
return _sanitize_field(fields)
return fields
def set_fields(self, table, fields, **kwargs) -> list:
def get_list_fields(self, fields: list) -> list:
updated_fields = []
if issubclass(type(fields), Criterion) or "*" in fields:
return fields
# fields = self.get_fieldnames_from_child_table(doctype=table, fields=fields)
for field in fields:
if not isinstance(field, Criterion) and field:
if " as " in field:
field, reference = field.split(" as ")
if "`" in field:
updated_fields.append(PseudoColumn(f"{field} as {reference}"))
else:
updated_fields.append(Field(field.strip()).as_(reference))
elif "`" in str(field):
updated_fields.append(PseudoColumn(field.strip()))
else:
updated_fields.append(Field(field))
return updated_fields
def get_string_fields(self, fields: str) -> Field:
if fields == "*":
return fields
if "`" in fields:
fields = PseudoColumn(fields)
if " as " in str(fields):
fields, reference = str(fields).split(" as ")
if "`" in str(fields):
fields = PseudoColumn(f"{fields} as {reference}")
else:
fields = Field(fields).as_(reference)
return fields
def set_fields(self, fields, **kwargs) -> list:
fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name"
fields = self.sanitize_fields(fields)
if isinstance(fields, list) and None in fields and Field not in fields:
return None
function_objects = []
@ -535,39 +589,9 @@ class Engine:
is_list, is_str = True, False
if is_str:
if fields == "*":
return fields
if "`" in fields:
fields = PseudoColumn(fields)
if " as " in str(fields):
fields, reference = str(fields).split(" as ")
if "`" in str(fields):
fields = PseudoColumn(f"{fields} as {reference}")
else:
fields = Field(fields).as_(reference)
fields = self.get_string_fields(fields)
if not is_str and fields:
if issubclass(type(fields), Criterion):
return fields
updated_fields = []
if "*" in fields:
return fields
# fields = self.get_fieldnames_from_child_table(doctype=table, fields=fields)
for field in fields:
if not isinstance(field, Criterion) and field:
if " as " in field:
field, reference = field.split(" as ")
if "`" in field:
updated_fields.append(PseudoColumn(f"{field} as {reference}"))
else:
updated_fields.append(Field(field.strip()).as_(reference))
elif "`" in str(field):
updated_fields.append(PseudoColumn(field.strip()))
else:
updated_fields.append(Field(field))
fields = updated_fields
fields = self.get_list_fields(fields)
# Need to check instance again since fields modified.
if not isinstance(fields, (list, tuple, set)):
@ -599,15 +623,17 @@ class Engine:
has_join = True
if has_join:
for idx, field in enumerate(fields):
def _update_pypika_fields(field):
if not is_pypika_function_object(field):
field = field if isinstance(field, str) else field.get_sql()
if not TABLE_PATTERN.search(str(field)):
fields[idx] = getattr(frappe.qb.DocType(table), field)
return getattr(frappe.qb.DocType(table), field)
else:
field.args = [getattr(frappe.qb.DocType(table), arg.get_sql()) for arg in field.args]
field.args[0] = getattr(frappe.qb.DocType(table), field.args[0].get_sql())
fields[idx] = field
return field
fields = [_update_pypika_fields(field) for field in fields]
if len(self.tables) > 1:
primary_table = self.tables.pop(table)
@ -631,7 +657,7 @@ class Engine:
self.linked_doctype = None
self.fieldname = None
fields = self.set_fields(table, kwargs.get("field_objects") or fields, **kwargs)
fields = self.set_fields(kwargs.get("field_objects") or fields, **kwargs)
criterion = self.build_conditions(table, filters, **kwargs)
join = kwargs.get("join").replace(" ", "_") if kwargs.get("join") else "left_join"
criterion, fields = self.join_(criterion=criterion, fields=fields, table=table, join=join)

View file

@ -201,3 +201,9 @@ class TestQuery(FrappeTestCase):
fields=["name", "`tabNote Seen By`.`user` as seen_by"],
),
)
@run_only_if(db_type_is.MARIADB)
def test_comment_stripping(self):
self.assertNotIn(
"email", frappe.qb.engine.get_query("User", fields=["name", "#email"], filters={}).get_sql()
)