Skip to content

Commit

Permalink
fix missing return statement. Do not normalize remote paths
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Aug 9, 2020
1 parent 6c18fd9 commit 43f07e6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def on_train_start(self, trainer, pl_module):
self.dirpath = ckpt_path

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
os.makedirs(self.dirpath, exist_ok=True)
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,9 @@ def default_root_dir(self) -> str:
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if "://" in str(self._default_root_dir):
# it is a remote uri, use as is
return self._default_root_dir
return os.path.normpath(self._default_root_dir)

@property
Expand All @@ -885,6 +888,9 @@ def weights_save_path(self) -> str:
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if "://" in str(self._weights_save_path):
# it is a remote uri, use as is
return self._weights_save_path
return os.path.normpath(self._weights_save_path)

# -----------------------------
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import platform
import os
from typing import Union
from pathlib import Path
Expand Down Expand Up @@ -35,11 +35,12 @@ def modern_gfile():
file operations
"""
tb_version = version.parse(tensorboard.version.VERSION)
modern_gfile = tb_version >= version.parse('2.0')
modern_gfile = tb_version >= version.parse("2.0")
return modern_gfile


def cloud_open(path: pathlike, mode: str, newline:str = None):
if sys.platform == "win32":
def cloud_open(path: pathlike, mode: str, newline: str = None):
if platform.system() == "Windows":
log.debug(
"gfile does not handle newlines correctly on windows so remote files are not"
"supported falling back to normal local file open."
Expand Down

0 comments on commit 43f07e6

Please sign in to comment.