feat: enable db socket connection

This commit is contained in:
David 2024-03-29 10:21:27 +01:00
parent 6eb1607c0a
commit 39d4318a27
No known key found for this signature in database
GPG key ID: AB15A6AF1101390D
15 changed files with 76 additions and 24 deletions

View file

@ -343,6 +343,7 @@ def connect(site: str | None = None, db_name: str | None = None, set_admin_as_us
assert local.conf.db_password, "site must be fully initialized, db_password missing"
local.db = get_db(
socket=local.conf.db_socket,
host=local.conf.db_host,
port=local.conf.db_port,
user=local.conf.db_user or db_name,
@ -368,6 +369,7 @@ def connect_replica() -> bool:
password = local.conf.replica_db_password
local.replica_db = get_db(
socket=None,
host=local.conf.replica_host,
port=port,
user=user,
@ -430,6 +432,7 @@ def get_site_config(sites_path: str | None = None, site_path: str | None = None)
os.environ.get("FRAPPE_REDIS_CACHE") or config.get("redis_cache") or "redis://127.0.0.1:13311"
)
config["db_type"] = os.environ.get("FRAPPE_DB_TYPE") or config.get("db_type") or "mariadb"
config["db_socket"] = os.environ.get("FRAPPE_DB_SOCKET") or config.get("db_socket")
config["db_host"] = os.environ.get("FRAPPE_DB_HOST") or config.get("db_host") or "127.0.0.1"
config["db_port"] = (
os.environ.get("FRAPPE_DB_PORT") or config.get("db_port") or db_default_ports(config["db_type"])

View file

@ -33,10 +33,10 @@ from frappe.utils import CallbackManager
)
@click.option("--db-root-password", "--mariadb-root-password", help="Root password for MariaDB or PostgreSQL")
@click.option(
"--no-mariadb-socket",
is_flag=True,
default=False,
help="Set MariaDB host to % and use TCP/IP Socket instead of using the UNIX Socket",
"--db-socket",
"--mariadb-db-socket",
envvar="MYSQL_UNIX_PORT",
help="Database socket for MariaDB or folder containing database socket for PostgreSQL",
)
@click.option(
"--no-mariadb-socket",
@ -78,6 +78,7 @@ def new_site(
db_name=None,
db_password=None,
db_type=None,
db_socket=None,
db_host=None,
db_port=None,
db_user=None,
@ -115,6 +116,7 @@ def new_site(
force=force,
db_password=db_password,
db_type=db_type,
db_socket=db_socket,
db_host=db_host,
db_port=db_port,
db_user=db_user,

View file

@ -526,6 +526,7 @@ def _enter_console(extra_args=None):
os.environ["PSQL_HISTORY"] = os.path.abspath(get_site_path("logs", "postgresql_console.log"))
bin, args, bin_name = get_command(
socket=frappe.conf.db_socket,
host=frappe.conf.db_host,
port=frappe.conf.db_port,
user=frappe.conf.db_user,

View file

@ -47,20 +47,26 @@ def drop_user_and_database(db_name, db_user):
return frappe.database.mariadb.setup_db.drop_user_and_database(db_name, db_user)
def get_db(host=None, user=None, password=None, port=None, cur_db_name=None):
def get_db(socket=None, host=None, user=None, password=None, port=None, cur_db_name=None):
import frappe
if frappe.conf.db_type == "postgres":
import frappe.database.postgres.database
return frappe.database.postgres.database.PostgresDatabase(host, user, password, port, cur_db_name)
return frappe.database.postgres.database.PostgresDatabase(
socket, host, user, password, port, cur_db_name
)
else:
import frappe.database.mariadb.database
return frappe.database.mariadb.database.MariaDBDatabase(host, user, password, port, cur_db_name)
return frappe.database.mariadb.database.MariaDBDatabase(
socket, host, user, password, port, cur_db_name
)
def get_command(host=None, port=None, user=None, password=None, db_name=None, extra=None, dump=False):
def get_command(
socket=None, host=None, port=None, user=None, password=None, db_name=None, extra=None, dump=False
):
import frappe
if frappe.conf.db_type == "postgres":
@ -69,7 +75,11 @@ def get_command(host=None, port=None, user=None, password=None, db_name=None, ex
else:
bin, bin_name = which("psql"), "psql"
if password:
if socket and password:
conn_string = f"postgresql://{user}:{password}@/{db_name}?host={socket}"
elif socket:
conn_string = f"postgresql://{user}@/{db_name}?host={socket}"
elif password:
conn_string = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
else:
conn_string = f"postgresql://{user}@{host}:{port}/{db_name}"
@ -85,11 +95,12 @@ def get_command(host=None, port=None, user=None, password=None, db_name=None, ex
else:
bin, bin_name = which("mariadb") or which("mysql"), "mariadb"
command = [
f"--user={user}",
f"--host={host}",
f"--port={port}",
]
command = [f"--user={user}"]
if socket:
command.append(f"--socket={socket}")
elif host and port:
command.append(f"--host={host}")
command.append(f"--port={port}")
if password:
command.append(f"--password={password}")

View file

@ -75,6 +75,7 @@ class Database:
def __init__(
self,
socket=None,
host=None,
user=None,
password=None,
@ -82,6 +83,7 @@ class Database:
cur_db_name=None,
):
self.setup_type_map()
self.socket = socket
self.host = host
self.port = port
self.user = user

View file

@ -70,6 +70,7 @@ class DbManager:
source = ["<", source]
bin, args, bin_name = get_command(
socket=frappe.conf.db_socket,
host=frappe.conf.db_host,
port=frappe.conf.db_port,
user=user,

View file

@ -116,9 +116,7 @@ class MariaDBConnectionUtil:
def get_connection_settings(self) -> dict:
conn_settings = {
"host": self.host,
"user": self.user,
"password": self.password,
"conv": self.CONVERSION_MAP,
"charset": "utf8mb4",
"use_unicode": True,
@ -127,8 +125,15 @@ class MariaDBConnectionUtil:
if self.cur_db_name:
conn_settings["database"] = self.cur_db_name
if self.port:
conn_settings["port"] = int(self.port)
if self.socket:
conn_settings["unix_socket"] = self.socket
else:
conn_settings["host"] = self.host
if self.port:
conn_settings["port"] = int(self.port)
if self.password:
conn_settings["password"] = self.password
if frappe.conf.local_infile:
conn_settings["local_infile"] = frappe.conf.local_infile

View file

@ -136,6 +136,7 @@ def get_root_connection():
frappe.flags.root_password = frappe.conf.get("root_password") or getpass("MySQL root password: ")
frappe.local.flags.root_connection = frappe.database.get_db(
socket=frappe.conf.db_socket,
host=frappe.conf.db_host,
port=frappe.conf.db_port,
user=frappe.flags.root_login,

View file

@ -168,10 +168,12 @@ class PostgresDatabase(PostgresExceptionUtil, Database):
conn_settings = {
"dbname": self.cur_db_name,
"user": self.user,
"host": self.host,
"password": self.password,
# libpg defaults to default socket if not specified
"host": self.host or self.socket,
}
if self.port:
if self.password:
conn_settings["password"] = self.password
if not self.socket and self.port:
conn_settings["port"] = self.port
conn = psycopg2.connect(**conn_settings)

View file

@ -73,6 +73,7 @@ def get_root_connection():
)
frappe.local.flags.root_connection = frappe.database.get_db(
socket=frappe.conf.db_socket,
host=frappe.conf.db_host,
port=frappe.conf.db_port,
user=frappe.flags.root_login,

View file

@ -48,6 +48,7 @@ def _new_site(
force=False,
db_password=None,
db_type=None,
db_socket=None,
db_host=None,
db_port=None,
db_user=None,
@ -88,6 +89,7 @@ def _new_site(
force=force,
db_password=db_password,
db_type=db_type,
db_socket=db_socket,
db_host=db_host,
db_port=db_port,
db_user=db_user,
@ -124,6 +126,7 @@ def install_db(
site_config=None,
db_password=None,
db_type=None,
db_socket=None,
db_host=None,
db_port=None,
db_user=None,
@ -146,6 +149,7 @@ def install_db(
site_config=site_config,
db_password=db_password,
db_type=db_type,
db_socket=db_socket,
db_host=db_host,
db_port=db_port,
db_user=db_user,
@ -537,6 +541,7 @@ def make_conf(
db_password=None,
site_config=None,
db_type=None,
db_socket=None,
db_host=None,
db_port=None,
db_user=None,
@ -547,6 +552,7 @@ def make_conf(
db_password,
site_config,
db_type=db_type,
db_socket=db_socket,
db_host=db_host,
db_port=db_port,
db_user=db_user,
@ -561,6 +567,7 @@ def make_site_config(
db_password=None,
site_config=None,
db_type=None,
db_socket=None,
db_host=None,
db_port=None,
db_user=None,
@ -575,6 +582,9 @@ def make_site_config(
if db_type:
site_config["db_type"] = db_type
if db_socket:
site_config["db_socket"] = db_socket
if db_host:
site_config["db_host"] = db_host

View file

@ -49,6 +49,7 @@ def get_latest_backup_file(with_files=False):
frappe.conf.db_name,
frappe.conf.db_user,
frappe.conf.db_password,
db_socket=frappe.conf.db_socket,
db_host=frappe.conf.db_host,
db_port=frappe.conf.db_port,
db_type=frappe.conf.db_type,
@ -107,6 +108,7 @@ def generate_files_backup():
frappe.conf.db_name,
frappe.conf.db_user,
frappe.conf.db_password,
db_socket=frappe.conf.db_socket,
db_host=frappe.conf.db_host,
db_port=frappe.conf.db_port,
db_type=frappe.conf.db_type,

View file

@ -684,6 +684,7 @@ class TestBackups(BaseTestCommands):
frappe.conf.db_name,
frappe.conf.db_name,
frappe.conf.db_password + "INCORRECT PASSWORD",
db_socket=frappe.conf.db_socket,
db_host=frappe.conf.db_host,
db_port=frappe.conf.db_port,
db_type=frappe.conf.db_type,

View file

@ -47,6 +47,7 @@ class BackupGenerator:
backup_path_db=None,
backup_path_files=None,
backup_path_private_files=None,
db_socket=None,
db_host=None,
db_port=None,
db_type=None,
@ -60,6 +61,7 @@ class BackupGenerator:
):
global _verbose
self.compress_files = compress_files or compress
self.db_socket = db_socket
self.db_host = db_host
self.db_port = db_port
self.db_name = db_name
@ -426,6 +428,7 @@ class BackupGenerator:
from frappe.database import get_command
bin, args, bin_name = get_command(
socket=self.db_socket,
host=self.db_host,
port=self.db_port,
user=self.user,
@ -501,6 +504,7 @@ def fetch_latest_backups(partial=False) -> dict:
frappe.conf.db_name,
frappe.conf.db_user,
frappe.conf.db_password,
db_socket=frappe.conf.db_socket,
db_host=frappe.conf.db_host,
db_port=frappe.conf.db_port,
db_type=frappe.conf.db_type,
@ -568,6 +572,7 @@ def new_backup(
frappe.conf.db_name,
frappe.conf.db_user,
frappe.conf.db_password,
db_socket=frappe.conf.db_socket,
db_host=frappe.conf.db_host,
db_port=frappe.conf.db_port,
db_type=frappe.conf.db_type,

View file

@ -7,10 +7,13 @@ from frappe.exceptions import UrlSchemeNotSupported
REDIS_KEYS = ("redis_cache", "redis_queue")
def is_open(scheme, hostname, port, timeout=10):
def is_open(scheme, hostname, port, path, timeout=10):
if scheme in ["redis", "postgres", "mariadb"]:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
conn = (hostname, int(port))
elif scheme == "unix":
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
conn = path
else:
raise UrlSchemeNotSupported(scheme)
@ -28,9 +31,11 @@ def is_open(scheme, hostname, port, timeout=10):
def check_database():
config = get_conf()
db_type = config.get("db_type", "mariadb")
if db_socket := config.get("db_socket"):
return {db_type: is_open("unix", None, None, db_socket)}
db_host = config.get("db_host", "127.0.0.1")
db_port = config.get("db_port", 3306 if db_type == "mariadb" else 5432)
return {db_type: is_open(db_type, db_host, db_port)}
return {db_type: is_open(db_type, db_host, db_port, None)}
def check_redis(redis_services=None):
@ -39,7 +44,7 @@ def check_redis(redis_services=None):
status = {}
for srv in services:
url = urlparse(config[srv])
status[srv] = is_open(url.scheme, url.hostname, url.port)
status[srv] = is_open(url.scheme, url.hostname, url.port, url.path)
return status