From b7c0ba1beaa2cadc5f6bca52ec4e0943c47681d0 Mon Sep 17 00:00:00 2001 From: Faris Ansari Date: Sat, 31 Dec 2022 22:55:00 +0530 Subject: [PATCH] fix: allow dynamic fields in filters e.g., `filters={'link.field': 'value'}` `filters={'child.field': 'value'}` --- frappe/database/query.py | 53 +++++++++++++++++++++++++++++++------- frappe/tests/test_query.py | 43 +++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 10 deletions(-) diff --git a/frappe/database/query.py b/frappe/database/query.py index b9e60043a0..3593de7c74 100644 --- a/frappe/database/query.py +++ b/frappe/database/query.py @@ -243,7 +243,7 @@ class Engine: for field in self.fields: if isinstance(field, DynamicTableField): - self.query = field.apply(self.query) + self.query = field.apply_select(self.query) else: self.query = self.query.select(field) @@ -318,18 +318,25 @@ class Engine: _value = value _operator = operator - if has_function(field): + if dynamic_field := DynamicTableField.parse(field, self.doctype): + # apply implicit join if link field's field is referenced + self.query = dynamic_field.apply_join(self.query) + _field = dynamic_field.field + elif has_function(field): _field = self.get_function_object(field) elif not doctype or doctype == self.doctype: _field = self.table[field] elif doctype: _field = self.get_table(doctype)[field] - # keep track of implicit join if child table is referenced + # apply implicit join if child table is referenced if doctype and doctype != self.doctype: meta = frappe.get_meta(doctype) - if meta.istable: - self.implicit_joins.add((doctype, "child")) + table = self.get_table(doctype) + if meta.istable and not self.query.is_joined(table): + self.query = self.query.left_join(table).on( + (table.parent == self.table.name) & (table.parenttype == self.doctype) + ) if isinstance(_value, (str, int)): _value = str(_value) @@ -586,19 +593,38 @@ class DynamicTableField: elif linked_field.fieldtype in frappe.model.table_fields: return ChildTableField(linked_doctype, fieldname, doctype, alias=alias) - def apply(self, query: QueryBuilder) -> QueryBuilder: + def apply_select(self, query: QueryBuilder) -> QueryBuilder: raise NotImplementedError class ChildTableField(DynamicTableField): - def apply(self, query: QueryBuilder) -> QueryBuilder: + def __init__( + self, + doctype: str, + fieldname: str, + parent_doctype: str, + alias: str | None = None, + ) -> None: + self.doctype = doctype + self.fieldname = fieldname + self.alias = alias + self.parent_doctype = parent_doctype + self.table = frappe.qb.DocType(self.doctype) + self.field = self.table[self.fieldname] + + def apply_select(self, query: QueryBuilder) -> QueryBuilder: + table = frappe.qb.DocType(self.doctype) + query = self.apply_join(query) + return query.select(getattr(table, self.fieldname).as_(self.alias or None)) + + def apply_join(self, query: QueryBuilder) -> QueryBuilder: table = frappe.qb.DocType(self.doctype) main_table = frappe.qb.DocType(self.parent_doctype) if not query.is_joined(table): query = query.left_join(table).on( (table.parent == main_table.name) & (table.parenttype == self.parent_doctype) ) - return query.select(getattr(table, self.fieldname).as_(self.alias or None)) + return query class LinkTableField(DynamicTableField): @@ -612,10 +638,17 @@ class LinkTableField(DynamicTableField): ) -> None: super().__init__(doctype, fieldname, parent_doctype, alias=alias) self.link_fieldname = link_fieldname + self.table = frappe.qb.DocType(self.doctype) + self.field = self.table[self.fieldname] - def apply(self, query: QueryBuilder) -> QueryBuilder: + def apply_select(self, query: QueryBuilder) -> QueryBuilder: + table = frappe.qb.DocType(self.doctype) + query = self.apply_join(query) + return query.select(getattr(table, self.fieldname).as_(self.alias or None)) + + def apply_join(self, query: QueryBuilder) -> QueryBuilder: table = frappe.qb.DocType(self.doctype) main_table = frappe.qb.DocType(self.parent_doctype) if not query.is_joined(table): query = query.left_join(table).on(table.name == getattr(main_table, self.link_fieldname)) - return query.select(getattr(table, self.fieldname).as_(self.alias or None)) + return query diff --git a/frappe/tests/test_query.py b/frappe/tests/test_query.py index 486bf9fe49..12cb6446d2 100644 --- a/frappe/tests/test_query.py +++ b/frappe/tests/test_query.py @@ -225,6 +225,39 @@ class TestQuery(FrappeTestCase): frappe.qb.from_("User").select(Max(Field("name"))).where(Ifnull("name", "") < Now()).run(), ) + self.assertEqual( + frappe.qb.get_query( + "DocType", + fields=["name"], + filters={"module.app_name": "frappe"}, + ).get_sql(), + "SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name`='frappe'".replace( + "`", '"' if frappe.db.db_type == "postgres" else "`" + ), + ) + + self.assertEqual( + frappe.qb.get_query( + "DocType", + fields=["name"], + filters={"module.app_name": ("like", "frap%")}, + ).get_sql(), + "SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module` WHERE `tabModule Def`.`app_name` LIKE 'frap%'".replace( + "`", '"' if frappe.db.db_type == "postgres" else "`" + ), + ) + + self.assertEqual( + frappe.qb.get_query( + "DocType", + fields=["name"], + filters={"permissions.role": "System Manager"}, + ).get_sql(), + "SELECT `tabDocType`.`name` FROM `tabDocType` LEFT JOIN `tabDocPerm` ON `tabDocPerm`.`parent`=`tabDocType`.`name` AND `tabDocPerm`.`parenttype`='DocType' WHERE `tabDocPerm`.`role`='System Manager'".replace( + "`", '"' if frappe.db.db_type == "postgres" else "`" + ), + ) + def test_implicit_join_query(self): self.maxDiff = None @@ -261,6 +294,16 @@ class TestQuery(FrappeTestCase): ), ) + self.assertEqual( + frappe.qb.get_query( + "DocType", + fields=["name", "module.app_name as app_name"], + ).get_sql(), + "SELECT `tabDocType`.`name`,`tabModule Def`.`app_name` `app_name` FROM `tabDocType` LEFT JOIN `tabModule Def` ON `tabModule Def`.`name`=`tabDocType`.`module`".replace( + "`", '"' if frappe.db.db_type == "postgres" else "`" + ), + ) + @run_only_if(db_type_is.MARIADB) def test_comment_stripping(self): self.assertNotIn(