feat(postgres): add unbuffered cursor in postgres (#35016)

* feat(postgres): add unbuffered cursor in postgres

* test: add unbuffered_cursor test for Postgres
This commit is contained in:
Aarol D'Souza 2025-12-23 17:21:28 +05:30 committed by GitHub
parent 9db33f6f24
commit 48c8ee9a78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 60 additions and 0 deletions

View file

@ -315,6 +315,12 @@ class Database:
if auto_commit:
self.commit()
if self.db_type == "postgres" and getattr(self._cursor, "name", None):
"""named cursors in Postgres are lazy and don't retrieve column names immediately,
so explicitly performed here to avoid early exit during `unbuffered_cursor` usage
"""
self._cursor.fetchmany(0)
if not self._cursor.description:
return ()

View file

@ -1,4 +1,5 @@
import re
from contextlib import contextmanager
import psycopg2
import psycopg2.extensions
@ -491,6 +492,23 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
count = self.sql("select reltuples from pg_class where relname = %s", table)
return cint(count[0][0]) if count else 0
@contextmanager
def unbuffered_cursor(self):
"""Unbuffered cursor in Postgres can only call .execute() once,
usage:
with frappe.db.unbuffered_cursor():
frappe.db.sql()
"""
try:
if not self._conn:
self.connect()
original_cursor = self._cursor
new_cursor = self._cursor = self._conn.cursor(name="ss_cursor")
yield
finally:
self._cursor = original_cursor
new_cursor.close()
def modify_query(query):
""" "Modifies query according to the requirements of postgres"""

View file

@ -1212,6 +1212,42 @@ class TestSqlIterator(IntegrationTestCase):
with frappe.db.unbuffered_cursor():
self.test_db_sql_iterator()
@run_only_if(db_type_is.POSTGRES)
def test_unbuffered_cursor_postgres(self):
test_queries = [
"select * from `tabCountry` order by name",
"select code from `tabCountry` order by name",
"select code from `tabCountry` order by name limit 5",
]
for query in test_queries:
with frappe.db.unbuffered_cursor():
iter_query_val = list(frappe.db.sql(query, as_dict=True, as_iterator=True))
query_val = frappe.db.sql(query, as_dict=True)
self.assertEqual(
query_val,
iter_query_val,
msg=f"{query=} results not same as iterator",
)
with frappe.db.unbuffered_cursor():
iter_query_val = list(frappe.db.sql(query, pluck=True, as_iterator=True))
query_val = frappe.db.sql(query, pluck=True)
self.assertEqual(
query_val,
iter_query_val,
msg=f"{query=} results not same as iterator",
)
with frappe.db.unbuffered_cursor():
iter_query_val = list(frappe.db.sql(query, as_list=True, as_iterator=True))
query_val = frappe.db.sql(query, as_list=True)
self.assertEqual(
query_val,
iter_query_val,
msg=f"{query=} results not same as iterator",
)
class ExtIntegrationTestCase(IntegrationTestCase):
def assertSqlException(self):