refactor: Enhance field and function parsing in query engine
- Introduce `SqlFunctionParser` for robust parsing of supported SQL functions (e.g., `COUNT(*)`, `SUM(amount) as total`, `AVG(price - cost)`), replacing get_function_object and has_function.
- Refactor `DynamicTableField.parse` for improved handling of:
- Aliases (case-insensitive `as`, quoted/unquoted).
- `tabDocType.fieldname` notation (distinguishing child vs. main doctype refs).
- Add validation and better error handling during parsing.
- Rewrite filter field validation (`_validate_and_prepare_filter_field`):
- Disallow backticks (`) in filter field names.
- Enforce specific patterns for dot notation (link/child fields only, reject `tabDoc.field`).
- Validate character sets for simple field names.
- Update standard field parsing (`parse_string_field`, `ALLOWED_FIELD_PATTERN`, `FIELD_PARSE_REGEX`):
- Support quoted table names potentially containing spaces (e.g., `tabTable Name`.`field`).
- Improve `parse_fields` and `_parse_single_field_item` logic:
- Handle direct pypika `Field`/`AggregateFunction` inputs.
- Reliably split comma-separated field strings.
```
This commit is contained in:
parent
ddca77429c
commit
87664ad604
1 changed files with 347 additions and 164 deletions
|
|
@ -1,12 +1,13 @@
|
|||
import operator
|
||||
import re
|
||||
from ast import literal_eval
|
||||
from functools import lru_cache
|
||||
from types import BuiltinFunctionType
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeAlias, Union
|
||||
|
||||
import sqlparse
|
||||
from pypika.queries import QueryBuilder, Table
|
||||
from pypika.terms import Term
|
||||
from pypika.terms import AggregateFunction, Term
|
||||
|
||||
import frappe
|
||||
from frappe import _
|
||||
|
|
@ -25,7 +26,6 @@ if TYPE_CHECKING:
|
|||
TAB_PATTERN = re.compile("^tab")
|
||||
WORDS_PATTERN = re.compile(r"\w+")
|
||||
BRACKETS_PATTERN = re.compile(r"\(.*?\)|$")
|
||||
SQL_FUNCTIONS = tuple(f"{sql_function.value}(" for sql_function in SqlFunctions) # ) <- ignore this comment.
|
||||
COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
|
||||
|
||||
# less restrictive version of frappe.core.doctype.doctype.doctype.START_WITH_LETTERS_PATTERN
|
||||
|
|
@ -33,8 +33,8 @@ COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
|
|||
TABLE_NAME_PATTERN = re.compile(r"^[\w -]*$", flags=re.ASCII)
|
||||
|
||||
# Pattern to validate field names in SELECT:
|
||||
# Allows: name, `name`, name as alias, `name` as alias, `table`.`name`, `table`.`name` as alias, table.name, table.name as alias
|
||||
ALLOWED_FIELD_PATTERN = re.compile(r"^(?:`?\w+`?\.)?(`?\w+`?|\w+)(?:\s+as\s+\w+)?$", flags=re.ASCII)
|
||||
# Allows: name, `name`, name as alias, `name` as alias, `table name`.`name`, `table name`.`name` as alias, table.name, table.name as alias
|
||||
ALLOWED_FIELD_PATTERN = re.compile(r"^(?:`?[\w\s-]+`?\.)?(`?\w+`?|\w+)(?:\s+as\s+\w+)?$", flags=re.ASCII)
|
||||
|
||||
# Pattern to validate basic SQL function call syntax: word(...) [as alias]
|
||||
FUNCTION_CALL_PATTERN = re.compile(r"^\w+\(.*\)(?:\s+as\s+\w+)?$", flags=re.IGNORECASE | re.ASCII)
|
||||
|
|
@ -44,10 +44,30 @@ FUNCTION_CALL_PATTERN = re.compile(r"^\w+\(.*\)(?:\s+as\s+\w+)?$", flags=re.IGNO
|
|||
# Does NOT allow aliases ('as alias') or functions.
|
||||
ALLOWED_SQL_FIELD_PATTERN = re.compile(r"^(?:`?\w+`?\.)?(`?\w+`?|\w+)$", flags=re.ASCII)
|
||||
|
||||
# Pattern to validate characters allowed within function arguments that are not simple fields/literals.
|
||||
# Allows alphanumeric, underscore, whitespace, +, -, *, /, ., (, ), quotes, and the keyword 'distinct'.
|
||||
# Disallows characters like ; = < > etc. to prevent injection.
|
||||
ALLOWED_ARGUMENT_CHARS_PATTERN = re.compile(
|
||||
r"^(?:[\w\s\+\-\*\/\.\(\)\`\'\"]+|\bDISTINCT\b)+$", flags=re.IGNORECASE | re.ASCII
|
||||
)
|
||||
|
||||
# Regex to parse field names:
|
||||
# Group 1: Optional table name (e.g., `tabDocType` or tabDocType)
|
||||
# Group 2: Field name (e.g., `field` or field)
|
||||
FIELD_PARSE_REGEX = re.compile(r"^(?:[`\"]?(tab\w+)[`\"]?\.)?[`\"]?(\w+)[`\"]?$")
|
||||
# Group 1: Optional quote for table name
|
||||
# Group 2: Optional table name (e.g., `tabDocType` or tabDocType or `tabNote Seen By`)
|
||||
# Group 3: Optional quote for field name
|
||||
# Group 4: Field name (e.g., `field` or field)
|
||||
FIELD_PARSE_REGEX = re.compile(r"^(?:([`\"]?)(tab[\w\s-]+)\1\.)?([`\"]?)(\w+)\3$")
|
||||
|
||||
# Regex to capture: FunctionName(Arguments) [AS Alias]
|
||||
# Group 1: Function Name (e.g., COUNT, SUM)
|
||||
# Group 2: Arguments string (e.g., *, field1, 'literal', field2)
|
||||
# Group 3: Optional Alias (e.g., average_price or `average_price`) - allows backticks
|
||||
SQL_FUNCTION_PATTERN = re.compile(
|
||||
r"^([a-zA-Z_]\w*)\s*\((.*?)\)(?:\s+as\s+(`?[\w\s-]+`?|\w+))?$", flags=re.IGNORECASE | re.ASCII
|
||||
)
|
||||
|
||||
# Regex to split arguments, respecting potential quotes or nested parentheses
|
||||
ARGS_SPLIT_PATTERN = re.compile(r",\s*(?![^()]*\))")
|
||||
|
||||
|
||||
class Engine:
|
||||
|
|
@ -206,47 +226,6 @@ class Engine:
|
|||
|
||||
self._apply_filter(field, value, operator)
|
||||
|
||||
def _validate_and_prepare_filter_field(self, field: str | Field, doctype: str | None = None) -> Field:
|
||||
"""Validate field name for filters and return a pypika Field object. Handles dynamic fields."""
|
||||
_field = field
|
||||
is_fieldname_safe = False
|
||||
|
||||
if not isinstance(_field, str):
|
||||
# Assume it's a pypika Field or similar, return as is.
|
||||
return _field
|
||||
|
||||
# Always validate field name if it contains special characters to prevent injection
|
||||
if SPECIAL_CHAR_PATTERN.search(_field):
|
||||
# First, try to parse as a dynamic field (contains '.')
|
||||
dynamic_field = DynamicTableField.parse(_field, self.doctype)
|
||||
if dynamic_field:
|
||||
# Legitimate dynamic field (e.g., table.field), apply join
|
||||
self.query = dynamic_field.apply_join(self.query)
|
||||
_field = dynamic_field.field # _field is now a pypika Field object
|
||||
# If not a dynamic field and doesn't match the allowed pattern, reject it
|
||||
elif not ALLOWED_SQL_FIELD_PATTERN.match(_field):
|
||||
frappe.throw(
|
||||
_(
|
||||
"Invalid filter field format: {0}. Field names cannot contain special characters or disallowed patterns."
|
||||
).format(_field),
|
||||
frappe.PermissionError,
|
||||
)
|
||||
# If it matched the pattern (e.g., `fieldname` with backticks), mark as safe
|
||||
else:
|
||||
is_fieldname_safe = True
|
||||
# No special characters, treat as a standard field name, mark as safe
|
||||
else:
|
||||
is_fieldname_safe = True
|
||||
|
||||
# Convert string field name to pypika Field object if needed
|
||||
if is_fieldname_safe:
|
||||
# Note: We are converting the original `field` string here,
|
||||
# not the potentially modified `_field`
|
||||
# if it became a dynamic field object earlier.
|
||||
_field = frappe.qb.DocType(doctype or self.doctype)[field]
|
||||
|
||||
return _field
|
||||
|
||||
def _apply_filter(
|
||||
self,
|
||||
field: str | Field,
|
||||
|
|
@ -278,9 +257,18 @@ class Engine:
|
|||
docname = _value
|
||||
|
||||
# Use the original field name string for get_field if _field was converted
|
||||
# If _field is from a dynamic field, its name might be just the target fieldname.
|
||||
# We need the original string ('link.target') or the fieldname from the main doctype.
|
||||
original_field_name = field if isinstance(field, str) else _field.name
|
||||
_df = frappe.get_meta(self.doctype).get_field(original_field_name)
|
||||
ref_doctype = _df.options if _df else self.doctype
|
||||
# Check if the original field name exists in the *main* doctype meta
|
||||
main_meta = frappe.get_meta(self.doctype)
|
||||
if main_meta.has_field(original_field_name):
|
||||
_df = main_meta.get_field(original_field_name)
|
||||
ref_doctype = _df.options if _df else self.doctype
|
||||
else:
|
||||
# If not in main doctype, assume it's a standard field like 'name' or refers to the main doctype itself
|
||||
# This part might need refinement if nested set operators are used with dynamic fields.
|
||||
ref_doctype = self.doctype
|
||||
|
||||
nodes = get_nested_set_hierarchy_result(ref_doctype, docname, hierarchy)
|
||||
operator_fn = (
|
||||
|
|
@ -297,59 +285,54 @@ class Engine:
|
|||
else:
|
||||
self.query = self.query.where(operator_fn(_field, _value))
|
||||
|
||||
def get_function_object(self, field: str) -> "Function":
|
||||
"""Return PyPika Function object. Expect field to look like 'SUM(*)' or 'name' or something similar."""
|
||||
func = field.split("(", maxsplit=1)[0].capitalize()
|
||||
args_start, args_end = len(func) + 1, field.index(")")
|
||||
args = field[args_start:args_end].split(",")
|
||||
def _validate_and_prepare_filter_field(self, field: str | Field, doctype: str | None = None) -> Field:
|
||||
"""Validate field name for filters and return a pypika Field object. Handles dynamic fields."""
|
||||
|
||||
_, alias = field.split(" as ") if " as " in field else (None, None)
|
||||
if isinstance(field, Term):
|
||||
# return if field is already a pypika Term
|
||||
return field
|
||||
|
||||
to_cast = "*" not in args
|
||||
_args = []
|
||||
# Reject backticks
|
||||
if "`" in field:
|
||||
frappe.throw(
|
||||
_("Filter fields cannot contain backticks (`)."),
|
||||
frappe.ValidationError,
|
||||
title=_("Invalid Filter"),
|
||||
)
|
||||
|
||||
for arg in args:
|
||||
initial_fields = literal_eval_(arg.strip())
|
||||
if to_cast:
|
||||
has_primitive_operator = False
|
||||
for _operator in OPERATOR_MAP.keys():
|
||||
if _operator in initial_fields:
|
||||
operator_mapping = OPERATOR_MAP[_operator]
|
||||
# Only perform this if operator is of primitive type.
|
||||
if isinstance(operator_mapping, BuiltinFunctionType):
|
||||
has_primitive_operator = True
|
||||
field = operator_mapping(
|
||||
*map(
|
||||
lambda field: Field(field.strip())
|
||||
if "`" not in field
|
||||
else PseudoColumnMapper(field.strip()),
|
||||
arg.split(_operator),
|
||||
),
|
||||
)
|
||||
|
||||
field = (
|
||||
(
|
||||
Field(initial_fields)
|
||||
if "`" not in initial_fields
|
||||
else PseudoColumnMapper(initial_fields)
|
||||
)
|
||||
if not has_primitive_operator
|
||||
else field
|
||||
)
|
||||
# Handle dot notation (link_field.target_field or child_table_field.target_field)
|
||||
if "." in field:
|
||||
# Disallow tabDoc.field notation in filters.
|
||||
dynamic_field = DynamicTableField.parse(field, self.doctype, allow_tab_notation=False)
|
||||
if dynamic_field:
|
||||
# Parsed successfully as link/child field access
|
||||
self.query = dynamic_field.apply_join(self.query)
|
||||
# Return the pypika Field object associated with the dynamic field
|
||||
return dynamic_field.field
|
||||
else:
|
||||
field = initial_fields
|
||||
|
||||
_args.append(field)
|
||||
|
||||
if alias and "`" in alias:
|
||||
alias = alias.replace("`", "")
|
||||
try:
|
||||
if func.casefold() == "now":
|
||||
return getattr(functions, func)()
|
||||
return getattr(functions, func)(*_args, alias=alias or None)
|
||||
except AttributeError:
|
||||
# Fall back for functions not present in `SqlFunctions``
|
||||
return Function(func, *_args, alias=alias or None)
|
||||
# Contains '.' but is not a valid link/child field access pattern
|
||||
# This rejects tabDoc.field and other invalid formats like a.b.c
|
||||
frappe.throw(
|
||||
_(
|
||||
"Invalid filter field format: {0}. Use 'fieldname' or 'link_fieldname.target_fieldname'."
|
||||
).format(field),
|
||||
frappe.ValidationError,
|
||||
title=_("Invalid Filter"),
|
||||
)
|
||||
else:
|
||||
# No '.' and no '`'. Check if it's a simple field name (alphanumeric + underscore).
|
||||
if not re.fullmatch(r"\w+", field):
|
||||
frappe.throw(
|
||||
_(
|
||||
"Invalid characters in fieldname: {0}. Only letters, numbers, and underscores are allowed."
|
||||
).format(field),
|
||||
frappe.ValidationError,
|
||||
title=_("Invalid Filter"),
|
||||
)
|
||||
# It's a simple, valid fieldname like 'name' or 'creation'
|
||||
# Convert string field name to pypika Field object for the specified/current doctype
|
||||
target_doctype = doctype or self.doctype
|
||||
return frappe.qb.DocType(target_doctype)[field]
|
||||
|
||||
def parse_string_field(self, field: str):
|
||||
"""
|
||||
|
|
@ -361,6 +344,7 @@ class Engine:
|
|||
- `quoted_field`
|
||||
- tabDocType.simple_field
|
||||
- `tabDocType`.`quoted_field`
|
||||
- `tabTable Name`.`quoted_field`
|
||||
- Aliases for all above formats (e.g., field as alias)
|
||||
"""
|
||||
if field == "*":
|
||||
|
|
@ -380,10 +364,16 @@ class Engine:
|
|||
if not match:
|
||||
frappe.throw(_("Could not parse field: {0}").format(field))
|
||||
|
||||
table_name, field_name = match.groups()
|
||||
# Groups: 1: table_quote, 2: table_name_with_tab, 3: field_quote, 4: field_name
|
||||
groups = match.groups()
|
||||
table_name = groups[1] # This will be None if no table part (e.g., just 'field')
|
||||
field_name = groups[3] # This will be the field name (e.g., 'field')
|
||||
|
||||
if table_name:
|
||||
# Table name specified (e.g., `tabX`.`y` or tabX.y)
|
||||
# Table name specified (e.g., `tabX`.`y` or tabX.y or `tabX Y`.`y`)
|
||||
# Ensure the extracted table name is valid before creating DocType object
|
||||
if not TABLE_NAME_PATTERN.match(table_name.lstrip("tab")):
|
||||
frappe.throw(_("Invalid characters in table name: {0}").format(table_name))
|
||||
table_obj = frappe.qb.DocType(table_name)
|
||||
pypika_field = table_obj[field_name]
|
||||
else:
|
||||
|
|
@ -395,11 +385,59 @@ class Engine:
|
|||
else:
|
||||
return pypika_field
|
||||
|
||||
def parse_fields(
|
||||
self, fields: str | list | tuple | Field | AggregateFunction | None
|
||||
) -> "list[Field | AggregateFunction | Criterion | DynamicTableField | ChildQuery]":
|
||||
if not fields:
|
||||
return []
|
||||
|
||||
# Handle direct pypika Field or Function objects
|
||||
if isinstance(fields, Field | AggregateFunction):
|
||||
return [fields]
|
||||
|
||||
initial_field_list = []
|
||||
if isinstance(fields, str):
|
||||
# Split comma-separated fields passed as a single string
|
||||
initial_field_list.extend(f.strip() for f in COMMA_PATTERN.split(fields) if f.strip())
|
||||
elif isinstance(fields, list | tuple):
|
||||
for item in fields:
|
||||
if isinstance(item, str) and "," in item:
|
||||
# Split comma-separated strings within the list
|
||||
initial_field_list.extend(f.strip() for f in COMMA_PATTERN.split(item) if f.strip())
|
||||
else:
|
||||
# Add non-comma-separated items directly
|
||||
initial_field_list.append(item)
|
||||
|
||||
else:
|
||||
frappe.throw(_("Fields must be a string, list, tuple, pypika Field, or pypika Function"))
|
||||
|
||||
_fields = []
|
||||
# Iterate through the list where each item could be a single field, criterion, or a comma-separated string
|
||||
for item in initial_field_list:
|
||||
if isinstance(item, str):
|
||||
# Sanitize and split potentially comma-separated strings within the list
|
||||
sanitized_item = _sanitize_field(item.strip(), self.is_mariadb).strip()
|
||||
if sanitized_item:
|
||||
parsed = self._parse_single_field_item(sanitized_item)
|
||||
if isinstance(parsed, list): # Result from parsing a child query dict
|
||||
_fields.extend(parsed)
|
||||
elif parsed:
|
||||
_fields.append(parsed)
|
||||
else:
|
||||
# Handle non-string items (like dict for child query, or pre-parsed Field/Function)
|
||||
parsed = self._parse_single_field_item(item)
|
||||
if isinstance(parsed, list):
|
||||
_fields.extend(parsed)
|
||||
elif parsed:
|
||||
_fields.append(parsed)
|
||||
|
||||
return _fields
|
||||
|
||||
def _parse_single_field_item(
|
||||
self, field: str | Criterion | dict
|
||||
) -> list | Criterion | Field | "DynamicTableField" | "ChildQuery" | None:
|
||||
self, field: str | Criterion | dict | Field | Function
|
||||
) -> "list | Criterion | Field | Function | DynamicTableField | ChildQuery | None":
|
||||
"""Parses a single item from the fields list/tuple. Assumes comma-separated strings have already been split."""
|
||||
if isinstance(field, Criterion):
|
||||
if isinstance(field, Criterion | Field | Function):
|
||||
return field
|
||||
elif isinstance(field, dict):
|
||||
# Handle child queries defined as dicts {fieldname: [child_fields]}
|
||||
|
|
@ -418,9 +456,10 @@ class Engine:
|
|||
if not isinstance(field, str):
|
||||
frappe.throw(_("Invalid field type: {0}").format(type(field)))
|
||||
|
||||
# Check for functions or dynamic fields first
|
||||
if has_function(field):
|
||||
return self.get_function_object(field)
|
||||
# Try parsing as SQL function first
|
||||
if parsed_function := SqlFunctionParser.parse(field):
|
||||
return parsed_function
|
||||
# Then try parsing as dynamic field (link/child table access)
|
||||
elif parsed := DynamicTableField.parse(field, self.doctype):
|
||||
return parsed
|
||||
# Otherwise, parse as a standard field (simple, quoted, table-qualified, with/without alias)
|
||||
|
|
@ -428,38 +467,6 @@ class Engine:
|
|||
# Note: Comma handling is done in parse_fields before this method is called
|
||||
return self.parse_string_field(field)
|
||||
|
||||
def parse_fields(
|
||||
self, fields: str | list | tuple | None
|
||||
) -> list[Field | Criterion | "DynamicTableField" | "ChildQuery"]:
|
||||
if not fields:
|
||||
return []
|
||||
|
||||
sanitized_field_list = []
|
||||
if isinstance(fields, str):
|
||||
# Split comma-separated fields passed as a single string *before* sanitizing
|
||||
sanitized_field_list.extend(
|
||||
_sanitize_field(f.strip(), self.is_mariadb) for f in COMMA_PATTERN.split(fields) if f.strip()
|
||||
)
|
||||
elif isinstance(fields, list | tuple):
|
||||
# Sanitize fields if input is already a list/tuple
|
||||
sanitized_field_list.extend(
|
||||
_sanitize_field(field, self.is_mariadb) if isinstance(field, str) else field
|
||||
for field in fields
|
||||
)
|
||||
else:
|
||||
frappe.throw(_("Fields must be a string, list, or tuple"))
|
||||
|
||||
_fields = []
|
||||
# Iterate through the list where each item is a single field definition or criterion
|
||||
for field_item in sanitized_field_list:
|
||||
parsed = self._parse_single_field_item(field_item)
|
||||
if isinstance(parsed, list): # Result from parsing a child query dict
|
||||
_fields.extend(parsed)
|
||||
elif parsed:
|
||||
_fields.append(parsed)
|
||||
|
||||
return _fields
|
||||
|
||||
def _validate_group_by(self, group_by: str):
|
||||
"""Validate the group_by string argument."""
|
||||
if not isinstance(group_by, str):
|
||||
|
|
@ -592,7 +599,7 @@ class Engine:
|
|||
elif hasattr(field, "alias") and field.alias and field.name in permitted_fields_set:
|
||||
allowed_fields.append(field)
|
||||
|
||||
elif isinstance(field, PseudoColumnMapper):
|
||||
elif isinstance(field, PseudoColumnMapper | Function):
|
||||
# Typically functions or complex terms
|
||||
allowed_fields.append(field)
|
||||
|
||||
|
|
@ -788,24 +795,84 @@ class DynamicTableField:
|
|||
return f"{table_name}.{fieldname} {alias}".strip()
|
||||
|
||||
@staticmethod
|
||||
def parse(field: str, doctype: str):
|
||||
def parse(field: str, doctype: str, allow_tab_notation: bool = True):
|
||||
if "." in field:
|
||||
alias = None
|
||||
if " as " in field:
|
||||
field, alias = field.split(" as ")
|
||||
if field.startswith("`tab") or field.startswith('"tab'):
|
||||
_, child_doctype, child_field = re.search(r'([`"])tab(.+?)\1.\1(.+)\1', field).groups()
|
||||
if child_doctype == doctype:
|
||||
return
|
||||
return ChildTableField(child_doctype, child_field, doctype, alias=alias)
|
||||
# Handle 'as' alias, case-insensitive, taking the last occurrence
|
||||
if " as " in field.lower():
|
||||
parts = re.split(r"\s+as\s+", field, flags=re.IGNORECASE)
|
||||
if len(parts) > 1:
|
||||
field_part = parts[0].strip()
|
||||
alias = parts[-1].strip().strip('`"') # Get last part as alias
|
||||
field = field_part # Use the part before alias for further parsing
|
||||
|
||||
child_match = None
|
||||
if allow_tab_notation:
|
||||
# Regex to match `tabDoc`.`field`, "tabDoc"."field", tabDoc.field
|
||||
# Group 1: Doctype name (without 'tab')
|
||||
# Group 2: Optional quote for fieldname
|
||||
# Group 3: Fieldname
|
||||
# Ensures quotes are consistent or absent on fieldname using backreference \2
|
||||
# Uses re.match to ensure the pattern matches the *entire* field string
|
||||
# Allow spaces in doctype name (Group 1) and field name (Group 3)
|
||||
child_match = re.match(r'[`"]?tab([\w\s]+)[`"]?\.([`"]?)([\w\s]+)\2$', field)
|
||||
|
||||
if child_match:
|
||||
child_doctype_name = child_match.group(1)
|
||||
child_field = child_match.group(3)
|
||||
|
||||
if child_doctype_name == doctype:
|
||||
# Referencing a field in the main doctype using `tabDoctype.field` notation.
|
||||
# This should be handled by the standard field parser, not as a DynamicTableField.
|
||||
return None
|
||||
# Found a child table reference like tabChildDoc.child_field
|
||||
# Note: parent_fieldname is None here as it's directly specified via tab notation
|
||||
return ChildTableField(child_doctype_name, child_field, doctype, alias=alias)
|
||||
else:
|
||||
linked_fieldname, fieldname = field.split(".")
|
||||
linked_field = frappe.get_meta(doctype).get_field(linked_fieldname)
|
||||
linked_doctype = linked_field.options
|
||||
if linked_field.fieldtype == "Link":
|
||||
return LinkTableField(linked_doctype, fieldname, doctype, linked_fieldname, alias=alias)
|
||||
elif linked_field.fieldtype in frappe.model.table_fields:
|
||||
return ChildTableField(linked_doctype, fieldname, doctype, linked_fieldname, alias=alias)
|
||||
# Try parsing as LinkTableField (link_field.target_field) or ChildTableField (child_field.target_field)
|
||||
# This handles patterns not starting with 'tab' prefix
|
||||
if "." not in field: # Should not happen due to outer check, but safety
|
||||
return None
|
||||
|
||||
parts = field.split(".", 1)
|
||||
if len(parts) != 2: # Ensure it splits into exactly two parts
|
||||
return None
|
||||
potential_parent_fieldname, target_fieldname = parts
|
||||
|
||||
# Basic validation for the parts to avoid unnecessary metadata lookups on invalid input
|
||||
# We expect simple identifiers here. Quoted/complex names are handled elsewhere or by child_match.
|
||||
if (
|
||||
not potential_parent_fieldname.replace("_", "").isalnum()
|
||||
or not target_fieldname.replace("_", "").isalnum()
|
||||
):
|
||||
return None
|
||||
|
||||
try:
|
||||
meta = frappe.get_meta(doctype) # Get meta of the *parent* doctype
|
||||
# Check if the first part is a valid fieldname in the parent doctype
|
||||
if not meta.has_field(potential_parent_fieldname):
|
||||
return None # Not a field in the parent, so not link/child access pattern
|
||||
|
||||
linked_field = meta.get_field(potential_parent_fieldname)
|
||||
except Exception:
|
||||
# Handle cases where doctype doesn't exist, etc.
|
||||
print(f"Error getting metadata for {doctype} while parsing field {field}")
|
||||
return None
|
||||
|
||||
if linked_field:
|
||||
linked_doctype = linked_field.options
|
||||
if linked_field.fieldtype == "Link":
|
||||
# It's a Link field access: parent_doctype.link_fieldname.target_fieldname
|
||||
return LinkTableField(
|
||||
linked_doctype, target_fieldname, doctype, potential_parent_fieldname, alias=alias
|
||||
)
|
||||
elif linked_field.fieldtype in frappe.model.table_fields:
|
||||
# It's a Child Table field access: parent_doctype.child_table_fieldname.target_fieldname
|
||||
return ChildTableField(
|
||||
linked_doctype, target_fieldname, doctype, potential_parent_fieldname, alias=alias
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def apply_select(self, query: QueryBuilder) -> QueryBuilder:
|
||||
raise NotImplementedError
|
||||
|
|
@ -899,6 +966,129 @@ class ChildQuery:
|
|||
)
|
||||
|
||||
|
||||
class SqlFunctionParser:
|
||||
_supported_functions: ClassVar[dict[str, BuiltinFunctionType]] = {
|
||||
f.value.lower(): getattr(functions, f.name) for f in SqlFunctions if hasattr(functions, f.name)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _parse_argument_expression(arg_str: str) -> Term | None:
|
||||
"""Attempts to parse simple arithmetic expressions between fields."""
|
||||
# Map symbols to pypika's expected operation methods if needed, or rely on overloading
|
||||
# For +, -, *, / pypika Field overloading works directly
|
||||
supported_operators = {"+": operator.add, "-": operator.sub, "*": operator.mul, "/": operator.truediv}
|
||||
|
||||
for op_symbol, _op_func in supported_operators.items():
|
||||
# Split only on the first occurrence of the operator
|
||||
parts = arg_str.split(op_symbol, 1)
|
||||
if len(parts) == 2:
|
||||
left_str, right_str = parts[0].strip(), parts[1].strip()
|
||||
|
||||
# Validate both parts are valid field names (simple or quoted)
|
||||
if ALLOWED_SQL_FIELD_PATTERN.match(left_str.strip('`"')) and ALLOWED_SQL_FIELD_PATTERN.match(
|
||||
right_str.strip('`"')
|
||||
):
|
||||
# Create Field or PseudoColumnMapper objects
|
||||
left_field = (
|
||||
PseudoColumnMapper(left_str)
|
||||
if "`" in left_str or '"' in left_str
|
||||
else Field(left_str)
|
||||
)
|
||||
right_field = (
|
||||
PseudoColumnMapper(right_str)
|
||||
if "`" in right_str or '"' in right_str
|
||||
else Field(right_str)
|
||||
)
|
||||
|
||||
# Use pypika's operator overloading for Field objects
|
||||
if op_symbol == "+":
|
||||
return left_field + right_field
|
||||
elif op_symbol == "-":
|
||||
return left_field - right_field
|
||||
elif op_symbol == "*":
|
||||
return left_field * right_field
|
||||
elif op_symbol == "/":
|
||||
return left_field / right_field
|
||||
# If no simple binary arithmetic expression is found
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse(field_str: str) -> Function | None:
|
||||
"""
|
||||
Parses a string to see if it represents a *supported* SQL function call.
|
||||
Returns a pypika Function object if valid and supported, otherwise None.
|
||||
Handles simple arguments (field names, *), aliases, and simple expressions.
|
||||
"""
|
||||
match = SQL_FUNCTION_PATTERN.match(field_str.strip())
|
||||
if not match:
|
||||
return None
|
||||
|
||||
func_name, args_str, alias = match.groups()
|
||||
func_name_lower = func_name.lower()
|
||||
|
||||
# Strip backticks from alias if present
|
||||
if alias:
|
||||
alias = alias.strip("`")
|
||||
|
||||
# Check if the function is in our supported list
|
||||
pypika_func = SqlFunctionParser._supported_functions.get(func_name_lower)
|
||||
if not pypika_func:
|
||||
# Function name not found in SqlFunctions enum values
|
||||
return None
|
||||
|
||||
# Handle NOW() specifically (often takes no arguments)
|
||||
if func_name_lower == "now" and not args_str.strip():
|
||||
return pypika_func(alias=alias or None)
|
||||
|
||||
# Parse arguments
|
||||
parsed_args = []
|
||||
if args_str.strip():
|
||||
raw_args = ARGS_SPLIT_PATTERN.split(args_str.strip())
|
||||
for arg in raw_args:
|
||||
arg = arg.strip()
|
||||
if not arg:
|
||||
continue
|
||||
|
||||
if arg == "*":
|
||||
# Only allow '*' for specific functions like COUNT
|
||||
if func_name_lower != "count":
|
||||
frappe.throw(_("Wildcard '*' argument is only supported for COUNT function."))
|
||||
parsed_args.append(Term.wrap_constant("*"))
|
||||
continue
|
||||
|
||||
evaluated_arg = literal_eval_(arg)
|
||||
if evaluated_arg != arg: # Successfully evaluated to a literal
|
||||
parsed_args.append(Term.wrap_constant(evaluated_arg))
|
||||
else:
|
||||
# Not '*' or a simple literal. Could be a field, quoted field, keyword, or expression.
|
||||
# Check if it's a simple or quoted field name first.
|
||||
if ALLOWED_SQL_FIELD_PATTERN.match(arg.strip('`"')):
|
||||
# Pass the original arg (with quotes if present) to the mapper/field
|
||||
if "`" in arg or '"' in arg:
|
||||
parsed_args.append(PseudoColumnMapper(arg))
|
||||
else:
|
||||
parsed_args.append(Field(arg))
|
||||
# Check if it's a valid expression/keyword based on allowed characters
|
||||
elif ALLOWED_ARGUMENT_CHARS_PATTERN.match(arg):
|
||||
# Attempt to parse as a simple arithmetic expression first
|
||||
parsed_expr = SqlFunctionParser._parse_argument_expression(arg)
|
||||
if parsed_expr:
|
||||
parsed_args.append(parsed_expr)
|
||||
else:
|
||||
# Fallback: Pass the raw string argument for non-expression cases like 'distinct name'
|
||||
parsed_args.append(arg)
|
||||
else:
|
||||
# Argument contains disallowed characters.
|
||||
frappe.throw(
|
||||
_("Invalid characters or format in function argument expression: {0}").format(
|
||||
arg
|
||||
),
|
||||
frappe.ValidationError,
|
||||
)
|
||||
|
||||
return pypika_func(*parsed_args, alias=alias or None)
|
||||
|
||||
|
||||
def literal_eval_(literal):
|
||||
try:
|
||||
return literal_eval(literal)
|
||||
|
|
@ -906,13 +1096,6 @@ def literal_eval_(literal):
|
|||
return literal
|
||||
|
||||
|
||||
def has_function(field: str):
|
||||
if "`" not in field:
|
||||
field = field.casefold()
|
||||
|
||||
return any(func in field for func in SQL_FUNCTIONS)
|
||||
|
||||
|
||||
def get_nested_set_hierarchy_result(doctype: str, name: str, hierarchy: str) -> list[str]:
|
||||
"""Get matching nodes based on operator."""
|
||||
table = frappe.qb.DocType(doctype)
|
||||
|
|
@ -954,7 +1137,7 @@ def _validate_select_field(field: str):
|
|||
if field.isdigit():
|
||||
return
|
||||
|
||||
if ALLOWED_FIELD_PATTERN.match(field) or FUNCTION_CALL_PATTERN.match(field):
|
||||
if ALLOWED_FIELD_PATTERN.match(field) or SqlFunctionParser.parse(field):
|
||||
return
|
||||
|
||||
frappe.throw(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue