Skip to content

Commit

Permalink
fix missing return statement. Do not normalize remote paths
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Aug 9, 2020
1 parent 6c18fd9 commit c87b092
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def on_train_start(self, trainer, pl_module):
self.dirpath = ckpt_path

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
os.makedirs(self.dirpath, exist_ok=True)
makedirs(self.dirpath)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,9 @@ def default_root_dir(self) -> str:
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if "://" in str(self._default_root_dir):
# it is a remote uri, use as is
return self._default_root_dir
return os.path.normpath(self._default_root_dir)

@property
Expand All @@ -885,6 +888,9 @@ def weights_save_path(self) -> str:
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if "://" in str(self._weights_save_path):
# it is a remote uri, use as is
return self._weights_save_path
return os.path.normpath(self._weights_save_path)

# -----------------------------
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def modern_gfile():
file operations
"""
tb_version = version.parse(tensorboard.version.VERSION)
modern_gfile = tb_version >= version.parse('2.0')
modern_gfile = tb_version >= version.parse("2.0")
return modern_gfile


def cloud_open(path: pathlike, mode: str, newline:str = None):
def cloud_open(path: pathlike, mode: str, newline: str = None):
if sys.platform == "win32":
log.debug(
"gfile does not handle newlines correctly on windows so remote files are not"
Expand Down

0 comments on commit c87b092

Please sign in to comment.