diff --git a/frappe/commands/site.py b/frappe/commands/site.py index bc65aa178c..4a631be3ac 100755 --- a/frappe/commands/site.py +++ b/frappe/commands/site.py @@ -100,13 +100,11 @@ def restore(context, sql_file_path, mariadb_root_username=None, mariadb_root_pas # Extract public and/or private files to the restored site, if user has given the path if with_public_files: - with_public_files = os.path.join(base_path, with_public_files) - public = extract_files(site, with_public_files, 'public') + public = extract_files(site, with_public_files) os.remove(public) if with_private_files: - with_private_files = os.path.join(base_path, with_private_files) - private = extract_files(site, with_private_files, 'private') + private = extract_files(site, with_private_files) os.remove(private) # Removing temporarily created file diff --git a/frappe/installer.py b/frappe/installer.py index 1245a08cb7..a11c8dfbfa 100755 --- a/frappe/installer.py +++ b/frappe/installer.py @@ -440,20 +440,11 @@ def extract_sql_from_archive(sql_file_path): Returns: str: Path of the decompressed SQL file """ + from frappe.utils import get_bench_relative_path + sql_file_path = get_bench_relative_path(sql_file_path) # Extract the gzip file if user has passed *.sql.gz file instead of *.sql file - if not os.path.exists(sql_file_path): - base_path = '..' - sql_file_path = os.path.join(base_path, sql_file_path) - if not os.path.exists(sql_file_path): - print('Invalid path {0}'.format(sql_file_path[3:])) - sys.exit(1) - elif sql_file_path.startswith(os.sep): - base_path = os.sep - else: - base_path = '.' - if sql_file_path.endswith('sql.gz'): - decompressed_file_name = extract_sql_gzip(os.path.abspath(sql_file_path)) + decompressed_file_name = extract_sql_gzip(sql_file_path) else: decompressed_file_name = sql_file_path @@ -475,9 +466,12 @@ def extract_sql_gzip(sql_gz_path): return decompressed_file -def extract_files(site_name, file_path, folder_name): +def extract_files(site_name, file_path): import shutil import subprocess + from frappe.utils import get_bench_relative_path + + file_path = get_bench_relative_path(file_path) # Need to do frappe.init to maintain the site locals frappe.init(site=site_name) diff --git a/frappe/tests/test_commands.py b/frappe/tests/test_commands.py index 8c76ce2f48..0786a0e14f 100644 --- a/frappe/tests/test_commands.py +++ b/frappe/tests/test_commands.py @@ -14,7 +14,7 @@ import glob import frappe import frappe.recorder from frappe.installer import add_to_installed_apps -from frappe.utils import add_to_date, now +from frappe.utils import add_to_date, get_bench_relative_path, now from frappe.utils.backups import fetch_latest_backups @@ -364,3 +364,21 @@ class TestCommands(BaseTestCommands): else: installed_apps = set(frappe.get_installed_apps()) self.assertSetEqual(list_apps, installed_apps) + + def test_get_bench_relative_path(self): + bench_path = frappe.utils.get_bench_path() + test1_path = os.path.join(bench_path, "test1.txt") + test2_path = os.path.join(bench_path, "sites", "test2.txt") + + with open(test1_path, "w+") as test1: + test1.write("asdf") + with open(test2_path, "w+") as test2: + test2.write("asdf") + + self.assertTrue("test1.txt" in get_bench_relative_path("test1.txt")) + self.assertTrue("sites/test2.txt" in get_bench_relative_path("test2.txt")) + with self.assertRaises(SystemExit): + get_bench_relative_path("test3.txt") + + os.remove(test1_path) + os.remove(test2_path) diff --git a/frappe/utils/__init__.py b/frappe/utils/__init__.py index 3ae89936e7..5ac4de618d 100644 --- a/frappe/utils/__init__.py +++ b/frappe/utils/__init__.py @@ -734,3 +734,27 @@ def get_build_version(): # .build can sometimes not exist # this is not a major problem so send fallback return frappe.utils.random_string(8) + +def get_bench_relative_path(file_path): + """Fixes paths relative to the bench root directory if exists and returns the absolute path + + Args: + file_path (str, Path): Path of a file that exists on the file system + + Returns: + str: Absolute path of the file_path + """ + if not os.path.exists(file_path): + base_path = '..' + elif file_path.startswith(os.sep): + base_path = os.sep + else: + base_path = '.' + + file_path = os.path.join(base_path, file_path) + + if not os.path.exists(file_path): + print('Invalid path {0}'.format(file_path[3:])) + sys.exit(1) + + return os.path.abspath(file_path)