Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
fix: datamodule can't load files with square brackets in names (#1501)
Browse files Browse the repository at this point in the history
* fix: datamodule can't load files with square brackets in names
* add escaping for URLs, and unfix the fsspec version
* different version format
* fix windows filepath checks
* fix escape_file_path

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yurijmikhalevich and pre-commit-ci[bot] authored Jan 20, 2023
1 parent 1f773a6 commit e685a92
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
28 changes: 27 additions & 1 deletion flash/core/data/utilities/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import glob
import re
from functools import partial
from os import PathLike
from typing import Union
from urllib.parse import parse_qs, quote, urlencode, urlparse

import fsspec
import numpy as np
Expand Down Expand Up @@ -139,9 +144,30 @@ def _get_loader(file_path: str, loaders):
)


WINDOWS_FILE_PATH_RE = re.compile("^[a-zA-Z]:(\\\\[^\\\\]|/[^/]).*")


def is_local_path(file_path: str) -> bool:
if WINDOWS_FILE_PATH_RE.fullmatch(file_path):
return True
return urlparse(file_path).scheme in ["", "file"]


def escape_url(url: str) -> str:
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}{quote(parsed.path)}?{urlencode(parse_qs(parsed.query), doseq=True)}"


def escape_file_path(file_path: Union[str, PathLike]) -> str:
file_path_str = str(file_path)
return glob.escape(file_path_str) if is_local_path(file_path_str) else escape_url(file_path_str)


def load(file_path: str, loaders):
loader = _get_loader(file_path, loaders)
with fsspec.open(file_path) as file:
# escaping file_path to avoid fsspec treating the path as a glob pattern
# fsspec ignores `expand=False` in read mode
with fsspec.open(escape_file_path(file_path)) as file:
return loader(file)


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ pandas>=1.1.0
jsonargparse[signatures]>=3.17.0, <=4.9.0
click>=7.1.2
protobuf<=3.20.1
fsspec
fsspec[http]>=2021.6.1,<=2022.7.1
lightning-utilities>=0.3.0
11 changes: 10 additions & 1 deletion tests/core/data/utilities/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def write_tsv(file_path):
@pytest.mark.parametrize(
"extension,write",
[(extension, write_image) for extension in IMG_EXTENSIONS]
+ [(extension, write_numpy) for extension in NP_EXTENSIONS],
+ [(extension, write_numpy) for extension in NP_EXTENSIONS]
# it shouldn't try to expand glob patterns in filenames
+ [(filename, write_image) for filename in ("image [test].jpeg",)],
)
def test_load_image(tmpdir, extension, write):
file_path = os.path.join(tmpdir, f"test{extension}")
Expand Down Expand Up @@ -149,6 +151,13 @@ def test_load_data_frame(tmpdir, extension, write):
Image.Image,
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed."),
),
# it shouldn't try to expand glob patterns in URLs
pytest.param(
"https://pl-flash-data.s3.amazonaws.com/images/ant_1 [test].jpg",
load_image,
Image.Image,
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed."),
),
pytest.param(
"https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg",
load_spectrogram,
Expand Down

0 comments on commit e685a92

Please sign in to comment.