feat: Added support for aliasing in function objects

This commit is contained in:
Aradhya 2022-07-09 22:27:29 +05:30
parent 2fac4f8074
commit 613065fa2e
2 changed files with 47 additions and 6 deletions

View file

@ -186,7 +186,7 @@ OPERATOR_MAP: dict[str, Callable] = {
class Engine:
tables: dict = {}
tables: dict[str, str] = {}
@cached_property
def OPERATOR_MAP(self):
@ -384,6 +384,7 @@ class Engine:
args_start, args_end = len(func) + 1, field.index(")")
args = field[args_start:args_end].split(",")
_, alias = field.split(" as ") if " as " in field else (None, None)
to_cast = "*" not in args
_args = []
@ -406,14 +407,15 @@ class Engine:
field = initial_fields
_args.append(field)
return getattr(functions, func)(*_args)
return getattr(functions, func)(*_args, alias=alias or None)
def function_objects_from_string(self, fields):
functions = ""
for func in SQL_FUNCTIONS:
if f"{func}(" in fields:
_, alias = fields.split(" as ") if " as " in fields else ("", "")
functions = str(func) + str(BRACKETS_PATTERN.search(fields).group())
functions += " as " + alias
return [self.get_function_object(functions)]
if not functions:
return []
@ -432,14 +434,25 @@ class Engine:
"""Remove string functions from fields which have already been converted to function objects"""
for function in function_objects:
if isinstance(fields, str):
has_alias = False
if function.alias:
has_alias = True
fields = BRACKETS_PATTERN.sub("", fields.replace(function.name.casefold(), ""))
if has_alias:
fields = fields.replace(" as " + function.alias.casefold(), "")
else:
updated_fields = []
for field in fields:
has_alias = False
if function.alias:
has_alias = True
if isinstance(field, str):
updated_fields.append(
_field = (
BRACKETS_PATTERN.sub("", field).strip().casefold().replace(function.name.casefold(), "")
)
if has_alias:
_field = _field.replace(" as " + function.alias.casefold(), "")
updated_fields.append(_field)
else:
updated_fields.append(field)

View file

@ -2,6 +2,7 @@ import unittest
import frappe
from frappe.query_builder import Field
from frappe.query_builder.functions import Abs, Count, Max, Timestamp
from frappe.tests.test_query_builder import db_type_is, run_only_if
@ -42,8 +43,6 @@ class TestQuery(unittest.TestCase):
)
def test_functions_fields(self):
from frappe.query_builder.functions import Abs, Count, Max, Timestamp
self.assertEqual(
frappe.qb.engine.get_query("User", fields="Count(name)", filters={}).get_sql(),
frappe.qb.from_("User").select(Count(Field("name"))).get_sql(),
@ -88,3 +87,32 @@ class TestQuery(unittest.TestCase):
frappe.qb.engine.get_query(user_doctype, fields=user_doctype.email, filters={}).get_sql(),
frappe.qb.from_(user_doctype).select(user_doctype.email).get_sql(),
)
def test_aliasing(self):
user_doctype = frappe.qb.DocType("User")
self.assertEqual(
frappe.qb.engine.get_query(
user_doctype, fields=["name as owner", "email as id"], filters={}
).get_sql(),
frappe.qb.from_(user_doctype)
.select(user_doctype.name.as_("owner"), user_doctype.email.as_("id"))
.get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(
user_doctype, fields="name as owner, email as id", filters={}
).get_sql(),
frappe.qb.from_(user_doctype)
.select(user_doctype.name.as_("owner"), user_doctype.email.as_("id"))
.get_sql(),
)
self.assertEqual(
frappe.qb.engine.get_query(
user_doctype, fields=["Count(name) as c", "email as id"], filters={}
).get_sql(),
frappe.qb.from_(user_doctype)
.select(user_doctype.email.as_("id"), Count(Field("name")).as_("c"))
.get_sql(),
)