feat: added PseudoColumnMapper for postgres support

This commit is contained in:
Aradhya 2022-11-03 20:43:24 +05:30
parent dfe62f2614
commit 8b73108270
2 changed files with 32 additions and 16 deletions

View file

@ -15,8 +15,7 @@ from frappe.database.utils import NestedSetHierarchy, is_pypika_function_object
from frappe.model.db_query import get_timespan_date_range
from frappe.query_builder import Criterion, Field, Order, Table, functions
from frappe.query_builder.functions import Function, SqlFunctions
from frappe.query_builder.utils import PseudoColumn
from frappe.utils import cstr
from frappe.query_builder.utils import PseudoColumnMapper
from frappe.utils.data import MARIADB_SPECIFIC_COMMENT
if TYPE_CHECKING:
@ -27,7 +26,6 @@ WORDS_PATTERN = re.compile(r"\w+")
BRACKETS_PATTERN = re.compile(r"\(.*?\)|$")
SQL_FUNCTIONS = [sql_function.value for sql_function in SqlFunctions]
COMMA_PATTERN = re.compile(r",\s*(?![^()]*\))")
TABLE_PATTERN = re.compile(r"`\btab\w+")
def like(key: Field, value: str) -> frappe.qb:
@ -166,8 +164,11 @@ def has_function(field):
def table_from_string(table: str) -> "DocType":
table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:]
return frappe.qb.DocType(table_name=table_name.replace("`", ""))
if frappe.db.db_type == "postgres":
table_name = table.split('"', maxsplit=1)[1].split(".")[0][3:].replace('"', "")
else:
table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:].replace("`", "")
return frappe.qb.DocType(table_name=table_name)
def get_nested_set_hierarchy_result(hierarchy: str, field: str, table: str):
@ -467,13 +468,15 @@ class Engine:
has_primitive_operator = True
field = operator_mapping(
*map(
lambda field: Field(field.strip()) if "`" not in field else PseudoColumn(field.strip()),
lambda field: Field(field.strip())
if "`" not in field
else PseudoColumnMapper(field.strip()),
arg.split(_operator),
),
)
field = (
(Field(initial_fields) if "`" not in initial_fields else PseudoColumn(initial_fields))
(Field(initial_fields) if "`" not in initial_fields else PseudoColumnMapper(initial_fields))
if not has_primitive_operator
else field
)
@ -590,11 +593,11 @@ class Engine:
if " as " in field:
field, reference = field.split(" as ")
if "`" in field:
updated_fields.append(PseudoColumn(f"{field} {reference}"))
updated_fields.append(PseudoColumnMapper(f"{field} {reference}"))
else:
updated_fields.append(Field(field.strip()).as_(reference))
elif "`" in str(field):
updated_fields.append(PseudoColumn(field.strip()))
updated_fields.append(PseudoColumnMapper(field.strip()))
else:
updated_fields.append(Field(field))
return updated_fields
@ -603,11 +606,11 @@ class Engine:
if fields == "*":
return fields
if "`" in fields:
fields = PseudoColumn(fields)
fields = PseudoColumnMapper(fields)
if " as " in str(fields):
fields, reference = str(fields).split(" as ")
if "`" in str(fields):
fields = PseudoColumn(f"{fields} as {reference}")
fields = PseudoColumnMapper(f"{fields} {reference}")
else:
fields = Field(fields).as_(reference)
return fields
@ -667,12 +670,15 @@ class Engine:
def join(self, criterion, fields, table, join_type):
"""Handles all join operations on criterion objects"""
has_join = False
table_pattern = (
re.compile(r"`\btab\w+") if frappe.db.db_type == "mariadb" else re.compile(r'"\btab\w+')
)
def _update_pypika_fields(field):
if not is_pypika_function_object(field):
field = field if isinstance(field, (str, PseudoColumn)) else field.get_sql()
if not TABLE_PATTERN.search(str(field)):
if isinstance(field, PseudoColumn):
field = field if isinstance(field, (str, PseudoColumnMapper)) else field.get_sql()
if not table_pattern.search(str(field)):
if isinstance(field, PseudoColumnMapper):
field = field.get_sql()
return getattr(frappe.qb.DocType(table), field)
else:
@ -686,8 +692,8 @@ class Engine:
# Only perform this bit if foreign doctype in fields
if (
not is_pypika_function_object(field)
and str(field).startswith("`tab")
and (f"`tab{table}`" not in str(field))
and (str(field).startswith('"tab') or str(field).startswith("`tab"))
and (f"`tab{table}`" not in str(field) and f'tab{table}"' not in str(field))
):
has_join = True
child_table = table_from_string(str(field))

View file

@ -12,6 +12,16 @@ from frappe.query_builder.terms import NamedParameterWrapper
from .builder import MariaDB, Postgres
class PseudoColumnMapper(PseudoColumn):
def __init__(self, name: str) -> None:
super().__init__(name)
def get_sql(self, **kwargs):
if frappe.db.db_type == "postgres":
self.name = self.name.replace("`", '"')
return self.name
class db_type_is(Enum):
MARIADB = "mariadb"
POSTGRES = "postgres"