diff --git a/frappe/database/database.py b/frappe/database/database.py index 58f48f0477..6fb7e427d8 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -315,6 +315,12 @@ class Database: if auto_commit: self.commit() + if self.db_type == "postgres" and getattr(self._cursor, "name", None): + """named cursors in Postgres are lazy and don't retrieve column names immediately, + so explicitly performed here to avoid early exit during `unbuffered_cursor` usage + """ + self._cursor.fetchmany(0) + if not self._cursor.description: return () diff --git a/frappe/database/postgres/database.py b/frappe/database/postgres/database.py index df0ec7dee3..68662aa0e5 100644 --- a/frappe/database/postgres/database.py +++ b/frappe/database/postgres/database.py @@ -1,4 +1,5 @@ import re +from contextlib import contextmanager import psycopg2 import psycopg2.extensions @@ -491,6 +492,23 @@ class PostgresDatabase(PostgresExceptionUtil, Database): count = self.sql("select reltuples from pg_class where relname = %s", table) return cint(count[0][0]) if count else 0 + @contextmanager + def unbuffered_cursor(self): + """Unbuffered cursor in Postgres can only call .execute() once, + usage: + with frappe.db.unbuffered_cursor(): + frappe.db.sql() + """ + try: + if not self._conn: + self.connect() + original_cursor = self._cursor + new_cursor = self._cursor = self._conn.cursor(name="ss_cursor") + yield + finally: + self._cursor = original_cursor + new_cursor.close() + def modify_query(query): """ "Modifies query according to the requirements of postgres""" diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index 14ece934d9..9b95f6ec2a 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -1212,6 +1212,42 @@ class TestSqlIterator(IntegrationTestCase): with frappe.db.unbuffered_cursor(): self.test_db_sql_iterator() + @run_only_if(db_type_is.POSTGRES) + def test_unbuffered_cursor_postgres(self): + test_queries = [ + "select * from `tabCountry` order by name", + "select code from `tabCountry` order by name", + "select code from `tabCountry` order by name limit 5", + ] + + for query in test_queries: + with frappe.db.unbuffered_cursor(): + iter_query_val = list(frappe.db.sql(query, as_dict=True, as_iterator=True)) + query_val = frappe.db.sql(query, as_dict=True) + self.assertEqual( + query_val, + iter_query_val, + msg=f"{query=} results not same as iterator", + ) + + with frappe.db.unbuffered_cursor(): + iter_query_val = list(frappe.db.sql(query, pluck=True, as_iterator=True)) + query_val = frappe.db.sql(query, pluck=True) + self.assertEqual( + query_val, + iter_query_val, + msg=f"{query=} results not same as iterator", + ) + + with frappe.db.unbuffered_cursor(): + iter_query_val = list(frappe.db.sql(query, as_list=True, as_iterator=True)) + query_val = frappe.db.sql(query, as_list=True) + self.assertEqual( + query_val, + iter_query_val, + msg=f"{query=} results not same as iterator", + ) + class ExtIntegrationTestCase(IntegrationTestCase): def assertSqlException(self):