From c0cf13b8e863c0538d7dee2a782d254d70e624e4 Mon Sep 17 00:00:00 2001 From: Ankush Menat Date: Wed, 17 Jan 2024 12:44:20 +0530 Subject: [PATCH] Revert "fix: search_link fails when txt contains parentheses (#22892)" This reverts commit 642e9f4ec1cf4d90e4484bc39b279e8c28f9f8a2. --- frappe/desk/reportview.py | 40 +++++++++++---------- frappe/model/db_query.py | 39 ++++++++++++--------- frappe/tests/test_reportview.py | 62 +-------------------------------- pyproject.toml | 1 - 4 files changed, 45 insertions(+), 97 deletions(-) diff --git a/frappe/desk/reportview.py b/frappe/desk/reportview.py index ee04d4cc97..746c6b299f 100644 --- a/frappe/desk/reportview.py +++ b/frappe/desk/reportview.py @@ -5,8 +5,6 @@ import json -from sql_metadata import Parser - import frappe import frappe.permissions from frappe import _ @@ -93,10 +91,7 @@ def validate_fields(data): wildcard = update_wildcard_field_param(data) for field in list(data.fields or []): - fieldname = extract_fieldnames(field)[0] - if not fieldname: - raise_invalid_field(fieldname) - + fieldname = extract_fieldname(field) if is_standard(fieldname): continue @@ -178,16 +173,23 @@ def is_standard(fieldname): ) -def extract_fieldnames(field): - parser = Parser(f"select {field}, _frappe_dummy from _dummy") - columns = [col for col in parser.columns if col != "_frappe_dummy"] +def extract_fieldname(field): + for text in (",", "/*", "#"): + if text in field: + raise_invalid_field(field) - if not columns: - f = field.lower() - if "count(" in f or "sum(" in f or "avg(" in f: - return ["*"] + fieldname = field + for sep in (" as ", " AS "): + if sep in fieldname: + fieldname = fieldname.split(sep, 1)[0] - return columns + # certain functions allowed, extract the fieldname from the function + if fieldname.startswith("count(") or fieldname.startswith("sum(") or fieldname.startswith("avg("): + if not fieldname.strip().endswith(")"): + raise_invalid_field(field) + fieldname = fieldname.split("(", 1)[1][:-1] + + return fieldname def get_meta_and_docfield(fieldname, data): @@ -234,13 +236,13 @@ def get_parenttype_and_fieldname(field, data): parts = field.split(".") parenttype = parts[0] fieldname = parts[1] - df = frappe.get_meta(data.doctype).get_field(parenttype) - if not df: - # tabChild DocType.fieldname - parenttype = parenttype[3:] + if parenttype.startswith("`tab"): + # `tabChild DocType`.`fieldname` + parenttype = parenttype[4:-1] + fieldname = fieldname.strip("`") else: # tablefield.fieldname - parenttype = df.options + parenttype = frappe.get_meta(data.doctype).get_field(parenttype).options else: parenttype = data.doctype fieldname = field.strip("`") diff --git a/frappe/model/db_query.py b/frappe/model/db_query.py index 3f5b5ccf79..106172a508 100644 --- a/frappe/model/db_query.py +++ b/frappe/model/db_query.py @@ -52,6 +52,7 @@ FIELD_COMMA_PATTERN = re.compile(r"[0-9a-zA-Z]+\s*,") STRICT_FIELD_PATTERN = re.compile(r".*/\*.*") STRICT_UNION_PATTERN = re.compile(r".*\s(union).*\s") ORDER_GROUP_PATTERN = re.compile(r".*[^a-z0-9-_ ,`'\"\.\(\)].*") +FN_PARAMS_PATTERN = re.compile(r".*?\((.*)\).*") SPECIAL_FIELD_CHARS = frozenset(("(", "`", ".", "'", '"', "*")) @@ -625,8 +626,6 @@ class DatabaseQuery: - Query: fields=["*"] - Result: fields=["title", ...] // will also include Frappe's meta field like `name`, `owner`, etc. """ - from frappe.desk.reportview import extract_fieldnames - if self.flags.ignore_permissions: return @@ -639,18 +638,23 @@ class DatabaseQuery: ) for i, field in enumerate(self.fields): - # field: 'count(distinct `tabPhoto`.name) as total_count' - # column: 'tabPhoto.name' - # field: 'count(`tabPhoto`.name) as total_count' - # column: 'tabPhoto.name' - columns = extract_fieldnames(field) - if not columns: - continue + if "distinct" in field.lower(): + # field: 'count(distinct `tabPhoto`.name) as total_count' + # column: 'tabPhoto.name' + if _fn := FN_PARAMS_PATTERN.findall(field): + column = _fn[0].replace("distinct ", "").replace("DISTINCT ", "").replace("`", "") + # field: 'distinct name' + # column: 'name' + else: + column = field.split(" ", 2)[1].replace("`", "") + else: + # field: 'count(`tabPhoto`.name) as total_count' + # column: 'tabPhoto.name' + column = field.split("(")[-1].split(")", 1)[0] + column = strip_alias(column).replace("`", "") - column = columns[0] - if column == "*": - if "*" in field and not in_function("*", field): - asterisk_fields.append(i) + if column == "*" and not in_function("*", field): + asterisk_fields.append(i) continue # handle pseudo columns @@ -689,9 +693,12 @@ class DatabaseQuery: elif "(" in field: if "*" in field: continue - else: - for column in columns: - if not column in permitted_fields: + elif _params := FN_PARAMS_PATTERN.findall(field): + params = (x.strip() for x in _params[0].split(",")) + for param in params: + if not ( + not param or param in permitted_fields or param.isnumeric() or "'" in param or '"' in param + ): self.remove_field(i) break continue diff --git a/frappe/tests/test_reportview.py b/frappe/tests/test_reportview.py index 7581dc42fb..4a24514acf 100644 --- a/frappe/tests/test_reportview.py +++ b/frappe/tests/test_reportview.py @@ -2,7 +2,7 @@ # License: MIT. See LICENSE import frappe -from frappe.desk.reportview import export_query, extract_fieldnames +from frappe.desk.reportview import export_query from frappe.tests.utils import FrappeTestCase @@ -32,63 +32,3 @@ class TestReportview(FrappeTestCase): for row in reader: self.assertEqual(int(row["Is Single"]), 1) self.assertEqual(row["Module"], "Core") - - def test_extract_fieldname(self): - self.assertEqual( - extract_fieldnames("count(distinct `tabPhoto`.name) as total_count")[0], "tabPhoto.name" - ) - - self.assertEqual(extract_fieldnames("owner")[0], "owner") - - self.assertEqual(extract_fieldnames("module")[0], "module") - - self.assertEqual(extract_fieldnames("count(`tabPhoto`.name) as total_count")[0], "tabPhoto.name") - - self.assertEqual(extract_fieldnames("count(distinct `tabPhoto`.name)")[0], "tabPhoto.name") - - self.assertEqual(extract_fieldnames("count(`tabPhoto`.name)")[0], "tabPhoto.name") - - self.assertEqual( - extract_fieldnames("count(distinct `tabJob Applicant`.name) as total_count")[0], - "tabJob Applicant.name", - ) - - self.assertEqual( - extract_fieldnames("(1 / nullif(locate('a', `tabAddress`.`name`), 0)) as `_relevance`")[0], - "tabAddress.name", - ) - - self.assertEqual( - extract_fieldnames("(1 / nullif(locate('(a)', `tabAddress`.`name`), 0)) as `_relevance`")[0], - "tabAddress.name", - ) - - self.assertEqual( - extract_fieldnames("EXTRACT(MONTH FROM date_column) AS month")[0], "date_column" - ) - - self.assertEqual(extract_fieldnames("COUNT(*) AS count")[0], "*") - - self.assertEqual(extract_fieldnames("COUNT(1) AS count")[0], "*") - - self.assertEqual(extract_fieldnames("COUNT(1) AS count, SUM(1) AS sum")[0], "*") - - self.assertEqual( - extract_fieldnames("first_name + ' ' + last_name AS full_name"), ["first_name", "last_name"] - ) - - self.assertEqual( - extract_fieldnames("CONCAT(first_name, ' ', last_name) AS full_name"), - ["first_name", "last_name"], - ) - - self.assertEqual( - extract_fieldnames("CONCAT(id, '/', name, '/', age, '/', marks) AS student"), - ["id", "name", "age", "marks"], - ) - - self.assertEqual(extract_fieldnames("tablefield.fiedname")[0], "tablefield.fiedname") - - self.assertEqual( - extract_fieldnames("`tabChild DocType`.`fiedname`")[0], "tabChild DocType.fiedname" - ) diff --git a/pyproject.toml b/pyproject.toml index e612d9226b..423de6e3ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,6 @@ dependencies = [ "semantic-version~=2.10.0", "sentry-sdk~=1.37.1", "sqlparse~=0.4.4", - "sql_metadata~=2.9.0", "tenacity~=8.2.2", "terminaltables~=3.1.10", "traceback-with-variables~=2.0.4",