feat: joining on tables mentioned in fields

This commit is contained in:
Aradhya 2022-08-10 01:18:17 +05:30
parent 2ed0e1e648
commit f4eaa4a481
2 changed files with 30 additions and 3 deletions

View file

@ -1,8 +1,6 @@
import operator
import re
import sys
from ast import literal_eval
from fileinput import filename
from functools import cached_property
from types import BuiltinFunctionType
from typing import Any, Callable
@ -13,6 +11,8 @@ from frappe.model.db_query import get_timespan_date_range
from frappe.query_builder import Criterion, Field, Order, Table, functions
from frappe.query_builder.functions import Function, SqlFunctions
from frappe.query_builder.utils import PseudoColumn
from frappe.database.utils import table_from_string
from pypika import functions as f
TAB_PATTERN = re.compile("^tab")
WORDS_PATTERN = re.compile(r"\w+")
@ -503,7 +503,7 @@ class Engine:
return fields
def set_fields(self, table, fields, **kwargs):
def set_fields(self, table, fields, **kwargs) -> list:
fields = kwargs.get("pluck") if kwargs.get("pluck") else fields or "name"
if isinstance(fields, list) and None in fields and Field not in fields:
return None
@ -586,6 +586,23 @@ class Engine:
table, kwargs.get("field_objects") or fields, **kwargs
)
criterion = self.build_conditions(table, filters, **kwargs)
joined = False
for field in fields:
if "tab" in str(field):
join_on = table_from_string(str(field))
criterion = criterion.left_join(join_on).on(join_on.parent == getattr(frappe.qb.DocType(table), "name"))
joined = True
if joined:
# Converting all fields to avoid ambiguity.
for field in fields:
if not getattr(field, '__module__', None) == f.__name__:
field = field if isinstance(field, str) else field.get_sql()
fields[fields.index(field)] = getattr(frappe.qb.DocType(table), field)
else:
field.args = [getattr(frappe.qb.DocType(table), arg.get_sql()) for arg in field.args]
field.args[0] = getattr(frappe.qb.DocType(table), field.args[0].get_sql())
fields[fields.index(field)] = field
if self.linked_doctype and self.fieldname:
for field in fields:
if "tab" not in str(field):

View file

@ -3,10 +3,14 @@
from functools import cached_property
from types import NoneType
import typing
import frappe
from frappe.query_builder.builder import MariaDB, Postgres
if typing.TYPE_CHECKING:
from frappe.query_builder import DocType
Query = str | MariaDB | Postgres
QueryValues = tuple | list | dict | NoneType
@ -17,6 +21,12 @@ FallBackDateTimeStr = "0001-01-01 00:00:00.000000"
def is_query_type(query: str, query_type: str | tuple[str]) -> bool:
return query.lstrip().split(maxsplit=1)[0].lower().startswith(query_type)
def table_from_string(table: str) -> "DocType":
table_name = table.split("`", maxsplit=1)[1].split(".")[0][3:]
if "`" in table_name:
return frappe.qb.DocType(table_name=table_name.replace("`", ""))
else:
return frappe.qb.DocType(table_name=table_name)
class LazyString:
def _setup(self) -> None: