feat: support certain backticked expressions
Signed-off-by: Akhil Narang <me@akhilnarang.dev>
This commit is contained in:
parent
7ad6f7e2c6
commit
943df998d6
4 changed files with 144 additions and 28 deletions
|
|
@ -370,12 +370,11 @@ class DataExporter:
|
|||
order_by = None
|
||||
table_columns = frappe.db.get_table_columns(self.parent_doctype)
|
||||
if "lft" in table_columns and "rgt" in table_columns:
|
||||
order_by = DocType(self.parent_doctype).lft
|
||||
|
||||
order_by = f"`tab{self.parent_doctype}`.`lft` asc"
|
||||
# get permitted data only
|
||||
self.data = frappe.qb.get_query(
|
||||
self.doctype, fields=["*"], filters=self.filters, order_by=order_by
|
||||
).run(as_dict=True)
|
||||
self.data = frappe.get_list(
|
||||
self.doctype, fields=["*"], filters=self.filters, limit_page_length=None, order_by=order_by
|
||||
)
|
||||
|
||||
for doc in self.data:
|
||||
op = self.docs_to_export.get("op")
|
||||
|
|
|
|||
|
|
@ -291,7 +291,7 @@ class Report(Document):
|
|||
@staticmethod
|
||||
def _format(parts):
|
||||
# sort by is saved as DocType.fieldname, covert it to sql
|
||||
return parts[1]
|
||||
return "`tab{}`.`{}`".format(*parts)
|
||||
|
||||
def get_standard_report_columns(self, params):
|
||||
if params.get("fields"):
|
||||
|
|
|
|||
|
|
@ -49,8 +49,15 @@ def _is_function_call(field_str: str) -> bool:
|
|||
|
||||
|
||||
# Pattern to validate field names in SELECT:
|
||||
# Allows: name, `name`, name as alias, `name` as alias, `table name`.`name`, `table name`.`name` as alias, table.name, table.name as alias
|
||||
ALLOWED_FIELD_PATTERN = re.compile(r"^(?:(`[\w\s-]+`|\w+)\.)?(`\w+`|\w+)(?:\s+as\s+\w+)?$", flags=re.ASCII)
|
||||
# Allows: name, `name`, name as alias, `name` as alias, table.name, table.name as alias
|
||||
# Also allows backtick-qualified identifiers with spaces/hyphens:
|
||||
# - `tabTable`.`field`
|
||||
# - `tabTable Name`.`field` (spaces in table name)
|
||||
# - `tabTable-Field`.`field` (hyphens in table name)
|
||||
# - Any of above with aliases: ... as alias
|
||||
ALLOWED_FIELD_PATTERN = re.compile(
|
||||
r"^(?:(`[\w\s-]+`|\w+)\.)?(`[\w\s-]+`|\w+)(?:\s+as\s+\w+)?$", flags=re.ASCII | re.IGNORECASE
|
||||
)
|
||||
|
||||
# Regex to parse field names:
|
||||
# Group 1: Optional quote for table name
|
||||
|
|
@ -551,10 +558,17 @@ class Engine:
|
|||
# return if field is already a pypika Term
|
||||
return field
|
||||
|
||||
# Reject backticks
|
||||
# Parse backtick table.field notation: `tabDocType`.`fieldname`
|
||||
if "`" in field:
|
||||
if parsed := self._parse_backtick_field_notation(field):
|
||||
table_name, field_name = parsed
|
||||
|
||||
# Return query builder field reference
|
||||
return frappe.qb.DocType(table_name)[field_name]
|
||||
|
||||
# If parsing failed, fall through to error handling below
|
||||
frappe.throw(
|
||||
_("Filter fields cannot contain backticks (`)."),
|
||||
_("Filter fields have invalid backtick notation: {0}").format(field),
|
||||
frappe.ValidationError,
|
||||
title=_("Invalid Filter"),
|
||||
)
|
||||
|
|
@ -743,7 +757,9 @@ class Engine:
|
|||
# Ensure the extracted table name is valid before creating DocType object
|
||||
if not TABLE_NAME_PATTERN.match(table_name.lstrip("tab")):
|
||||
frappe.throw(_("Invalid characters in table name: {0}").format(table_name))
|
||||
table_obj = frappe.qb.DocType(table_name)
|
||||
|
||||
doctype_name = table_name[3:] if table_name.startswith("tab") else table_name
|
||||
table_obj = frappe.qb.DocType(doctype_name)
|
||||
pypika_field = table_obj[field_name]
|
||||
else:
|
||||
# Simple field name (e.g., `y` or y) - use the main table
|
||||
|
|
@ -802,10 +818,11 @@ class Engine:
|
|||
return _fields
|
||||
|
||||
def _parse_single_field_item(
|
||||
self, field: str | Criterion | dict | Field
|
||||
self, field: str | Criterion | dict | Field | Term
|
||||
) -> "list | Criterion | Field | DynamicTableField | ChildQuery | None":
|
||||
"""Parses a single item from the fields list/tuple. Assumes comma-separated strings have already been split."""
|
||||
if isinstance(field, Criterion | Field):
|
||||
if isinstance(field, Term):
|
||||
# Accept any pypika Term (Field, Criterion, ArithmeticExpression, AggregateFunction, etc.)
|
||||
return field
|
||||
elif isinstance(field, dict):
|
||||
# Check if it's a SQL function or operator dictionary
|
||||
|
|
@ -879,6 +896,76 @@ class Engine:
|
|||
order_direction = Order.desc if sort_order.lower() == "desc" else Order.asc
|
||||
self.query = self.query.orderby(field, order=order_direction)
|
||||
|
||||
def _parse_backtick_field_notation(self, field_name: str) -> tuple[str, str] | None:
|
||||
"""
|
||||
Parse backtick field notation like `tabDocType`.`fieldname` or `tabDocType`.fieldname and return (table_name, field_name).
|
||||
Uses sqlparse for robust SQL parsing with Identifier support.
|
||||
Returns None if the notation is invalid.
|
||||
"""
|
||||
import sqlparse
|
||||
from sqlparse.sql import Identifier
|
||||
|
||||
# Parse the field name as SQL
|
||||
parsed = sqlparse.parse(field_name.strip())
|
||||
if not parsed or not parsed[0].tokens:
|
||||
return None
|
||||
|
||||
tokens = parsed[0].tokens
|
||||
|
||||
# Filter out whitespace tokens
|
||||
non_ws_tokens = [t for t in tokens if not t.is_whitespace]
|
||||
|
||||
if len(non_ws_tokens) != 1:
|
||||
return None
|
||||
|
||||
# Check if it's an Identifier (which handles table.field notation)
|
||||
first_token = non_ws_tokens[0]
|
||||
if not isinstance(first_token, Identifier):
|
||||
return None
|
||||
|
||||
# Get the sub-tokens within the identifier
|
||||
# Should have: `tabTable` (Name), `.` (Punctuation), `fieldname` (Name)
|
||||
identifier_tokens = [t for t in first_token.tokens if not t.is_whitespace]
|
||||
|
||||
if len(identifier_tokens) != 3:
|
||||
return None
|
||||
|
||||
table_token = identifier_tokens[0]
|
||||
dot_token = identifier_tokens[1]
|
||||
field_token = identifier_tokens[2]
|
||||
|
||||
# Verify the dot
|
||||
if str(dot_token).strip() != ".":
|
||||
return None
|
||||
|
||||
# Extract and validate table name (should be backtick-quoted)
|
||||
table_str = str(table_token).strip()
|
||||
if not (table_str.startswith("`") and table_str.endswith("`")):
|
||||
return None
|
||||
|
||||
# Extract field name (can be backtick-quoted or unquoted)
|
||||
field_str = str(field_token).strip()
|
||||
# Remove backticks if present
|
||||
if field_str.startswith("`") and field_str.endswith("`"):
|
||||
field_str = field_str[1:-1]
|
||||
|
||||
# Remove backticks from table name
|
||||
table_name = table_str[1:-1]
|
||||
field_name = field_str
|
||||
|
||||
# Validate table name starts with "tab"
|
||||
if not table_name.startswith("tab"):
|
||||
return None
|
||||
|
||||
# Extract doctype name by stripping "tab" prefix
|
||||
doctype_name = table_name[3:]
|
||||
|
||||
# Validate doctype name is not empty and table actually exists
|
||||
if not doctype_name or not frappe.db.table_exists(doctype_name):
|
||||
return None
|
||||
|
||||
return (doctype_name, field_name)
|
||||
|
||||
def _validate_and_parse_field_for_clause(self, field_name: str, clause_name: str) -> Field:
|
||||
"""
|
||||
Common helper to validate and parse field names for GROUP BY and ORDER BY clauses.
|
||||
|
|
@ -898,10 +985,15 @@ class Engine:
|
|||
if field_name in self.function_aliases:
|
||||
return Field(field_name)
|
||||
|
||||
# Reject backticks
|
||||
# Parse backtick table.field notation: `tabDocType`.`fieldname`
|
||||
if "`" in field_name:
|
||||
if parsed := self._parse_backtick_field_notation(field_name):
|
||||
table_name, field_name = parsed
|
||||
return frappe.qb.DocType(table_name)[field_name]
|
||||
|
||||
# If parsing failed, fall through to error handling below
|
||||
frappe.throw(
|
||||
_("{0} fields cannot contain backticks (`): {1}").format(clause_name, field_name),
|
||||
_("{0} has invalid backtick notation: {1}").format(clause_name, field_name),
|
||||
frappe.ValidationError,
|
||||
)
|
||||
|
||||
|
|
@ -967,11 +1059,16 @@ class Engine:
|
|||
|
||||
for declaration in order_by.split(","):
|
||||
if _order_by := declaration.strip():
|
||||
# Extract direction from end of declaration (handles backtick identifiers with spaces)
|
||||
# Check if the last word is a valid direction
|
||||
parts = _order_by.split()
|
||||
field_name = parts[0]
|
||||
direction = None
|
||||
if len(parts) > 1:
|
||||
direction = parts[1].lower()
|
||||
field_name = _order_by
|
||||
|
||||
if len(parts) > 1 and parts[-1].lower() in valid_directions:
|
||||
# Last part is a direction, so field_name is everything before it
|
||||
direction = parts[-1].lower()
|
||||
field_name = " ".join(parts[:-1])
|
||||
|
||||
order_direction = Order.desc if direction == "desc" else Order.asc
|
||||
|
||||
|
|
@ -980,7 +1077,7 @@ class Engine:
|
|||
|
||||
if direction and direction not in valid_directions:
|
||||
frappe.throw(
|
||||
_("Invalid direction in Order By: {0}. Must be 'ASC' or 'DESC'.").format(parts[1]),
|
||||
_("Invalid direction in Order By: {0}. Must be 'ASC' or 'DESC'.").format(direction),
|
||||
ValueError,
|
||||
)
|
||||
|
||||
|
|
@ -1852,6 +1949,20 @@ class SQLFunctionParser:
|
|||
if self._is_string_literal(arg):
|
||||
return self._validate_string_literal(arg)
|
||||
|
||||
# Check for backtick notation: `tabDocType`.`fieldname`
|
||||
# Parse and return as Field object to preserve field reference in operators
|
||||
elif "`" in arg:
|
||||
if parsed := self.engine._parse_backtick_field_notation(arg):
|
||||
table_name, field_name = parsed
|
||||
return Table(f"tab{table_name}")[field_name]
|
||||
else:
|
||||
frappe.throw(
|
||||
_(
|
||||
"Invalid argument format: {0}. Only quoted string literals or simple field names are allowed."
|
||||
).format(arg),
|
||||
frappe.ValidationError,
|
||||
)
|
||||
|
||||
elif self._is_valid_field_name(arg):
|
||||
# Validate field name and check permissions
|
||||
self._validate_function_field_arg(arg)
|
||||
|
|
|
|||
|
|
@ -197,11 +197,10 @@ class TestQuery(IntegrationTestCase):
|
|||
|
||||
def test_field_validation_filters(self):
|
||||
"""Test validation for fields used in filters (WHERE clause)."""
|
||||
valid_fields = ["name", "creation", "language.name"]
|
||||
valid_fields = ["name", "creation", "language.name", "`tabUser`.`name`"]
|
||||
# Filters should not allow aliases or functions directly as field names
|
||||
invalid_fields = [
|
||||
"tabUser.name",
|
||||
"`tabUser`.`name`",
|
||||
"name as alias",
|
||||
"`name` as alias",
|
||||
"tabUser.name as alias",
|
||||
|
|
@ -248,6 +247,7 @@ class TestQuery(IntegrationTestCase):
|
|||
"1", # Allow numeric indices
|
||||
"name, email",
|
||||
"1, 2",
|
||||
"`tabUser`.`name`",
|
||||
]
|
||||
# GROUP BY should not allow aliases or functions
|
||||
invalid_fields = [
|
||||
|
|
@ -262,7 +262,6 @@ class TestQuery(IntegrationTestCase):
|
|||
"table.invalid-field",
|
||||
"tabUser.name",
|
||||
"`name`",
|
||||
"`tabUser`.`name`",
|
||||
"`name`, `tabUser`.`email`",
|
||||
"`table`.`invalid-field`",
|
||||
"field with space",
|
||||
|
|
@ -293,6 +292,8 @@ class TestQuery(IntegrationTestCase):
|
|||
"2 DESC",
|
||||
"name, email",
|
||||
"1 asc, 2 desc",
|
||||
"`tabUser`.`name`",
|
||||
"`tabUser`.`name` desc",
|
||||
]
|
||||
# ORDER BY should not allow aliases or functions, or invalid directions
|
||||
invalid_fields = [
|
||||
|
|
@ -305,10 +306,8 @@ class TestQuery(IntegrationTestCase):
|
|||
"name /* comment */",
|
||||
"`name`",
|
||||
"tabUser.name",
|
||||
"`tabUser`.`name`",
|
||||
"`name` DESC",
|
||||
"tabUser.name Asc",
|
||||
"`tabUser`.`name` desc",
|
||||
"`name` asc, `tabUser`.`email` DESC",
|
||||
"invalid-field-name",
|
||||
"table.invalid-field",
|
||||
|
|
@ -1629,18 +1628,25 @@ class TestQuery(IntegrationTestCase):
|
|||
frappe.qb.get_query("User", order_by=field).get_sql()
|
||||
|
||||
def test_backtick_rejection_group_order(self):
|
||||
"""Test that backticks are properly rejected in GROUP BY and ORDER BY."""
|
||||
"""Test that malformed backticks are properly rejected in GROUP BY and ORDER BY."""
|
||||
# Test single backtick (invalid notation - should be `tabTable`.`field`)
|
||||
with self.assertRaises(frappe.ValidationError) as cm:
|
||||
frappe.qb.get_query("User", group_by="`name`").get_sql()
|
||||
self.assertIn("cannot contain backticks", str(cm.exception))
|
||||
self.assertIn("invalid backtick notation", str(cm.exception))
|
||||
|
||||
# Test single backtick with direction (invalid notation)
|
||||
with self.assertRaises(frappe.ValidationError) as cm:
|
||||
frappe.qb.get_query("User", order_by="`name` ASC").get_sql()
|
||||
self.assertIn("cannot contain backticks", str(cm.exception))
|
||||
self.assertIn("invalid backtick notation", str(cm.exception))
|
||||
|
||||
# Test multiple single backticks (invalid notation)
|
||||
with self.assertRaises(frappe.ValidationError) as cm:
|
||||
frappe.qb.get_query("User", group_by="`name`, `email`").get_sql()
|
||||
self.assertIn("cannot contain backticks", str(cm.exception))
|
||||
self.assertIn("invalid backtick notation", str(cm.exception))
|
||||
|
||||
# Valid backtick notation should work
|
||||
frappe.qb.get_query("User", group_by="`tabUser`.`name`").get_sql()
|
||||
frappe.qb.get_query("User", order_by="`tabUser`.`name` ASC").get_sql()
|
||||
|
||||
def test_sql_functions_in_fields(self):
|
||||
"""Test SQL function support in fields with various syntaxes."""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue