diff --git a/frappe/database/query.py b/frappe/database/query.py index 5726b760ee..2e1635101a 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -5,6 +5,7 @@ from functools import cached_property from types import BuiltinFunctionType from typing import TYPE_CHECKING, Callable +import sqlparse from pypika.dialects import MySQLQueryBuilder, PostgreSQLQueryBuilder import frappe @@ -14,6 +15,7 @@ from frappe.model.db_query import get_timespan_date_range from frappe.query_builder import Criterion, Field, Order, Table, functions from frappe.query_builder.functions import Function, SqlFunctions from frappe.query_builder.utils import PseudoColumn +from frappe.utils.data import MARIADB_SPECIFIC_COMMENT if TYPE_CHECKING: from frappe.query_builder import DocType @@ -492,7 +494,8 @@ class Engine: def get_fieldnames_from_child_table(self, doctype, fields): # Hacky and flaky implementation of implicit joins. # convert child_table.fieldname to `tabChild DocType`.`fieldname` - for idx, field in enumerate(fields, start=0): + _fields = [] + for field in fields: if "." in field and "tab" not in field: alias = None if " as " in field: @@ -506,12 +509,63 @@ class Engine: field = f"`tab{self.linked_doctype}`.`{linked_fieldname}`" if alias: field = f"{field} as {alias}" - fields[idx] = field + _fields.append(field) + + return _fields + + def sanitize_fields(self, fields: str | list | tuple): + is_mariadb = frappe.db.db_type == "mariadb" + + def _sanitize_field(field: str): + if not isinstance(field, str): + return field + stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower") + if is_mariadb: + return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field) + return stripped_field + + if isinstance(fields, (list, tuple)): + return [_sanitize_field(field) for field in fields] + elif isinstance(fields, str): + return _sanitize_field(fields) return fields - def set_fields(self, table, fields, **kwargs) -> list: + def get_list_fields(self, fields: list) -> list: + updated_fields = [] + if issubclass(type(fields), Criterion) or "*" in fields: + return fields + # fields = self.get_fieldnames_from_child_table(doctype=table, fields=fields) + for field in fields: + if not isinstance(field, Criterion) and field: + if " as " in field: + field, reference = field.split(" as ") + if "`" in field: + updated_fields.append(PseudoColumn(f"{field} as {reference}")) + else: + updated_fields.append(Field(field.strip()).as_(reference)) + elif "`" in str(field): + updated_fields.append(PseudoColumn(field.strip())) + else: + updated_fields.append(Field(field)) + return updated_fields + + def get_string_fields(self, fields: str) -> Field: + if fields == "*": + return fields + if "`" in fields: + fields = PseudoColumn(fields) + if " as " in str(fields): + fields, reference = str(fields).split(" as ") + if "`" in str(fields): + fields = PseudoColumn(f"{fields} as {reference}") + else: + fields = Field(fields).as_(reference) + return fields + + def set_fields(self, fields, **kwargs) -> list: fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name" + fields = self.sanitize_fields(fields) if isinstance(fields, list) and None in fields and Field not in fields: return None function_objects = [] @@ -535,39 +589,9 @@ class Engine: is_list, is_str = True, False if is_str: - if fields == "*": - return fields - if "`" in fields: - fields = PseudoColumn(fields) - if " as " in str(fields): - fields, reference = str(fields).split(" as ") - if "`" in str(fields): - fields = PseudoColumn(f"{fields} as {reference}") - else: - fields = Field(fields).as_(reference) - + fields = self.get_string_fields(fields) if not is_str and fields: - if issubclass(type(fields), Criterion): - return fields - updated_fields = [] - if "*" in fields: - return fields - # fields = self.get_fieldnames_from_child_table(doctype=table, fields=fields) - for field in fields: - if not isinstance(field, Criterion) and field: - if " as " in field: - field, reference = field.split(" as ") - if "`" in field: - updated_fields.append(PseudoColumn(f"{field} as {reference}")) - else: - updated_fields.append(Field(field.strip()).as_(reference)) - - elif "`" in str(field): - updated_fields.append(PseudoColumn(field.strip())) - else: - updated_fields.append(Field(field)) - - fields = updated_fields + fields = self.get_list_fields(fields) # Need to check instance again since fields modified. if not isinstance(fields, (list, tuple, set)): @@ -599,15 +623,17 @@ class Engine: has_join = True if has_join: - for idx, field in enumerate(fields): + + def _update_pypika_fields(field): if not is_pypika_function_object(field): field = field if isinstance(field, str) else field.get_sql() if not TABLE_PATTERN.search(str(field)): - fields[idx] = getattr(frappe.qb.DocType(table), field) + return getattr(frappe.qb.DocType(table), field) else: field.args = [getattr(frappe.qb.DocType(table), arg.get_sql()) for arg in field.args] - field.args[0] = getattr(frappe.qb.DocType(table), field.args[0].get_sql()) - fields[idx] = field + return field + + fields = [_update_pypika_fields(field) for field in fields] if len(self.tables) > 1: primary_table = self.tables.pop(table) @@ -631,7 +657,7 @@ class Engine: self.linked_doctype = None self.fieldname = None - fields = self.set_fields(table, kwargs.get("field_objects") or fields, **kwargs) + fields = self.set_fields(kwargs.get("field_objects") or fields, **kwargs) criterion = self.build_conditions(table, filters, **kwargs) join = kwargs.get("join").replace(" ", "_") if kwargs.get("join") else "left_join" criterion, fields = self.join_(criterion=criterion, fields=fields, table=table, join=join) diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index ae9513a302..1804f1672c 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -201,3 +201,9 @@ class TestQuery(FrappeTestCase): fields=["name", "`tabNote Seen By`.`user` as seen_by"], ), ) + + @run_only_if(db_type_is.MARIADB) + def test_comment_stripping(self): + self.assertNotIn( + "email", frappe.qb.engine.get_query("User", fields=["name", "#email"], filters={}).get_sql() + )