diff --git a/frappe/__init__.py b/frappe/__init__.py index 81c552078a..77b840412e 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -343,6 +343,7 @@ def connect(site: str | None = None, db_name: str | None = None, set_admin_as_us assert local.conf.db_password, "site must be fully initialized, db_password missing" local.db = get_db( + socket=local.conf.db_socket, host=local.conf.db_host, port=local.conf.db_port, user=local.conf.db_user or db_name, @@ -368,6 +369,7 @@ def connect_replica() -> bool: password = local.conf.replica_db_password local.replica_db = get_db( + socket=None, host=local.conf.replica_host, port=port, user=user, @@ -430,6 +432,7 @@ def get_site_config(sites_path: str | None = None, site_path: str | None = None) os.environ.get("FRAPPE_REDIS_CACHE") or config.get("redis_cache") or "redis://127.0.0.1:13311" ) config["db_type"] = os.environ.get("FRAPPE_DB_TYPE") or config.get("db_type") or "mariadb" + config["db_socket"] = os.environ.get("FRAPPE_DB_SOCKET") or config.get("db_socket") config["db_host"] = os.environ.get("FRAPPE_DB_HOST") or config.get("db_host") or "127.0.0.1" config["db_port"] = ( os.environ.get("FRAPPE_DB_PORT") or config.get("db_port") or db_default_ports(config["db_type"]) diff --git a/frappe/commands/site.py b/frappe/commands/site.py index 24995f7599..d24fe3e68f 100644 --- a/frappe/commands/site.py +++ b/frappe/commands/site.py @@ -33,10 +33,10 @@ from frappe.utils import CallbackManager ) @click.option("--db-root-password", "--mariadb-root-password", help="Root password for MariaDB or PostgreSQL") @click.option( - "--no-mariadb-socket", - is_flag=True, - default=False, - help="Set MariaDB host to % and use TCP/IP Socket instead of using the UNIX Socket", + "--db-socket", + "--mariadb-db-socket", + envvar="MYSQL_UNIX_PORT", + help="Database socket for MariaDB or folder containing database socket for PostgreSQL", ) @click.option( "--no-mariadb-socket", @@ -78,6 +78,7 @@ def new_site( db_name=None, db_password=None, db_type=None, + db_socket=None, db_host=None, db_port=None, db_user=None, @@ -115,6 +116,7 @@ def new_site( force=force, db_password=db_password, db_type=db_type, + db_socket=db_socket, db_host=db_host, db_port=db_port, db_user=db_user, diff --git a/frappe/commands/utils.py b/frappe/commands/utils.py index 5a61f00140..6f783bb7ef 100644 --- a/frappe/commands/utils.py +++ b/frappe/commands/utils.py @@ -526,6 +526,7 @@ def _enter_console(extra_args=None): os.environ["PSQL_HISTORY"] = os.path.abspath(get_site_path("logs", "postgresql_console.log")) bin, args, bin_name = get_command( + socket=frappe.conf.db_socket, host=frappe.conf.db_host, port=frappe.conf.db_port, user=frappe.conf.db_user, diff --git a/frappe/database/__init__.py b/frappe/database/__init__.py index d8435fd1de..fd59452c97 100644 --- a/frappe/database/__init__.py +++ b/frappe/database/__init__.py @@ -47,20 +47,26 @@ def drop_user_and_database(db_name, db_user): return frappe.database.mariadb.setup_db.drop_user_and_database(db_name, db_user) -def get_db(host=None, user=None, password=None, port=None, cur_db_name=None): +def get_db(socket=None, host=None, user=None, password=None, port=None, cur_db_name=None): import frappe if frappe.conf.db_type == "postgres": import frappe.database.postgres.database - return frappe.database.postgres.database.PostgresDatabase(host, user, password, port, cur_db_name) + return frappe.database.postgres.database.PostgresDatabase( + socket, host, user, password, port, cur_db_name + ) else: import frappe.database.mariadb.database - return frappe.database.mariadb.database.MariaDBDatabase(host, user, password, port, cur_db_name) + return frappe.database.mariadb.database.MariaDBDatabase( + socket, host, user, password, port, cur_db_name + ) -def get_command(host=None, port=None, user=None, password=None, db_name=None, extra=None, dump=False): +def get_command( + socket=None, host=None, port=None, user=None, password=None, db_name=None, extra=None, dump=False +): import frappe if frappe.conf.db_type == "postgres": @@ -69,7 +75,11 @@ def get_command(host=None, port=None, user=None, password=None, db_name=None, ex else: bin, bin_name = which("psql"), "psql" - if password: + if socket and password: + conn_string = f"postgresql://{user}:{password}@/{db_name}?host={socket}" + elif socket: + conn_string = f"postgresql://{user}@/{db_name}?host={socket}" + elif password: conn_string = f"postgresql://{user}:{password}@{host}:{port}/{db_name}" else: conn_string = f"postgresql://{user}@{host}:{port}/{db_name}" @@ -85,11 +95,12 @@ def get_command(host=None, port=None, user=None, password=None, db_name=None, ex else: bin, bin_name = which("mariadb") or which("mysql"), "mariadb" - command = [ - f"--user={user}", - f"--host={host}", - f"--port={port}", - ] + command = [f"--user={user}"] + if socket: + command.append(f"--socket={socket}") + elif host and port: + command.append(f"--host={host}") + command.append(f"--port={port}") if password: command.append(f"--password={password}") diff --git a/frappe/database/database.py b/frappe/database/database.py index 0c7fe76a39..8417be0554 100644 --- a/frappe/database/database.py +++ b/frappe/database/database.py @@ -75,6 +75,7 @@ class Database: def __init__( self, + socket=None, host=None, user=None, password=None, @@ -82,6 +83,7 @@ class Database: cur_db_name=None, ): self.setup_type_map() + self.socket = socket self.host = host self.port = port self.user = user diff --git a/frappe/database/db_manager.py b/frappe/database/db_manager.py index 952a30c0fd..558550d1ac 100644 --- a/frappe/database/db_manager.py +++ b/frappe/database/db_manager.py @@ -70,6 +70,7 @@ class DbManager: source = ["<", source] bin, args, bin_name = get_command( + socket=frappe.conf.db_socket, host=frappe.conf.db_host, port=frappe.conf.db_port, user=user, diff --git a/frappe/database/mariadb/database.py b/frappe/database/mariadb/database.py index 4ac1c092c9..f70a4292ba 100644 --- a/frappe/database/mariadb/database.py +++ b/frappe/database/mariadb/database.py @@ -116,9 +116,7 @@ class MariaDBConnectionUtil: def get_connection_settings(self) -> dict: conn_settings = { - "host": self.host, "user": self.user, - "password": self.password, "conv": self.CONVERSION_MAP, "charset": "utf8mb4", "use_unicode": True, @@ -127,8 +125,15 @@ class MariaDBConnectionUtil: if self.cur_db_name: conn_settings["database"] = self.cur_db_name - if self.port: - conn_settings["port"] = int(self.port) + if self.socket: + conn_settings["unix_socket"] = self.socket + else: + conn_settings["host"] = self.host + if self.port: + conn_settings["port"] = int(self.port) + + if self.password: + conn_settings["password"] = self.password if frappe.conf.local_infile: conn_settings["local_infile"] = frappe.conf.local_infile diff --git a/frappe/database/mariadb/setup_db.py b/frappe/database/mariadb/setup_db.py index 8cb1b2d5d7..58419dee0d 100644 --- a/frappe/database/mariadb/setup_db.py +++ b/frappe/database/mariadb/setup_db.py @@ -136,6 +136,7 @@ def get_root_connection(): frappe.flags.root_password = frappe.conf.get("root_password") or getpass("MySQL root password: ") frappe.local.flags.root_connection = frappe.database.get_db( + socket=frappe.conf.db_socket, host=frappe.conf.db_host, port=frappe.conf.db_port, user=frappe.flags.root_login, diff --git a/frappe/database/postgres/database.py b/frappe/database/postgres/database.py index 78763fe2f7..aa64b74cb7 100644 --- a/frappe/database/postgres/database.py +++ b/frappe/database/postgres/database.py @@ -168,10 +168,12 @@ class PostgresDatabase(PostgresExceptionUtil, Database): conn_settings = { "dbname": self.cur_db_name, "user": self.user, - "host": self.host, - "password": self.password, + # libpg defaults to default socket if not specified + "host": self.host or self.socket, } - if self.port: + if self.password: + conn_settings["password"] = self.password + if not self.socket and self.port: conn_settings["port"] = self.port conn = psycopg2.connect(**conn_settings) diff --git a/frappe/database/postgres/setup_db.py b/frappe/database/postgres/setup_db.py index fe46cd3998..2c5b76169e 100644 --- a/frappe/database/postgres/setup_db.py +++ b/frappe/database/postgres/setup_db.py @@ -73,6 +73,7 @@ def get_root_connection(): ) frappe.local.flags.root_connection = frappe.database.get_db( + socket=frappe.conf.db_socket, host=frappe.conf.db_host, port=frappe.conf.db_port, user=frappe.flags.root_login, diff --git a/frappe/installer.py b/frappe/installer.py index 37d30fb3ea..f6214da580 100644 --- a/frappe/installer.py +++ b/frappe/installer.py @@ -48,6 +48,7 @@ def _new_site( force=False, db_password=None, db_type=None, + db_socket=None, db_host=None, db_port=None, db_user=None, @@ -88,6 +89,7 @@ def _new_site( force=force, db_password=db_password, db_type=db_type, + db_socket=db_socket, db_host=db_host, db_port=db_port, db_user=db_user, @@ -124,6 +126,7 @@ def install_db( site_config=None, db_password=None, db_type=None, + db_socket=None, db_host=None, db_port=None, db_user=None, @@ -146,6 +149,7 @@ def install_db( site_config=site_config, db_password=db_password, db_type=db_type, + db_socket=db_socket, db_host=db_host, db_port=db_port, db_user=db_user, @@ -537,6 +541,7 @@ def make_conf( db_password=None, site_config=None, db_type=None, + db_socket=None, db_host=None, db_port=None, db_user=None, @@ -547,6 +552,7 @@ def make_conf( db_password, site_config, db_type=db_type, + db_socket=db_socket, db_host=db_host, db_port=db_port, db_user=db_user, @@ -561,6 +567,7 @@ def make_site_config( db_password=None, site_config=None, db_type=None, + db_socket=None, db_host=None, db_port=None, db_user=None, @@ -575,6 +582,9 @@ def make_site_config( if db_type: site_config["db_type"] = db_type + if db_socket: + site_config["db_socket"] = db_socket + if db_host: site_config["db_host"] = db_host diff --git a/frappe/integrations/offsite_backup_utils.py b/frappe/integrations/offsite_backup_utils.py index 8cd29d5ecc..f16eabe748 100644 --- a/frappe/integrations/offsite_backup_utils.py +++ b/frappe/integrations/offsite_backup_utils.py @@ -49,6 +49,7 @@ def get_latest_backup_file(with_files=False): frappe.conf.db_name, frappe.conf.db_user, frappe.conf.db_password, + db_socket=frappe.conf.db_socket, db_host=frappe.conf.db_host, db_port=frappe.conf.db_port, db_type=frappe.conf.db_type, @@ -107,6 +108,7 @@ def generate_files_backup(): frappe.conf.db_name, frappe.conf.db_user, frappe.conf.db_password, + db_socket=frappe.conf.db_socket, db_host=frappe.conf.db_host, db_port=frappe.conf.db_port, db_type=frappe.conf.db_type, diff --git a/frappe/tests/test_commands.py b/frappe/tests/test_commands.py index dac8c24377..89790c600a 100644 --- a/frappe/tests/test_commands.py +++ b/frappe/tests/test_commands.py @@ -684,6 +684,7 @@ class TestBackups(BaseTestCommands): frappe.conf.db_name, frappe.conf.db_name, frappe.conf.db_password + "INCORRECT PASSWORD", + db_socket=frappe.conf.db_socket, db_host=frappe.conf.db_host, db_port=frappe.conf.db_port, db_type=frappe.conf.db_type, diff --git a/frappe/utils/backups.py b/frappe/utils/backups.py index a0238afcc1..f14a989e81 100644 --- a/frappe/utils/backups.py +++ b/frappe/utils/backups.py @@ -47,6 +47,7 @@ class BackupGenerator: backup_path_db=None, backup_path_files=None, backup_path_private_files=None, + db_socket=None, db_host=None, db_port=None, db_type=None, @@ -60,6 +61,7 @@ class BackupGenerator: ): global _verbose self.compress_files = compress_files or compress + self.db_socket = db_socket self.db_host = db_host self.db_port = db_port self.db_name = db_name @@ -426,6 +428,7 @@ class BackupGenerator: from frappe.database import get_command bin, args, bin_name = get_command( + socket=self.db_socket, host=self.db_host, port=self.db_port, user=self.user, @@ -501,6 +504,7 @@ def fetch_latest_backups(partial=False) -> dict: frappe.conf.db_name, frappe.conf.db_user, frappe.conf.db_password, + db_socket=frappe.conf.db_socket, db_host=frappe.conf.db_host, db_port=frappe.conf.db_port, db_type=frappe.conf.db_type, @@ -568,6 +572,7 @@ def new_backup( frappe.conf.db_name, frappe.conf.db_user, frappe.conf.db_password, + db_socket=frappe.conf.db_socket, db_host=frappe.conf.db_host, db_port=frappe.conf.db_port, db_type=frappe.conf.db_type, diff --git a/frappe/utils/connections.py b/frappe/utils/connections.py index 711c4d71f3..e398baa5f6 100644 --- a/frappe/utils/connections.py +++ b/frappe/utils/connections.py @@ -7,10 +7,13 @@ from frappe.exceptions import UrlSchemeNotSupported REDIS_KEYS = ("redis_cache", "redis_queue") -def is_open(scheme, hostname, port, timeout=10): +def is_open(scheme, hostname, port, path, timeout=10): if scheme in ["redis", "postgres", "mariadb"]: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) conn = (hostname, int(port)) + elif scheme == "unix": + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + conn = path else: raise UrlSchemeNotSupported(scheme) @@ -28,9 +31,11 @@ def is_open(scheme, hostname, port, timeout=10): def check_database(): config = get_conf() db_type = config.get("db_type", "mariadb") + if db_socket := config.get("db_socket"): + return {db_type: is_open("unix", None, None, db_socket)} db_host = config.get("db_host", "127.0.0.1") db_port = config.get("db_port", 3306 if db_type == "mariadb" else 5432) - return {db_type: is_open(db_type, db_host, db_port)} + return {db_type: is_open(db_type, db_host, db_port, None)} def check_redis(redis_services=None): @@ -39,7 +44,7 @@ def check_redis(redis_services=None): status = {} for srv in services: url = urlparse(config[srv]) - status[srv] = is_open(url.scheme, url.hostname, url.port) + status[srv] = is_open(url.scheme, url.hostname, url.port, url.path) return status