perf: speedup QB field sanitization (#28818)

This commit is contained in:
Ankush Menat 2024-12-18 11:17:02 +05:30 committed by GitHub
parent 051cedb860
commit 23b5b0c7ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 12 deletions

View file

@ -1,5 +1,6 @@
import re
from ast import literal_eval
from functools import lru_cache
from types import BuiltinFunctionType
from typing import TYPE_CHECKING, TypeAlias
@ -275,19 +276,13 @@ class Engine:
return Function(func, *_args, alias=alias or None)
def sanitize_fields(self, fields: str | list | tuple):
def _sanitize_field(field: str):
if not isinstance(field, str):
return field
stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower")
if self.is_mariadb:
return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field)
return stripped_field
if isinstance(fields, list | tuple):
return [_sanitize_field(field) for field in fields]
return [
_sanitize_field(field, self.is_mariadb) if isinstance(field, str) else field
for field in fields
]
elif isinstance(fields, str):
return _sanitize_field(fields)
return _sanitize_field(fields, self.is_mariadb)
return fields
def parse_string_field(self, field: str):
@ -524,7 +519,7 @@ def literal_eval_(literal):
def has_function(field):
_field = field.casefold() if (isinstance(field, str) and "`" not in field) else field
if not issubclass(type(_field), Criterion):
if any([f"{func}(" in _field for func in SQL_FUNCTIONS]):
if any([f"{func}(" in _field for func in SQL_FUNCTIONS]): # ) <- ignore this comment.
return True
@ -558,3 +553,15 @@ def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str) ->
.run(pluck=True)
)
return result
@lru_cache(maxsize=1024)
def _sanitize_field(field: str, is_mariadb):
if field == "*" or not SPECIAL_CHAR_PATTERN.match(field):
# Skip checking if there are no special characters
return field
stripped_field = sqlparse.format(field, strip_comments=True, keyword_case="lower")
if is_mariadb:
return MARIADB_SPECIFIC_COMMENT.sub("", stripped_field)
return stripped_field

View file

@ -5,7 +5,9 @@ from frappe import _
from frappe.utils import cint, cstr, flt
from frappe.utils.defaults import get_not_null_defaults
# This matches anything that isn't [a-zA-Z0-9_]
SPECIAL_CHAR_PATTERN = re.compile(r"[\W]", flags=re.UNICODE)
VARCHAR_CAST_PATTERN = re.compile(r"varchar\(([\d]+)\)")