From abba28be3b1b2fdb2010dcdc07ecc6556d32a71b Mon Sep 17 00:00:00 2001 From: Philipp Gruener Date: Fri, 5 Jul 2024 12:02:04 +0200 Subject: [PATCH] feat: Added env db options for db, password and pg_schema --- frappe/__init__.py | 31 +++++++++------- frappe/tests/test_db.py | 81 +++++++++++++++++++++++++++++++---------- 2 files changed, 79 insertions(+), 33 deletions(-) diff --git a/frappe/__init__.py b/frappe/__init__.py index 72616ee4c9..e911d538e7 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -339,23 +339,17 @@ def connect(site: str | None = None, db_name: str | None = None, set_admin_as_us "Instead, explicitly invoke frappe.init(site) with the right config prior to calling frappe.connect(), if necessary." ) - db_name = db_name or os.getenv("DB_NAME") or local.conf.db_name or local.conf.db_user - user = os.getenv("DB_USER") or local.conf.db_user or db_name - host = os.getenv("DB_HOST") or local.conf.db_host - port = os.getenv("DB_PORT") or local.conf.db_port - password = os.getenv("DB_PASSWORD") or local.conf.db_password - - assert user, "site must be fully initialized, db_user missing" - assert db_name, "site must be fully initialized, db_name missing" - assert password, "site must be fully initialized, db_password missing" + assert db_name or local.conf.db_user, "site must be fully initialized, db_user missing" + assert db_name or local.conf.db_name, "site must be fully initialized, db_name missing" + assert local.conf.db_password, "site must be fully initialized, db_password missing" local.db = get_db( socket=local.conf.db_socket, - host=host, - port=port, - user=user, - password=password, - cur_db_name=db_name, + host=local.conf.db_host, + port=local.conf.db_port, + user=local.conf.db_user or db_name, + password=local.conf.db_password, + cur_db_name=local.conf.db_name or db_name, ) if set_admin_as_user: set_user("Administrator") @@ -448,6 +442,15 @@ def get_site_config(sites_path: str | None = None, site_path: str | None = None) # Set the user as database name if not set in config config["db_user"] = os.environ.get("FRAPPE_DB_USER") or config.get("db_user") or config.get("db_name") + # vice versa for dbname if not defined + config["db_name"] = os.environ.get("FRAPPE_DB_NAME") or config.get("db_name") or config["db_user"] + + # read password + config["db_password"] = os.environ.get("FRAPPE_DB_PASSWORD") or config.get("db_password") + + if config["db_type"] == "postgres": + config["db_schema"] = os.environ.get("FRAPPE_DB_PG_SCHEMA") or config.get("db_schema") + # Allow externally extending the config with hooks if extra_config := config.get("extra_config"): if isinstance(extra_config, str): diff --git a/frappe/tests/test_db.py b/frappe/tests/test_db.py index 01198ecb38..6ba0bb5ce2 100644 --- a/frappe/tests/test_db.py +++ b/frappe/tests/test_db.py @@ -1083,38 +1083,81 @@ class TestSqlIterator(FrappeTestCase): class TestDbConnectWithEnvCredentials(FrappeTestCase): def test_connect_fails_with_wrong_credentials_by_env(self) -> None: + import contextlib import os - # with wrong db name - os.environ["DB_NAME"] = "dbiq" + @contextlib.contextmanager + def set_env_variable(key, value): + os.environ[key] = value + try: + yield + finally: + del os.environ[key] - frappe.connect() - self.assertRaises(frappe.db.OperationalError, frappe.db.connect) + current_site = frappe.local.site + + # with wrong db name + with set_env_variable("FRAPPE_DB_NAME", "dbiq"): + frappe.init(current_site, force=True) + frappe.connect() + + with self.assertRaises(Exception) as cm: + frappe.db.connect() + + self.assertTrue('database "dbiq"' in str(cm.exception)) # with wrong host - del os.environ["DB_NAME"] - os.environ["DB_HOST"] = "iqx.local" + with set_env_variable("FRAPPE_DB_HOST", "iqx.local"): + frappe.init(current_site, force=True) + frappe.connect() - frappe.connect() - self.assertRaises(frappe.db.OperationalError, frappe.db.connect) + with self.assertRaises(Exception) as cm: + frappe.db.connect() + + self.assertTrue('host name "iqx.local"' in str(cm.exception)) # with wrong user name - del os.environ["DB_HOST"] - os.environ["DB_USER"] = "uname" + with set_env_variable("FRAPPE_DB_USER", "uname"): + frappe.init(current_site, force=True) + frappe.connect() - frappe.connect() - self.assertRaises(frappe.db.OperationalError, frappe.db.connect) + with self.assertRaises(Exception) as cm: + frappe.db.connect() + + self.assertTrue('user "uname"' in str(cm.exception)) # with wrong password - del os.environ["DB_USER"] - os.environ["DB_PASSWORD"] = "pass" + with set_env_variable("FRAPPE_DB_PASSWORD", "pass"): + frappe.init(current_site, force=True) + frappe.connect() - frappe.connect() - self.assertRaises(frappe.db.OperationalError, frappe.db.connect) + with self.assertRaises(Exception) as cm: + frappe.db.connect() + + self.assertTrue("password authentication failed" in str(cm.exception)) + + # with wrong password + with set_env_variable("FRAPPE_DB_PORT", "1111"): + frappe.init(current_site, force=True) + frappe.connect() + + with self.assertRaises(Exception) as cm: + frappe.db.connect() + + self.assertTrue("port 1111" in str(cm.exception)) + + # with wrong postgres schema + with set_env_variable("FRAPPE_DB_PG_SCHEMA", "pg_schema"): + frappe.init(current_site, force=True) + frappe.connect() + + if frappe.conf.get("db_type") == db_type_is.POSTGRES.value: + self.assertEqual(frappe.conf.get("db_schema"), "pg_schema") + else: + # for mariadb this env should not have any effect + self.assertIsNone(frappe.conf.get("db_schema")) # now with configured settings without any influences from env - del os.environ["DB_PASSWORD"] - # finally connect should work without any error (when no wrong credentials are given via ENV) - frappe.connect() + frappe.init(current_site, force=True) frappe.db.connect()