From 642e9f4ec1cf4d90e4484bc39b279e8c28f9f8a2 Mon Sep 17 00:00:00 2001 From: Sambasiva Suda Date: Tue, 16 Jan 2024 21:36:45 +0530 Subject: [PATCH] fix: search_link fails when txt contains parentheses (#22892) * fix: search_link fails when txt contains parentheses * fix: updating regex to replace number params also * chore: replacing regex with sqlparse * chore: not including fields like count(1) in asterisk_fields * fix: owner/module not identified as column * chore: lint fix and removing exception * refactor: better function name --------- Co-authored-by: Ankush Menat --- frappe/desk/reportview.py | 40 ++++++++++----------- frappe/model/db_query.py | 39 +++++++++------------ frappe/tests/test_reportview.py | 62 ++++++++++++++++++++++++++++++++- pyproject.toml | 1 + 4 files changed, 97 insertions(+), 45 deletions(-) diff --git a/frappe/desk/reportview.py b/frappe/desk/reportview.py index 746c6b299f..ee04d4cc97 100644 --- a/frappe/desk/reportview.py +++ b/frappe/desk/reportview.py @@ -5,6 +5,8 @@ import json +from sql_metadata import Parser + import frappe import frappe.permissions from frappe import _ @@ -91,7 +93,10 @@ def validate_fields(data): wildcard = update_wildcard_field_param(data) for field in list(data.fields or []): - fieldname = extract_fieldname(field) + fieldname = extract_fieldnames(field)[0] + if not fieldname: + raise_invalid_field(fieldname) + if is_standard(fieldname): continue @@ -173,23 +178,16 @@ def is_standard(fieldname): ) -def extract_fieldname(field): - for text in (",", "/*", "#"): - if text in field: - raise_invalid_field(field) +def extract_fieldnames(field): + parser = Parser(f"select {field}, _frappe_dummy from _dummy") + columns = [col for col in parser.columns if col != "_frappe_dummy"] - fieldname = field - for sep in (" as ", " AS "): - if sep in fieldname: - fieldname = fieldname.split(sep, 1)[0] + if not columns: + f = field.lower() + if "count(" in f or "sum(" in f or "avg(" in f: + return ["*"] - # certain functions allowed, extract the fieldname from the function - if fieldname.startswith("count(") or fieldname.startswith("sum(") or fieldname.startswith("avg("): - if not fieldname.strip().endswith(")"): - raise_invalid_field(field) - fieldname = fieldname.split("(", 1)[1][:-1] - - return fieldname + return columns def get_meta_and_docfield(fieldname, data): @@ -236,13 +234,13 @@ def get_parenttype_and_fieldname(field, data): parts = field.split(".") parenttype = parts[0] fieldname = parts[1] - if parenttype.startswith("`tab"): - # `tabChild DocType`.`fieldname` - parenttype = parenttype[4:-1] - fieldname = fieldname.strip("`") + df = frappe.get_meta(data.doctype).get_field(parenttype) + if not df: + # tabChild DocType.fieldname + parenttype = parenttype[3:] else: # tablefield.fieldname - parenttype = frappe.get_meta(data.doctype).get_field(parenttype).options + parenttype = df.options else: parenttype = data.doctype fieldname = field.strip("`") diff --git a/frappe/model/db_query.py b/frappe/model/db_query.py index 106172a508..3f5b5ccf79 100644 --- a/frappe/model/db_query.py +++ b/frappe/model/db_query.py @@ -52,7 +52,6 @@ FIELD_COMMA_PATTERN = re.compile(r"[0-9a-zA-Z]+\s*,") STRICT_FIELD_PATTERN = re.compile(r".*/\*.*") STRICT_UNION_PATTERN = re.compile(r".*\s(union).*\s") ORDER_GROUP_PATTERN = re.compile(r".*[^a-z0-9-_ ,`'\"\.\(\)].*") -FN_PARAMS_PATTERN = re.compile(r".*?\((.*)\).*") SPECIAL_FIELD_CHARS = frozenset(("(", "`", ".", "'", '"', "*")) @@ -626,6 +625,8 @@ class DatabaseQuery: - Query: fields=["*"] - Result: fields=["title", ...] // will also include Frappe's meta field like `name`, `owner`, etc. """ + from frappe.desk.reportview import extract_fieldnames + if self.flags.ignore_permissions: return @@ -638,23 +639,18 @@ class DatabaseQuery: ) for i, field in enumerate(self.fields): - if "distinct" in field.lower(): - # field: 'count(distinct `tabPhoto`.name) as total_count' - # column: 'tabPhoto.name' - if _fn := FN_PARAMS_PATTERN.findall(field): - column = _fn[0].replace("distinct ", "").replace("DISTINCT ", "").replace("`", "") - # field: 'distinct name' - # column: 'name' - else: - column = field.split(" ", 2)[1].replace("`", "") - else: - # field: 'count(`tabPhoto`.name) as total_count' - # column: 'tabPhoto.name' - column = field.split("(")[-1].split(")", 1)[0] - column = strip_alias(column).replace("`", "") + # field: 'count(distinct `tabPhoto`.name) as total_count' + # column: 'tabPhoto.name' + # field: 'count(`tabPhoto`.name) as total_count' + # column: 'tabPhoto.name' + columns = extract_fieldnames(field) + if not columns: + continue - if column == "*" and not in_function("*", field): - asterisk_fields.append(i) + column = columns[0] + if column == "*": + if "*" in field and not in_function("*", field): + asterisk_fields.append(i) continue # handle pseudo columns @@ -693,12 +689,9 @@ class DatabaseQuery: elif "(" in field: if "*" in field: continue - elif _params := FN_PARAMS_PATTERN.findall(field): - params = (x.strip() for x in _params[0].split(",")) - for param in params: - if not ( - not param or param in permitted_fields or param.isnumeric() or "'" in param or '"' in param - ): + else: + for column in columns: + if not column in permitted_fields: self.remove_field(i) break continue diff --git a/frappe/tests/test_reportview.py b/frappe/tests/test_reportview.py index 4a24514acf..7581dc42fb 100644 --- a/frappe/tests/test_reportview.py +++ b/frappe/tests/test_reportview.py @@ -2,7 +2,7 @@ # License: MIT. See LICENSE import frappe -from frappe.desk.reportview import export_query +from frappe.desk.reportview import export_query, extract_fieldnames from frappe.tests.utils import FrappeTestCase @@ -32,3 +32,63 @@ class TestReportview(FrappeTestCase): for row in reader: self.assertEqual(int(row["Is Single"]), 1) self.assertEqual(row["Module"], "Core") + + def test_extract_fieldname(self): + self.assertEqual( + extract_fieldnames("count(distinct `tabPhoto`.name) as total_count")[0], "tabPhoto.name" + ) + + self.assertEqual(extract_fieldnames("owner")[0], "owner") + + self.assertEqual(extract_fieldnames("module")[0], "module") + + self.assertEqual(extract_fieldnames("count(`tabPhoto`.name) as total_count")[0], "tabPhoto.name") + + self.assertEqual(extract_fieldnames("count(distinct `tabPhoto`.name)")[0], "tabPhoto.name") + + self.assertEqual(extract_fieldnames("count(`tabPhoto`.name)")[0], "tabPhoto.name") + + self.assertEqual( + extract_fieldnames("count(distinct `tabJob Applicant`.name) as total_count")[0], + "tabJob Applicant.name", + ) + + self.assertEqual( + extract_fieldnames("(1 / nullif(locate('a', `tabAddress`.`name`), 0)) as `_relevance`")[0], + "tabAddress.name", + ) + + self.assertEqual( + extract_fieldnames("(1 / nullif(locate('(a)', `tabAddress`.`name`), 0)) as `_relevance`")[0], + "tabAddress.name", + ) + + self.assertEqual( + extract_fieldnames("EXTRACT(MONTH FROM date_column) AS month")[0], "date_column" + ) + + self.assertEqual(extract_fieldnames("COUNT(*) AS count")[0], "*") + + self.assertEqual(extract_fieldnames("COUNT(1) AS count")[0], "*") + + self.assertEqual(extract_fieldnames("COUNT(1) AS count, SUM(1) AS sum")[0], "*") + + self.assertEqual( + extract_fieldnames("first_name + ' ' + last_name AS full_name"), ["first_name", "last_name"] + ) + + self.assertEqual( + extract_fieldnames("CONCAT(first_name, ' ', last_name) AS full_name"), + ["first_name", "last_name"], + ) + + self.assertEqual( + extract_fieldnames("CONCAT(id, '/', name, '/', age, '/', marks) AS student"), + ["id", "name", "age", "marks"], + ) + + self.assertEqual(extract_fieldnames("tablefield.fiedname")[0], "tablefield.fiedname") + + self.assertEqual( + extract_fieldnames("`tabChild DocType`.`fiedname`")[0], "tabChild DocType.fiedname" + ) diff --git a/pyproject.toml b/pyproject.toml index 423de6e3ee..e612d9226b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ dependencies = [ "semantic-version~=2.10.0", "sentry-sdk~=1.37.1", "sqlparse~=0.4.4", + "sql_metadata~=2.9.0", "tenacity~=8.2.2", "terminaltables~=3.1.10", "traceback-with-variables~=2.0.4",