Skip to content

Commit

Permalink
fix for minimal dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Jun 24, 2020
1 parent 17a92ab commit de0fd62
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 18 deletions.
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile
from pytorch_lightning.utilities.cloud_io import gfile, makedirs


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
if not gfile.exists(self.dirpath):
gfile.makedirs(self.dirpath)
makedirs(self.dirpath)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -172,7 +172,7 @@ def _del_model(self, filepath):
def _save_model(self, filepath):
# make paths
if not gfile.exists(os.path.dirname(filepath)):
gfile.makedirs(os.path.dirname(filepath))
makedirs(os.path.dirname(filepath))

# delegate the saving to the model
if self.save_function is not None:
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open

# we want this for tf.iogfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
Expand Down Expand Up @@ -296,7 +296,7 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
return {}

with gfile.GFile(tags_csv, "r") as fp:
with cloud_open(tags_csv, "r") as fp:
csv_reader = csv.reader(fp, delimiter=',')
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

Expand All @@ -310,7 +310,7 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) ->
if isinstance(hparams, Namespace):
hparams = vars(hparams)

with gfile.GFile(tags_csv, 'w') as fp:
with cloud_open(tags_csv, 'w') as fp:
fieldnames = ['key', 'value']
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
Expand All @@ -333,7 +333,7 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
return {}

with gfile.GFile(config_yaml, "r") as fp:
with cloud_open(config_yaml, "r") as fp:
tags = yaml.load(fp, Loader=yaml.SafeLoader)

return tags
Expand All @@ -360,7 +360,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
hparams = dict(hparams)
assert isinstance(hparams, dict)

with gfile.GFile(config_yaml, 'w') as fp:
with cloud_open(config_yaml, 'w') as fp:
yaml.dump(hparams, fp)


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile
from pytorch_lightning.utilities.cloud_io import gfile, makedirs


class TensorBoardLogger(LightningLoggerBase):
Expand Down Expand Up @@ -98,7 +98,7 @@ def experiment(self) -> SummaryWriter:
return self._experiment

if not gfile.exists(self.root_dir):
gfile.makedirs(self.root_dir)
makedirs(self.root_dir)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import gfile
from pytorch_lightning.utilities.cloud_io import gfile, makedirs, pathlike


class TrainerCallbackConfigMixin(ABC):
Expand Down Expand Up @@ -69,7 +69,7 @@ def configure_checkpoint_callback(self):

if self.checkpoint_callback is True:
if not gfile.exists(ckpt_path):
gfile.makedirs(ckpt_path)
makedirs(ckpt_path)
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path,
monitor=monitor_key
Expand All @@ -80,7 +80,7 @@ def configure_checkpoint_callback(self):
self.checkpoint_callback.dirpath = ckpt_path
self.checkpoint_callback.filename = '{epoch}'
if not gfile.exists(self.checkpoint_callback.dirpath):
gfile.makedirs(self.checkpoint_callback.dirpath)
makedirs(self.checkpoint_callback.dirpath)

elif self.checkpoint_callback is False:
self.checkpoint_callback = None
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
)
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile
from pytorch_lightning.utilities.cloud_io import gfile, makedirs

try:
import torch_xla
Expand Down Expand Up @@ -462,15 +462,15 @@ def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
if not gfile.exists(folderpath):
gfile.makedirs(folderpath)
makedirs(folderpath)

# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

if not gfile.exists(folderpath):
gfile.makedirs(folderpath)
makedirs(folderpath)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
Expand Down
23 changes: 21 additions & 2 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import torch

from pathlib import Path
from urllib.parse import urlparse
from typing import Union

# we want this for tf.io.gfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
Expand All @@ -10,10 +12,27 @@

gfile = tf.io.gfile

pathlike = Union[Path, str]

def load(path_or_url: str, map_location=None):

def load(path_or_url: pathlike, map_location=None):
parsed = urlparse(path_or_url)
if parsed.scheme == '' or Path(path_or_url).is_file():
if parsed.scheme == "" or Path(path_or_url).is_file():
# no scheme or local file
return torch.load(path_or_url, map_location=map_location)
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)


def cloud_open(path: pathlike, mode: str):
try:
return gfile.GFile(path, mode)
except NotImplementedError:
# minimal dependencies are installed and only local files will work
return open(path, mode)


def makedirs(path: pathlike):
if hasattr(gfile, "makedirs"):
return gfile.makedirs(str(path))
# otherwise minimal dependencies are installed and only local files will work
return os.makedirs(pathlike, exist_ok=True)

0 comments on commit de0fd62

Please sign in to comment.