seitime-frappe/frappe/query_builder/functions.py
Aarol D'Souza 59b440cb28
fix(search): make QB DB-Aware when using Locate (#35796)
* fix: make QB DB-Aware when choosing Locate

* fix(test): adjust test to check smarter qb choice based on db
2026-01-12 12:06:28 +05:30

184 lines
4.7 KiB
Python

from datetime import time
from enum import Enum
from pypika.functions import *
from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function
import frappe
from frappe.query_builder.custom import (
GROUP_CONCAT,
MATCH,
STRING_AGG,
TO_TSVECTOR,
Month,
MonthName,
Quarter,
)
from frappe.query_builder.utils import ImportMapper, db_type_is
from .utils import PseudoColumn
class Concat_ws(Function):
def __init__(self, *terms, **kwargs):
super().__init__("CONCAT_WS", *terms, **kwargs)
class Locate(Function):
def __init__(self, needle, haystack, **kwargs):
super().__init__("LOCATE", needle, haystack, **kwargs)
class Strpos(Function):
def __init__(self, needle, haystack, **kwargs):
super().__init__("STRPOS", haystack, needle, **kwargs)
class Instr(Function):
def __init__(self, needle, haystack, **kwargs):
super().__init__("INSTR", haystack, needle, **kwargs)
Locate = ImportMapper({db_type_is.MARIADB: Locate, db_type_is.POSTGRES: Strpos, db_type_is.SQLITE: Instr})
# for backward compatibility
Ifnull = IfNull
class Timestamp(Function):
def __init__(self, term: str, time=None, alias=None):
if time:
super().__init__("TIMESTAMP", term, time, alias=alias)
else:
super().__init__("TIMESTAMP", term, alias=alias)
class Round(Function):
def __init__(self, term, decimal=0, **kwargs):
super().__init__("ROUND", term, decimal, **kwargs)
class Truncate(Function):
def __init__(self, term, decimal, **kwargs):
super().__init__("TRUNCATE", term, decimal, **kwargs)
GroupConcat = ImportMapper({db_type_is.MARIADB: GROUP_CONCAT, db_type_is.POSTGRES: STRING_AGG})
Match = ImportMapper({db_type_is.MARIADB: MATCH, db_type_is.POSTGRES: TO_TSVECTOR})
class _PostgresTimestamp(ArithmeticExpression):
def __init__(self, datepart, timepart, alias=None):
"""Postgres would need both datepart and timepart to be a string for concatenation"""
if isinstance(timepart, time) or isinstance(datepart, time):
timepart, datepart = str(timepart), str(datepart)
if isinstance(datepart, str):
datepart = Cast(datepart, "date")
if isinstance(timepart, str):
timepart = Cast(timepart, "time")
super().__init__(operator=Arithmetic.add, left=datepart, right=timepart, alias=alias)
CombineDatetime = ImportMapper(
{
db_type_is.MARIADB: CustomFunction("TIMESTAMP", ["date", "time"]),
db_type_is.POSTGRES: _PostgresTimestamp,
}
)
DateFormat = ImportMapper(
{
db_type_is.MARIADB: CustomFunction("DATE_FORMAT", ["date", "format"]),
db_type_is.POSTGRES: ToChar,
}
)
class YearWeek(Function):
def __init__(self, term):
super().__init__("YEARWEEK", term, 1)
class _PostgresUnixTimestamp(Extract):
# Note: this is just a special case of "Extract" function with "epoch" hardcoded.
# Check super definition to see how it works.
def __init__(self, field, alias=None):
super().__init__("epoch", field=field, alias=alias)
self.field = field
UnixTimestamp = ImportMapper(
{
db_type_is.MARIADB: CustomFunction("unix_timestamp", ["date"]),
db_type_is.POSTGRES: _PostgresUnixTimestamp,
}
)
class Cast_(Function):
def __init__(self, value, as_type, alias=None):
if frappe.db.db_type == "mariadb" and (
(hasattr(as_type, "get_sql") and as_type.get_sql().lower() == "varchar")
or str(as_type).lower() == "varchar"
):
# mimics varchar cast in mariadb
# as mariadb doesn't have varchar data cast
# https://mariadb.com/kb/en/cast/#description
# ref: https://stackoverflow.com/a/32542095
super().__init__("CONCAT", value, "", alias=alias)
else:
# from source: https://pypika.readthedocs.io/en/latest/_modules/pypika/functions.html#Cast
super().__init__("CAST", value, alias=alias)
self.as_type = as_type
def get_special_params_sql(self, **kwargs):
if self.name.lower() == "cast":
type_sql = (
self.as_type.get_sql(**kwargs)
if hasattr(self.as_type, "get_sql")
else str(self.as_type).upper()
)
return f"AS {type_sql}"
def _aggregate(function, dt, fieldname, filters, **kwargs):
return (
frappe.qb.get_query(dt, filters=filters, fields=[function(PseudoColumn(fieldname))]).run(**kwargs)[0][
0
]
or 0
)
class SqlFunctions(Enum):
DayOfYear = "dayofyear"
Extract = "extract"
Locate = "locate"
Count = "count"
Sum = "sum"
Avg = "avg"
Max = "max"
Min = "min"
Abs = "abs"
Timestamp = "timestamp"
IfNull = "ifnull"
def _max(dt, fieldname, filters=None, **kwargs):
return _aggregate(Max, dt, fieldname, filters, **kwargs)
def _min(dt, fieldname, filters=None, **kwargs):
return _aggregate(Min, dt, fieldname, filters, **kwargs)
def _avg(dt, fieldname, filters=None, **kwargs):
return _aggregate(Avg, dt, fieldname, filters, **kwargs)
def _sum(dt, fieldname, filters=None, **kwargs):
return _aggregate(Sum, dt, fieldname, filters, **kwargs)