fix: search_link fails when txt contains parentheses (#22892)

* fix: search_link fails when txt contains parentheses

* fix: updating regex to replace number params also

* chore: replacing regex with sqlparse

* chore: not including fields like count(1) in asterisk_fields

* fix: owner/module not identified as column

* chore: lint fix and removing exception

* refactor: better function name

---------

Co-authored-by: Ankush Menat <ankush@frappe.io>
This commit is contained in:
Sambasiva Suda 2024-01-16 21:36:45 +05:30 committed by GitHub
parent 7123f50912
commit 642e9f4ec1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 45 deletions

View file

@ -5,6 +5,8 @@
import json
from sql_metadata import Parser
import frappe
import frappe.permissions
from frappe import _
@ -91,7 +93,10 @@ def validate_fields(data):
wildcard = update_wildcard_field_param(data)
for field in list(data.fields or []):
fieldname = extract_fieldname(field)
fieldname = extract_fieldnames(field)[0]
if not fieldname:
raise_invalid_field(fieldname)
if is_standard(fieldname):
continue
@ -173,23 +178,16 @@ def is_standard(fieldname):
)
def extract_fieldname(field):
for text in (",", "/*", "#"):
if text in field:
raise_invalid_field(field)
def extract_fieldnames(field):
parser = Parser(f"select {field}, _frappe_dummy from _dummy")
columns = [col for col in parser.columns if col != "_frappe_dummy"]
fieldname = field
for sep in (" as ", " AS "):
if sep in fieldname:
fieldname = fieldname.split(sep, 1)[0]
if not columns:
f = field.lower()
if "count(" in f or "sum(" in f or "avg(" in f:
return ["*"]
# certain functions allowed, extract the fieldname from the function
if fieldname.startswith("count(") or fieldname.startswith("sum(") or fieldname.startswith("avg("):
if not fieldname.strip().endswith(")"):
raise_invalid_field(field)
fieldname = fieldname.split("(", 1)[1][:-1]
return fieldname
return columns
def get_meta_and_docfield(fieldname, data):
@ -236,13 +234,13 @@ def get_parenttype_and_fieldname(field, data):
parts = field.split(".")
parenttype = parts[0]
fieldname = parts[1]
if parenttype.startswith("`tab"):
# `tabChild DocType`.`fieldname`
parenttype = parenttype[4:-1]
fieldname = fieldname.strip("`")
df = frappe.get_meta(data.doctype).get_field(parenttype)
if not df:
# tabChild DocType.fieldname
parenttype = parenttype[3:]
else:
# tablefield.fieldname
parenttype = frappe.get_meta(data.doctype).get_field(parenttype).options
parenttype = df.options
else:
parenttype = data.doctype
fieldname = field.strip("`")

View file

@ -52,7 +52,6 @@ FIELD_COMMA_PATTERN = re.compile(r"[0-9a-zA-Z]+\s*,")
STRICT_FIELD_PATTERN = re.compile(r".*/\*.*")
STRICT_UNION_PATTERN = re.compile(r".*\s(union).*\s")
ORDER_GROUP_PATTERN = re.compile(r".*[^a-z0-9-_ ,`'\"\.\(\)].*")
FN_PARAMS_PATTERN = re.compile(r".*?\((.*)\).*")
SPECIAL_FIELD_CHARS = frozenset(("(", "`", ".", "'", '"', "*"))
@ -626,6 +625,8 @@ class DatabaseQuery:
- Query: fields=["*"]
- Result: fields=["title", ...] // will also include Frappe's meta field like `name`, `owner`, etc.
"""
from frappe.desk.reportview import extract_fieldnames
if self.flags.ignore_permissions:
return
@ -638,23 +639,18 @@ class DatabaseQuery:
)
for i, field in enumerate(self.fields):
if "distinct" in field.lower():
# field: 'count(distinct `tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
if _fn := FN_PARAMS_PATTERN.findall(field):
column = _fn[0].replace("distinct ", "").replace("DISTINCT ", "").replace("`", "")
# field: 'distinct name'
# column: 'name'
else:
column = field.split(" ", 2)[1].replace("`", "")
else:
# field: 'count(`tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
column = field.split("(")[-1].split(")", 1)[0]
column = strip_alias(column).replace("`", "")
# field: 'count(distinct `tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
# field: 'count(`tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
columns = extract_fieldnames(field)
if not columns:
continue
if column == "*" and not in_function("*", field):
asterisk_fields.append(i)
column = columns[0]
if column == "*":
if "*" in field and not in_function("*", field):
asterisk_fields.append(i)
continue
# handle pseudo columns
@ -693,12 +689,9 @@ class DatabaseQuery:
elif "(" in field:
if "*" in field:
continue
elif _params := FN_PARAMS_PATTERN.findall(field):
params = (x.strip() for x in _params[0].split(","))
for param in params:
if not (
not param or param in permitted_fields or param.isnumeric() or "'" in param or '"' in param
):
else:
for column in columns:
if not column in permitted_fields:
self.remove_field(i)
break
continue

View file

@ -2,7 +2,7 @@
# License: MIT. See LICENSE
import frappe
from frappe.desk.reportview import export_query
from frappe.desk.reportview import export_query, extract_fieldnames
from frappe.tests.utils import FrappeTestCase
@ -32,3 +32,63 @@ class TestReportview(FrappeTestCase):
for row in reader:
self.assertEqual(int(row["Is Single"]), 1)
self.assertEqual(row["Module"], "Core")
def test_extract_fieldname(self):
self.assertEqual(
extract_fieldnames("count(distinct `tabPhoto`.name) as total_count")[0], "tabPhoto.name"
)
self.assertEqual(extract_fieldnames("owner")[0], "owner")
self.assertEqual(extract_fieldnames("module")[0], "module")
self.assertEqual(extract_fieldnames("count(`tabPhoto`.name) as total_count")[0], "tabPhoto.name")
self.assertEqual(extract_fieldnames("count(distinct `tabPhoto`.name)")[0], "tabPhoto.name")
self.assertEqual(extract_fieldnames("count(`tabPhoto`.name)")[0], "tabPhoto.name")
self.assertEqual(
extract_fieldnames("count(distinct `tabJob Applicant`.name) as total_count")[0],
"tabJob Applicant.name",
)
self.assertEqual(
extract_fieldnames("(1 / nullif(locate('a', `tabAddress`.`name`), 0)) as `_relevance`")[0],
"tabAddress.name",
)
self.assertEqual(
extract_fieldnames("(1 / nullif(locate('(a)', `tabAddress`.`name`), 0)) as `_relevance`")[0],
"tabAddress.name",
)
self.assertEqual(
extract_fieldnames("EXTRACT(MONTH FROM date_column) AS month")[0], "date_column"
)
self.assertEqual(extract_fieldnames("COUNT(*) AS count")[0], "*")
self.assertEqual(extract_fieldnames("COUNT(1) AS count")[0], "*")
self.assertEqual(extract_fieldnames("COUNT(1) AS count, SUM(1) AS sum")[0], "*")
self.assertEqual(
extract_fieldnames("first_name + ' ' + last_name AS full_name"), ["first_name", "last_name"]
)
self.assertEqual(
extract_fieldnames("CONCAT(first_name, ' ', last_name) AS full_name"),
["first_name", "last_name"],
)
self.assertEqual(
extract_fieldnames("CONCAT(id, '/', name, '/', age, '/', marks) AS student"),
["id", "name", "age", "marks"],
)
self.assertEqual(extract_fieldnames("tablefield.fiedname")[0], "tablefield.fiedname")
self.assertEqual(
extract_fieldnames("`tabChild DocType`.`fiedname`")[0], "tabChild DocType.fiedname"
)

View file

@ -68,6 +68,7 @@ dependencies = [
"semantic-version~=2.10.0",
"sentry-sdk~=1.37.1",
"sqlparse~=0.4.4",
"sql_metadata~=2.9.0",
"tenacity~=8.2.2",
"terminaltables~=3.1.10",
"traceback-with-variables~=2.0.4",