fix(db_query): Parse SQL function calls to check if field is accessible

This commit is contained in:
Gavin D'souza 2023-01-09 18:19:03 +05:30
parent 058c49f439
commit bb9763def7

View file

@ -50,6 +50,7 @@ FIELD_COMMA_PATTERN = re.compile(r"[0-9a-zA-Z]+\s*,")
STRICT_FIELD_PATTERN = re.compile(r".*/\*.*")
STRICT_UNION_PATTERN = re.compile(r".*\s(union).*\s")
ORDER_GROUP_PATTERN = re.compile(r".*[^a-z0-9-_ ,`'\"\.\(\)].*")
FN_PARAMS_PATTERN = re.compile(r".*?\((.*)\).*")
class DatabaseQuery:
@ -563,11 +564,13 @@ class DatabaseQuery:
available_fields = get_available_fields(doctype=self.doctype)
for i, field in enumerate(self.fields):
if field == "*":
column = field.split(" ", 1)[0].replace("`", "")
if column == "*":
self.fields[i : i + 1] = available_fields
# labels / pseudo columns or frappe internals
elif field[0] in {"'", '"', "_"} or field in available_fields:
elif column[0] in {"'", '"', "_"} or column in available_fields:
continue
# handle child / joined table fields
@ -586,8 +589,18 @@ class DatabaseQuery:
self.fields.remove(field)
# field inside function calls / * handles things like count(*)
elif "(" in field and ("*" in field or any(x for x in available_fields if x in field)):
continue
elif "(" in field:
if "*" in field:
continue
elif any(x for x in available_fields if x in field):
continue
elif _params := FN_PARAMS_PATTERN.findall(column):
params = (x for x in _params[0].split(","))
for param in params:
if param in available_fields or param.isnumeric() or "'" in param or '"' in param:
continue
else:
self.fields.remove(field)
# remove if access not allowed
else: