Skip to content

Commit

Permalink
fix tests on minimal versions
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Jul 6, 2020
1 parent 1f550aa commit 817dd3f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
14 changes: 12 additions & 2 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import torch
import shutil

import tensorboard
from pathlib import Path
from urllib.parse import urlparse
from typing import Union
from packaging import version

# we want this for tf.io.gfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
Expand All @@ -14,6 +17,10 @@

pathlike = Union[Path, str]

#older version of tensorboard had buggy gfile compatibility layers
#only support remote cloud paths if newer
modern_gfile = version.parse(tensorboard.__version__) >= version.parse('2.0')


def load(path_or_url: str, map_location=None):
if urlparse(path_or_url).scheme == "" or Path(path_or_url).drive: # no scheme or with a drive letter
Expand All @@ -23,15 +30,18 @@ def load(path_or_url: str, map_location=None):


def cloud_open(path: pathlike, mode: str):
if not modern_gfile:
return open(path, mode)
try:
return gfile.GFile(path, mode)
except NotImplementedError:
except NotImplementedError as e:
# minimal dependencies are installed and only local files will work
print("not imp", e)
return open(path, mode)


def makedirs(path: pathlike):
if hasattr(gfile, "makedirs"):
if modern_gfile and hasattr(gfile, "makedirs"):
return gfile.makedirs(str(path))
# otherwise minimal dependencies are installed and only local files will work
return os.makedirs(path, exist_ok=True)
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ pre-commit>=1.0

cloudpickle>=1.2

boto3
moto>=1.3.14
3 changes: 3 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from argparse import Namespace
from pathlib import Path

import tensorboard
import cloudpickle
import pytest
import torch
from packaging import version
from omegaconf import OmegaConf
from moto import mock_s3

Expand Down Expand Up @@ -966,6 +968,7 @@ def aws_credentials():
os.environ['AWS_SESSION_TOKEN'] = 'testing'


@pytest.mark.skipif(version.parse(tensorboard.__version__) < version.parse('2.0'), reason="remote paths require tensorboard>=2.0")
def test_trainer_s3_path(aws_credentials):
model = EvalModelTemplate()
with mock_s3():
Expand Down

0 comments on commit 817dd3f

Please sign in to comment.