perf(db_query): Maintain compiled pattern globally

This commit is contained in:
Gavin D'souza 2022-06-14 17:17:22 +05:30
parent 678eebe4fb
commit 9b4db43b84

View file

@ -40,6 +40,16 @@ CAST_VARCHAR_PATTERN = re.compile(
r"([`\"]?tab[\w`\" -]+\.[`\"]?name[`\"]?)(?!\w)", flags=re.IGNORECASE
)
ORDER_BY_PATTERN = re.compile(r"\ order\ by\ |\ asc|\ ASC|\ desc|\ DESC", flags=re.IGNORECASE)
SUB_QUERY_PATTERN = re.compile("^.*[,();@].*")
IS_QUERY_PATTERN = re.compile(r"^(select|delete|update|drop|create)\s")
IS_QUERY_PREDICATE_PATTERN = re.compile(
r"\s*[0-9a-zA-z]*\s*( from | group by | order by | where | join )"
)
FIELD_QUOTE_PATTERN = re.compile(r"[0-9a-zA-Z]+\s*'")
FIELD_COMMA_PATTERN = re.compile(r"[0-9a-zA-Z]+\s*,")
STRICT_FIELD_PATTERN = re.compile(r".*/\*.*")
STRICT_UNION_PATTERN = re.compile(r".*\s(union).*\s")
ORDER_GROUP_PATTERN = re.compile(r".*[^a-z0-9-_ ,`'\"\.\(\)].*")
class DatabaseQuery(object):
@ -343,8 +353,6 @@ class DatabaseQuery(object):
As field contains `,` and mysql function `version()`, with the help of regex
the system will filter out this field.
"""
sub_query_regex = re.compile("^.*[,();@].*")
blacklisted_keywords = ["select", "create", "insert", "delete", "drop", "update", "case", "show"]
blacklisted_functions = [
"concat",
@ -368,16 +376,14 @@ class DatabaseQuery(object):
frappe.throw(_("Use of sub-query or function is restricted"), frappe.DataError)
def _is_query(field):
if re.compile(r"^(select|delete|update|drop|create)\s").match(field):
if IS_QUERY_PATTERN.match(field):
_raise_exception()
elif re.compile(r"\s*[0-9a-zA-z]*\s*( from | group by | order by | where | join )").match(
field
):
elif IS_QUERY_PREDICATE_PATTERN.match(field):
_raise_exception()
for field in self.fields:
if sub_query_regex.match(field):
if SUB_QUERY_PATTERN.match(field):
if any(f"({keyword}" in field.lower() for keyword in blacklisted_keywords):
_raise_exception()
@ -388,19 +394,19 @@ class DatabaseQuery(object):
# prevent access to global variables
_raise_exception()
if re.compile(r"[0-9a-zA-Z]+\s*'").match(field):
if FIELD_QUOTE_PATTERN.match(field):
_raise_exception()
if re.compile(r"[0-9a-zA-Z]+\s*,").match(field):
if FIELD_COMMA_PATTERN.match(field):
_raise_exception()
_is_query(field)
if self.strict:
if re.compile(r".*/\*.*").match(field):
if STRICT_FIELD_PATTERN.match(field):
frappe.throw(_("Illegal SQL Query"))
if re.compile(r".*\s(union).*\s").match(field.lower()):
if STRICT_UNION_PATTERN.match(field.lower()):
frappe.throw(_("Illegal SQL Query"))
def extract_tables(self):
@ -907,7 +913,7 @@ class DatabaseQuery(object):
if "select" in _lower and "from" in _lower:
frappe.throw(_("Cannot use sub-query in order by"))
if re.compile(r".*[^a-z0-9-_ ,`'\"\.\(\)].*").match(_lower):
if ORDER_GROUP_PATTERN.match(_lower):
frappe.throw(_("Illegal SQL Query"))
for field in parameters.split(","):