diff --git a/frappe/query_builder/functions.py b/frappe/query_builder/functions.py index 98cd501be2..1776abb029 100644 --- a/frappe/query_builder/functions.py +++ b/frappe/query_builder/functions.py @@ -2,7 +2,7 @@ from datetime import time from enum import Enum from pypika.functions import * -from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function +from pypika.terms import Arithmetic, ArithmeticExpression, CustomFunction, Function, Term import frappe from frappe.query_builder.custom import ( @@ -118,6 +118,47 @@ UnixTimestamp = ImportMapper( ) +class _MariaDBJSONExtract(Function): + def __init__(self, field, path, **kwargs): + super().__init__("JSON_EXTRACT", field, path, **kwargs) + + +class _MariaDBJSONValue(Function): + def __init__(self, field, path, **kwargs): + super().__init__("JSON_UNQUOTE", _MariaDBJSONExtract(field, path), **kwargs) + + +class _MariaDBJSONContains(Function): + def __init__(self, target, candidate, **kwargs): + from pypika.terms import JSON + + if not isinstance(candidate, Term): + candidate = JSON(candidate) + super().__init__("JSON_CONTAINS", target, candidate, **kwargs) + + +JSONExtract = ImportMapper( + { + db_type_is.MARIADB: _MariaDBJSONExtract, + db_type_is.POSTGRES: lambda field, path, **kw: field.get_json_value(path), + } +) + +JSONValue = ImportMapper( + { + db_type_is.MARIADB: _MariaDBJSONValue, + db_type_is.POSTGRES: lambda field, path, **kw: field.get_text_value(path), + } +) + +JSONContains = ImportMapper( + { + db_type_is.MARIADB: _MariaDBJSONContains, + db_type_is.POSTGRES: lambda target, candidate, **kw: target.contains(candidate), + } +) + + class Cast_(Function): def __init__(self, value, as_type, alias=None): if frappe.db.db_type == "mariadb" and ( diff --git a/frappe/tests/test_query_builder.py b/frappe/tests/test_query_builder.py index f0df927238..0d7bcae136 100644 --- a/frappe/tests/test_query_builder.py +++ b/frappe/tests/test_query_builder.py @@ -14,6 +14,9 @@ from frappe.query_builder.functions import ( CombineDatetime, Date, GroupConcat, + JSONContains, + JSONExtract, + JSONValue, Match, Round, Truncate, @@ -176,6 +179,43 @@ class TestCustomFunctionsMariaDB(IntegrationTestCase): query = frappe.qb.from_(note).select(Truncate(note.price, 3)) self.assertEqual("select truncate(`price`,3) from `tabnote`", str(query).lower()) + def test_json_extract(self): + note = frappe.qb.DocType("Note") + # Simple get_sql + self.assertEqual("JSON_EXTRACT(content,'$.key')", JSONExtract(note.content, "$.key").get_sql()) + + # In a SELECT query + query = frappe.qb.from_(note).select(JSONExtract(note.content, "$.key")) + self.assertIn("json_extract(`content`,'$.key')", str(query).lower()) + + # In a WHERE clause + query = frappe.qb.from_(note).select(note.name).where(JSONExtract(note.content, "$.key") == "value") + self.assertIn("json_extract(`content`,'$.key')='value'", str(query).lower()) + + def test_json_value(self): + note = frappe.qb.DocType("Note") + # Simple get_sql + self.assertEqual( + "JSON_UNQUOTE(JSON_EXTRACT(content,'$.key'))", JSONValue(note.content, "$.key").get_sql() + ) + + # In a SELECT query + query = frappe.qb.from_(note).select(JSONValue(note.content, "$.key")) + self.assertIn("json_unquote(json_extract(`content`,'$.key'))", str(query).lower()) + + # In a WHERE clause + query = frappe.qb.from_(note).select(note.name).where(JSONValue(note.content, "$.key") == "value") + self.assertIn("json_unquote(json_extract(`content`,'$.key'))='value'", str(query).lower()) + + def test_json_contains(self): + note = frappe.qb.DocType("Note") + # With a plain string candidate (auto-wrapped as JSON) + self.assertEqual("JSON_CONTAINS(content,'\"value\"')", JSONContains(note.content, "value").get_sql()) + + # In a WHERE clause + query = frappe.qb.from_(note).select(note.name).where(JSONContains(note.content, "admin")) + self.assertIn("json_contains(`content`,'\"admin\"')", str(query).lower()) + @run_only_if(db_type_is.POSTGRES) class TestCustomFunctionsPostgres(IntegrationTestCase): @@ -313,6 +353,41 @@ class TestCustomFunctionsPostgres(IntegrationTestCase): query = frappe.qb.from_(note).select(Truncate(note.price, 3)) self.assertEqual('select truncate("price",3) from "tabnote"', str(query).lower()) + def test_json_extract(self): + note = frappe.qb.DocType("Note") + # Simple get_sql + self.assertEqual("\"content\"->'$.key'", JSONExtract(note.content, "$.key").get_sql()) + + # In a SELECT query + query = frappe.qb.from_(note).select(JSONExtract(note.content, "$.key")) + self.assertIn("\"content\"->'$.key'", str(query)) + + # In a WHERE clause + query = frappe.qb.from_(note).select(note.name).where(JSONExtract(note.content, "$.key") == "value") + self.assertIn("\"content\"->'$.key'='value'", str(query)) + + def test_json_value(self): + note = frappe.qb.DocType("Note") + # Simple get_sql + self.assertEqual("\"content\"->>'$.key'", JSONValue(note.content, "$.key").get_sql()) + + # In a SELECT query + query = frappe.qb.from_(note).select(JSONValue(note.content, "$.key")) + self.assertIn("\"content\"->>'$.key'", str(query)) + + # In a WHERE clause + query = frappe.qb.from_(note).select(note.name).where(JSONValue(note.content, "$.key") == "value") + self.assertIn("\"content\"->>'$.key'='value'", str(query)) + + def test_json_contains(self): + note = frappe.qb.DocType("Note") + # With a plain string candidate + self.assertEqual("\"content\"@>'admin'", JSONContains(note.content, "admin").get_sql()) + + # In a WHERE clause + query = frappe.qb.from_(note).select(note.name).where(JSONContains(note.content, "admin")) + self.assertIn("\"content\"@>'admin'", str(query)) + class TestBuilderBase: def test_adding_tabs(self):