Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Aug 8, 2020
1 parent 6d22bde commit 896d94a
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 45 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
return {}

with cloud_open(tags_csv, "r") as fp:
csv_reader = csv.reader(fp.read(), delimiter=",")
with cloud_open(tags_csv, "r", newline="") as fp:
csv_reader = csv.reader(fp, delimiter=",")
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

return tags
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
# only support remote cloud paths if newer
modern_gfile = version.parse(tensorboard.version.VERSION) >= version.parse('2.0')

import torch


def load(path_or_url: str, map_location=None):
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
Expand All @@ -32,7 +30,7 @@ def load(path_or_url: str, map_location=None):


def cloud_open(path: pathlike, mode: str, newline:str = None):
if not modern_gfile or sys.platform == "win32":
if not modern_gfile:
log.debug(
"tenosrboard.compat gfile does not work on older versions "
"of tensorboard normal local file open."
Expand All @@ -48,7 +46,7 @@ def cloud_open(path: pathlike, mode: str, newline:str = None):
return gfile.GFile(path, mode)
except NotImplementedError as e:
# minimal dependencies are installed and only local files will work
return open(path, mode)
return open(path, mode, newline=newline)


def makedirs(path: pathlike):
Expand Down
39 changes: 0 additions & 39 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import pickle
import sys
import types
import boto3
import botocore
from argparse import Namespace
from pathlib import Path
from unittest.mock import patch
Expand All @@ -17,7 +15,6 @@
import torch
from packaging import version
from omegaconf import OmegaConf
from moto import mock_s3

import tests.base.develop_utils as tutils
from pytorch_lightning import Callback, LightningModule, Trainer
Expand Down Expand Up @@ -1081,42 +1078,6 @@ def test_trainer_pickle(tmpdir):
cloudpickle.dumps(trainer)


@pytest.fixture(scope="function")
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ["AWS_ACCESS_KEY_ID"] = "testing" # nosec
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" # nosec
os.environ["AWS_SECURITY_TOKEN"] = "testing" # nosec
os.environ["AWS_SESSION_TOKEN"] = "testing" # nosec


@pytest.mark.skipif(platform.system() == "Windows", reason="Saving to remote paths is not supported on Windows")
@pytest.mark.skipif(
version.parse(tensorboard.version.VERSION) < version.parse("2.0"), reason="remote paths require tensorboard>=2.0"
)
def test_trainer_s3_path(monkeypatch,aws_credentials):
"""Test that we can save to remote directories."""

# put everything on a remote s3 path
monkeypatch.setenv("TORCH_HOME", "s3://test_bucket/")

# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
model = EvalModelTemplate()
with mock_s3():
conn = boto3.resource("s3", region_name="us-east-1")
conn.create_bucket(Bucket="test_bucket")
tb_logger = TensorBoardLogger('s3://test_bucket/logs/') # use an s3 path
trainer = Trainer(max_epochs=1, default_root_dir="s3://test_bucket/outputpath/", logger=tb_logger)
result = trainer.fit(model)

# Explicitly fail if not using mock s3 to indicate it tries to write to
# a remote location and fails if s3 is not setup
with pytest.raises(botocore.exceptions.ClientError):
tb_logger = TensorBoardLogger('s3://test_bucket/logs/')
trainer = Trainer(max_epochs=1, default_root_dir="s3://test_bucket/outputpath_fail/", logger=tb_logger)
result = trainer.fit(model)


def test_trainer_setup_call(tmpdir):
"""Test setup call with fit and test call."""

Expand Down

0 comments on commit 896d94a

Please sign in to comment.