From 8b73108270e6d5ea5e68891c34ffb33fa6bc44ff Mon Sep 17 00:00:00 2001 From: Aradhya Date: Thu, 3 Nov 2022 20:43:24 +0530 Subject: [PATCH] feat: added PseudoColumnMapper for postgres support --- frappe/database/query.py | 38 ++++++++++++++++++++--------------- frappe/query_builder/utils.py | 10 +++++++++ 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/frappe/database/query.py b/frappe/database/query.py index 1d0447e512..a9dab02744 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -15,8 +15,7 @@ from frappe.database.utils import NestedSetHierarchy, is_pypika_function_object from frappe.model.db_query import get_timespan_date_range from frappe.query_builder import Criterion, Field, Order, Table, functions from frappe.query_builder.functions import Function, SqlFunctions -from frappe.query_builder.utils import PseudoColumn -from frappe.utils import cstr +from frappe.query_builder.utils import PseudoColumnMapper from frappe.utils.data import MARIADB_SPECIFIC_COMMENT if TYPE_CHECKING: @@ -27,7 +26,6 @@ WORDS_PATTERN = re.compile(r"\w+") BRACKETS_PATTERN = re.compile(r"\(.*?\)|$") SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions] COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))") -TABLE_PATTERN = re.compile(r"`\btab\w+") def like(key: Field, value: str) -> frappe.qb: @@ -166,8 +164,11 @@ def has_function(field): def table_from_string(table: str) -> "DocType": - table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:] - return frappe.qb.DocType(table_name=table_name.replace("`", "")) + if frappe.db.db_type == "postgres": + table_name = table.split('"', maxsplit=1)[1].split(".")[0][3:].replace('"', "") + else: + table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:].replace("`", "") + return frappe.qb.DocType(table_name=table_name) def get_nested_set_hierarchy_result(hierarchy: str, field: str, table: str): @@ -467,13 +468,15 @@ class Engine: has_primitive_operator = True field = operator_mapping( *map( - lambda field: Field(field.strip()) if "`" not in field else PseudoColumn(field.strip()), + lambda field: Field(field.strip()) + if "`" not in field + else PseudoColumnMapper(field.strip()), arg.split(_operator), ), ) field = ( - (Field(initial_fields) if "`" not in initial_fields else PseudoColumn(initial_fields)) + (Field(initial_fields) if "`" not in initial_fields else PseudoColumnMapper(initial_fields)) if not has_primitive_operator else field ) @@ -590,11 +593,11 @@ class Engine: if " as " in field: field, reference = field.split(" as ") if "`" in field: - updated_fields.append(PseudoColumn(f"{field} {reference}")) + updated_fields.append(PseudoColumnMapper(f"{field} {reference}")) else: updated_fields.append(Field(field.strip()).as_(reference)) elif "`" in str(field): - updated_fields.append(PseudoColumn(field.strip())) + updated_fields.append(PseudoColumnMapper(field.strip())) else: updated_fields.append(Field(field)) return updated_fields @@ -603,11 +606,11 @@ class Engine: if fields == "*": return fields if "`" in fields: - fields = PseudoColumn(fields) + fields = PseudoColumnMapper(fields) if " as " in str(fields): fields, reference = str(fields).split(" as ") if "`" in str(fields): - fields = PseudoColumn(f"{fields} as {reference}") + fields = PseudoColumnMapper(f"{fields} {reference}") else: fields = Field(fields).as_(reference) return fields @@ -667,12 +670,15 @@ class Engine: def join(self, criterion, fields, table, join_type): """Handles all join operations on criterion objects""" has_join = False + table_pattern = ( + re.compile(r"`\btab\w+") if frappe.db.db_type == "mariadb" else re.compile(r'"\btab\w+') + ) def _update_pypika_fields(field): if not is_pypika_function_object(field): - field = field if isinstance(field, (str, PseudoColumn)) else field.get_sql() - if not TABLE_PATTERN.search(str(field)): - if isinstance(field, PseudoColumn): + field = field if isinstance(field, (str, PseudoColumnMapper)) else field.get_sql() + if not table_pattern.search(str(field)): + if isinstance(field, PseudoColumnMapper): field = field.get_sql() return getattr(frappe.qb.DocType(table), field) else: @@ -686,8 +692,8 @@ class Engine: # Only perform this bit if foreign doctype in fields if ( not is_pypika_function_object(field) - and str(field).startswith("`tab") - and (f"`tab{table}`" not in str(field)) + and (str(field).startswith('"tab') or str(field).startswith("`tab")) + and (f"`tab{table}`" not in str(field) and f'tab{table}"' not in str(field)) ): has_join = True child_table = table_from_string(str(field)) diff --git a/frappe/query_builder/utils.py b/frappe/query_builder/utils.py index f0130ca813..8780841e03 100644 --- a/frappe/query_builder/utils.py +++ b/frappe/query_builder/utils.py @@ -12,6 +12,16 @@ from frappe.query_builder.terms import NamedParameterWrapper from .builder import MariaDB, Postgres +class PseudoColumnMapper(PseudoColumn): + def __init__(self, name: str) -> None: + super().__init__(name) + + def get_sql(self, **kwargs): + if frappe.db.db_type == "postgres": + self.name = self.name.replace("`", '"') + return self.name + + class db_type_is(Enum): MARIADB = "mariadb" POSTGRES = "postgres"