diff --git a/client/python/cryoet_data_portal/src/cryoet_data_portal/_file_tools.py b/client/python/cryoet_data_portal/src/cryoet_data_portal/_file_tools.py index ea0174d1c..de9fcf66d 100644 --- a/client/python/cryoet_data_portal/src/cryoet_data_portal/_file_tools.py +++ b/client/python/cryoet_data_portal/src/cryoet_data_portal/_file_tools.py @@ -57,13 +57,24 @@ def get_destination_path( dest_path: Optional[str], recursive_from_prefix: Optional[str] = None, ) -> str: + """ + Get the destination path for a file download. + + Args: + url (str): The URL to download + dest_path (str): The destination path the files will download to + recursive_from_prefix (str): All files under this prefix in the url will be downloaded to a path dest_path. + E.g. if the URL is https://example.com/a/b/file.txt, and the recursive_from_prefix would be + https://example.com/, then the dest_path would be dest_path/a/b/file.txt. + + Returns: + str: The destination path for the file download + + """ if not dest_path: dest_path = os.getcwd() dest_path = os.path.abspath(dest_path) - if not os.path.isdir(dest_path) and recursive_from_prefix: - raise ValueError("Recursive downloads require a base directory") - # If we're downloading recursively, we need to add the dest URL # (minus the prefix) to the dest path. if not recursive_from_prefix: @@ -71,7 +82,11 @@ def get_destination_path( path_suffix = url[len(recursive_from_prefix) :] dest_path = os.path.join(dest_path, os.path.dirname(path_suffix)) if not os.path.isdir(dest_path): - os.makedirs(dest_path, exist_ok=True) + try: + os.makedirs(dest_path, exist_ok=True) + except Exception as e: + raise ValueError(f"Unable to create the path {dest_path}") from e + dest_path = os.path.join(dest_path, os.path.basename(path_suffix)) return dest_path diff --git a/client/python/cryoet_data_portal/tests/test_file_tools.py b/client/python/cryoet_data_portal/tests/test_file_tools.py new file mode 100644 index 000000000..ceb27db29 --- /dev/null +++ b/client/python/cryoet_data_portal/tests/test_file_tools.py @@ -0,0 +1,64 @@ +import os +from unittest.mock import patch + +import pytest + +from cryoet_data_portal._file_tools import get_destination_path + + +class TestGetDestinationPath: + def test_url(self, tmp_path) -> None: + with patch("cryoet_data_portal._file_tools.os.getcwd", return_value=tmp_path): + url = "https://example.com/file.txt" + expected = os.path.join(tmp_path, "file.txt") + assert get_destination_path(url, None) == expected + + def test_dest_path_exists(self, tmp_path) -> None: + url = "https://example.com/file.txt" + dest_path = os.path.join(tmp_path, "my_dest") + os.makedirs(dest_path) + expected = os.path.join(dest_path, "file.txt") + assert get_destination_path(url, dest_path) == expected + + def test_dest_path_does_not_exist(self, tmp_path) -> None: + """Test that the destination path is created if it does not exist""" + url = "https://example.com/file.txt" + dest_path = os.path.join(tmp_path, "my_dest") + expected = os.path.join(dest_path, "file.txt") + assert get_destination_path(url, dest_path) == expected + assert os.path.isdir(dest_path) + + def test_recursive_from_prefix_where_dest_path_exist(self, tmp_path) -> None: + """Test when a recursive_from_prefix is provided and the dest_path exists""" + url = "https://example.com/a/file.txt" + dest_path = os.path.join(tmp_path, "my_dest") + os.makedirs(dest_path) + recursive_from_prefix = "https://example.com/" + expected = os.path.join(dest_path, "a", "file.txt") + assert get_destination_path(url, dest_path, recursive_from_prefix) == expected + + def test_recursive_from_prefix_where_dest_path_does_not_exist( + self, + tmp_path, + ) -> None: + """Test when a recursive_from_prefix is provided and the dest_path does not exist. The dest_path should be created.""" + url = "https://example.com/a/b/file.txt" + dest_path = os.path.join(tmp_path, "my_dest") + recursive_from_prefix = "https://example.com/" + expected_path = os.path.join(dest_path, "a", "b") + expected = os.path.join(dest_path, "a", "b", "file.txt") + assert get_destination_path(url, dest_path, recursive_from_prefix) == expected + assert os.path.isdir(expected_path) + + def test_invalid_dest_path( + self, + tmp_path, + ) -> None: + url = "https://example.com/file.txt" + dest_path = os.path.join(tmp_path, "\000") + recursive_from_prefix = "https://example.com/" + expected = os.path.join(dest_path, "file.txt") + with pytest.raises(ValueError): + assert ( + get_destination_path(url, dest_path, recursive_from_prefix) == expected + )