feat(postgres): add unbuffered cursor in postgres (#35016)
* feat(postgres): add unbuffered cursor in postgres * test: add unbuffered_cursor test for Postgres
This commit is contained in:
parent
9db33f6f24
commit
48c8ee9a78
3 changed files with 60 additions and 0 deletions
|
|
@ -315,6 +315,12 @@ class Database:
|
|||
if auto_commit:
|
||||
self.commit()
|
||||
|
||||
if self.db_type == "postgres" and getattr(self._cursor, "name", None):
|
||||
"""named cursors in Postgres are lazy and don't retrieve column names immediately,
|
||||
so explicitly performed here to avoid early exit during `unbuffered_cursor` usage
|
||||
"""
|
||||
self._cursor.fetchmany(0)
|
||||
|
||||
if not self._cursor.description:
|
||||
return ()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
from contextlib import contextmanager
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
|
|
@ -491,6 +492,23 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
|
|||
count = self.sql("select reltuples from pg_class where relname = %s", table)
|
||||
return cint(count[0][0]) if count else 0
|
||||
|
||||
@contextmanager
|
||||
def unbuffered_cursor(self):
|
||||
"""Unbuffered cursor in Postgres can only call .execute() once,
|
||||
usage:
|
||||
with frappe.db.unbuffered_cursor():
|
||||
frappe.db.sql()
|
||||
"""
|
||||
try:
|
||||
if not self._conn:
|
||||
self.connect()
|
||||
original_cursor = self._cursor
|
||||
new_cursor = self._cursor = self._conn.cursor(name="ss_cursor")
|
||||
yield
|
||||
finally:
|
||||
self._cursor = original_cursor
|
||||
new_cursor.close()
|
||||
|
||||
|
||||
def modify_query(query):
|
||||
""" "Modifies query according to the requirements of postgres"""
|
||||
|
|
|
|||
|
|
@ -1212,6 +1212,42 @@ class TestSqlIterator(IntegrationTestCase):
|
|||
with frappe.db.unbuffered_cursor():
|
||||
self.test_db_sql_iterator()
|
||||
|
||||
@run_only_if(db_type_is.POSTGRES)
|
||||
def test_unbuffered_cursor_postgres(self):
|
||||
test_queries = [
|
||||
"select * from `tabCountry` order by name",
|
||||
"select code from `tabCountry` order by name",
|
||||
"select code from `tabCountry` order by name limit 5",
|
||||
]
|
||||
|
||||
for query in test_queries:
|
||||
with frappe.db.unbuffered_cursor():
|
||||
iter_query_val = list(frappe.db.sql(query, as_dict=True, as_iterator=True))
|
||||
query_val = frappe.db.sql(query, as_dict=True)
|
||||
self.assertEqual(
|
||||
query_val,
|
||||
iter_query_val,
|
||||
msg=f"{query=} results not same as iterator",
|
||||
)
|
||||
|
||||
with frappe.db.unbuffered_cursor():
|
||||
iter_query_val = list(frappe.db.sql(query, pluck=True, as_iterator=True))
|
||||
query_val = frappe.db.sql(query, pluck=True)
|
||||
self.assertEqual(
|
||||
query_val,
|
||||
iter_query_val,
|
||||
msg=f"{query=} results not same as iterator",
|
||||
)
|
||||
|
||||
with frappe.db.unbuffered_cursor():
|
||||
iter_query_val = list(frappe.db.sql(query, as_list=True, as_iterator=True))
|
||||
query_val = frappe.db.sql(query, as_list=True)
|
||||
self.assertEqual(
|
||||
query_val,
|
||||
iter_query_val,
|
||||
msg=f"{query=} results not same as iterator",
|
||||
)
|
||||
|
||||
|
||||
class ExtIntegrationTestCase(IntegrationTestCase):
|
||||
def assertSqlException(self):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue