* fix: make QB DB-Aware when choosing Locate * fix(test): adjust test to check smarter qb choice based on db
184 lines
4.7 KiB
Python
184 lines
4.7 KiB
Python
from datetime import time
|
|
from enum import Enum
|
|
|
|
from pypika.functions import *
|
|
from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function
|
|
|
|
import frappe
|
|
from frappe.query_builder.custom import (
|
|
GROUP_CONCAT,
|
|
MATCH,
|
|
STRING_AGG,
|
|
TO_TSVECTOR,
|
|
Month,
|
|
MonthName,
|
|
Quarter,
|
|
)
|
|
from frappe.query_builder.utils import ImportMapper, db_type_is
|
|
|
|
from .utils import PseudoColumn
|
|
|
|
|
|
class Concat_ws(Function):
|
|
def __init__(self, *terms, **kwargs):
|
|
super().__init__("CONCAT_WS", *terms, **kwargs)
|
|
|
|
|
|
class Locate(Function):
|
|
def __init__(self, needle, haystack, **kwargs):
|
|
super().__init__("LOCATE", needle, haystack, **kwargs)
|
|
|
|
|
|
class Strpos(Function):
|
|
def __init__(self, needle, haystack, **kwargs):
|
|
super().__init__("STRPOS", haystack, needle, **kwargs)
|
|
|
|
|
|
class Instr(Function):
|
|
def __init__(self, needle, haystack, **kwargs):
|
|
super().__init__("INSTR", haystack, needle, **kwargs)
|
|
|
|
|
|
Locate = ImportMapper({db_type_is.MARIADB: Locate, db_type_is.POSTGRES: Strpos, db_type_is.SQLITE: Instr})
|
|
|
|
|
|
# for backward compatibility
|
|
Ifnull = IfNull
|
|
|
|
|
|
class Timestamp(Function):
|
|
def __init__(self, term: str, time=None, alias=None):
|
|
if time:
|
|
super().__init__("TIMESTAMP", term, time, alias=alias)
|
|
else:
|
|
super().__init__("TIMESTAMP", term, alias=alias)
|
|
|
|
|
|
class Round(Function):
|
|
def __init__(self, term, decimal=0, **kwargs):
|
|
super().__init__("ROUND", term, decimal, **kwargs)
|
|
|
|
|
|
class Truncate(Function):
|
|
def __init__(self, term, decimal, **kwargs):
|
|
super().__init__("TRUNCATE", term, decimal, **kwargs)
|
|
|
|
|
|
GroupConcat = ImportMapper({db_type_is.MARIADB: GROUP_CONCAT, db_type_is.POSTGRES: STRING_AGG})
|
|
|
|
Match = ImportMapper({db_type_is.MARIADB: MATCH, db_type_is.POSTGRES: TO_TSVECTOR})
|
|
|
|
|
|
class _PostgresTimestamp(ArithmeticExpression):
|
|
def __init__(self, datepart, timepart, alias=None):
|
|
"""Postgres would need both datepart and timepart to be a string for concatenation"""
|
|
if isinstance(timepart, time) or isinstance(datepart, time):
|
|
timepart, datepart = str(timepart), str(datepart)
|
|
if isinstance(datepart, str):
|
|
datepart = Cast(datepart, "date")
|
|
if isinstance(timepart, str):
|
|
timepart = Cast(timepart, "time")
|
|
|
|
super().__init__(operator=Arithmetic.add, left=datepart, right=timepart, alias=alias)
|
|
|
|
|
|
CombineDatetime = ImportMapper(
|
|
{
|
|
db_type_is.MARIADB: CustomFunction("TIMESTAMP", ["date", "time"]),
|
|
db_type_is.POSTGRES: _PostgresTimestamp,
|
|
}
|
|
)
|
|
|
|
DateFormat = ImportMapper(
|
|
{
|
|
db_type_is.MARIADB: CustomFunction("DATE_FORMAT", ["date", "format"]),
|
|
db_type_is.POSTGRES: ToChar,
|
|
}
|
|
)
|
|
|
|
|
|
class YearWeek(Function):
|
|
def __init__(self, term):
|
|
super().__init__("YEARWEEK", term, 1)
|
|
|
|
|
|
class _PostgresUnixTimestamp(Extract):
|
|
# Note: this is just a special case of "Extract" function with "epoch" hardcoded.
|
|
# Check super definition to see how it works.
|
|
def __init__(self, field, alias=None):
|
|
super().__init__("epoch", field=field, alias=alias)
|
|
self.field = field
|
|
|
|
|
|
UnixTimestamp = ImportMapper(
|
|
{
|
|
db_type_is.MARIADB: CustomFunction("unix_timestamp", ["date"]),
|
|
db_type_is.POSTGRES: _PostgresUnixTimestamp,
|
|
}
|
|
)
|
|
|
|
|
|
class Cast_(Function):
|
|
def __init__(self, value, as_type, alias=None):
|
|
if frappe.db.db_type == "mariadb" and (
|
|
(hasattr(as_type, "get_sql") and as_type.get_sql().lower() == "varchar")
|
|
or str(as_type).lower() == "varchar"
|
|
):
|
|
# mimics varchar cast in mariadb
|
|
# as mariadb doesn't have varchar data cast
|
|
# https://mariadb.com/kb/en/cast/#description
|
|
|
|
# ref: https://stackoverflow.com/a/32542095
|
|
super().__init__("CONCAT", value, "", alias=alias)
|
|
else:
|
|
# from source: https://pypika.readthedocs.io/en/latest/_modules/pypika/functions.html#Cast
|
|
super().__init__("CAST", value, alias=alias)
|
|
self.as_type = as_type
|
|
|
|
def get_special_params_sql(self, **kwargs):
|
|
if self.name.lower() == "cast":
|
|
type_sql = (
|
|
self.as_type.get_sql(**kwargs)
|
|
if hasattr(self.as_type, "get_sql")
|
|
else str(self.as_type).upper()
|
|
)
|
|
return f"AS {type_sql}"
|
|
|
|
|
|
def _aggregate(function, dt, fieldname, filters, **kwargs):
|
|
return (
|
|
frappe.qb.get_query(dt, filters=filters, fields=[function(PseudoColumn(fieldname))]).run(**kwargs)[0][
|
|
0
|
|
]
|
|
or 0
|
|
)
|
|
|
|
|
|
class SqlFunctions(Enum):
|
|
DayOfYear = "dayofyear"
|
|
Extract = "extract"
|
|
Locate = "locate"
|
|
Count = "count"
|
|
Sum = "sum"
|
|
Avg = "avg"
|
|
Max = "max"
|
|
Min = "min"
|
|
Abs = "abs"
|
|
Timestamp = "timestamp"
|
|
IfNull = "ifnull"
|
|
|
|
|
|
def _max(dt, fieldname, filters=None, **kwargs):
|
|
return _aggregate(Max, dt, fieldname, filters, **kwargs)
|
|
|
|
|
|
def _min(dt, fieldname, filters=None, **kwargs):
|
|
return _aggregate(Min, dt, fieldname, filters, **kwargs)
|
|
|
|
|
|
def _avg(dt, fieldname, filters=None, **kwargs):
|
|
return _aggregate(Avg, dt, fieldname, filters, **kwargs)
|
|
|
|
|
|
def _sum(dt, fieldname, filters=None, **kwargs):
|
|
return _aggregate(Sum, dt, fieldname, filters, **kwargs)
|