diff --git a/frappe/database/query.py b/frappe/database/query.py index 0cdb09533c..f2797f3775 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -1,5 +1,6 @@ import re from ast import literal_eval +from functools import lru_cache from types import BuiltinFunctionType from typing import TYPE_CHECKING, TypeAlias @@ -275,19 +276,13 @@ class Engine: return Function(func, *_args, alias=alias or None) def sanitize_fields(self, fields: str | list | tuple): - def _sanitize_field(field: str): - if not isinstance(field, str): - return field - stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower") - if self.is_mariadb: - return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field) - return stripped_field - if isinstance(fields, list | tuple): - return [_sanitize_field(field) for field in fields] + return [ + _sanitize_field(field, self.is_mariadb) if isinstance(field, str) else field + for field in fields + ] elif isinstance(fields, str): - return _sanitize_field(fields) - + return _sanitize_field(fields, self.is_mariadb) return fields def parse_string_field(self, field: str): @@ -524,7 +519,7 @@ def literal_eval_(literal): def has_function(field): _field = field.casefold() if (isinstance(field, str) and "`" not in field) else field if not issubclass(type(_field), Criterion): - if any([f"{func}(" in _field for func in SQL_FUNCTIONS]): + if any([f"{func}(" in _field for func in SQL_FUNCTIONS]): # ) <- ignore this comment. return True @@ -558,3 +553,15 @@ def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str) -> .run(pluck=True) ) return result + + +@lru_cache(maxsize=1024) +def _sanitize_field(field: str, is_mariadb): + if field == "*" or not SPECIAL_CHAR_PATTERN.match(field): + # Skip checking if there are no special characters + return field + + stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower") + if is_mariadb: + return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field) + return stripped_field diff --git a/frappe/database/schema.py b/frappe/database/schema.py index da2dba1924..1988b6abac 100644 --- a/frappe/database/schema.py +++ b/frappe/database/schema.py @@ -5,7 +5,9 @@ from frappe import _ from frappe.utils import cint, cstr, flt from frappe.utils.defaults import get_not_null_defaults +# This matches anything that isn't [a-zA-Z0-9_] SPECIAL_CHAR_PATTERN = re.compile(r"[\W]", flags=re.UNICODE) + VARCHAR_CAST_PATTERN = re.compile(r"varchar\(([\d]+)\)")