diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e8076782292dea..8db11bd1dd5fac 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -16,7 +16,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only -from pytorch_lightning.utilities.cloud_io import gfile +from pytorch_lightning.utilities.cloud_io import gfile, makedirs class ModelCheckpoint(Callback): @@ -118,7 +118,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) if not gfile.exists(self.dirpath): - gfile.makedirs(self.dirpath) + makedirs(self.dirpath) self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only @@ -172,7 +172,7 @@ def _del_model(self, filepath): def _save_model(self, filepath): # make paths if not gfile.exists(os.path.dirname(filepath)): - gfile.makedirs(os.path.dirname(filepath)) + makedirs(os.path.dirname(filepath)) # delegate the saving to the model if self.save_function is not None: diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 6477359fa62b86..ed80756ee81b98 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -11,7 +11,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.utilities import rank_zero_warn, AttributeDict from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.cloud_io import gfile +from pytorch_lightning.utilities.cloud_io import gfile, cloud_open # we want this for tf.iogfile, which if tf is installed gives full tf, # otherwise gives a pruned down version which works for some file backends but @@ -296,7 +296,7 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning) return {} - with gfile.GFile(tags_csv, "r") as fp: + with cloud_open(tags_csv, "r") as fp: csv_reader = csv.reader(fp, delimiter=',') tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} @@ -310,7 +310,7 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> if isinstance(hparams, Namespace): hparams = vars(hparams) - with gfile.GFile(tags_csv, 'w') as fp: + with cloud_open(tags_csv, 'w') as fp: fieldnames = ['key', 'value'] writer = csv.DictWriter(fp, fieldnames=fieldnames) writer.writerow({'key': 'key', 'value': 'value'}) @@ -333,7 +333,7 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning) return {} - with gfile.GFile(config_yaml, "r") as fp: + with cloud_open(config_yaml, "r") as fp: tags = yaml.load(fp, Loader=yaml.SafeLoader) return tags @@ -360,7 +360,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: hparams = dict(hparams) assert isinstance(hparams, dict) - with gfile.GFile(config_yaml, 'w') as fp: + with cloud_open(config_yaml, 'w') as fp: yaml.dump(hparams, fp) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index a1c4bdc538fd8e..43fff60c322f2e 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -16,7 +16,7 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.cloud_io import gfile +from pytorch_lightning.utilities.cloud_io import gfile, makedirs class TensorBoardLogger(LightningLoggerBase): @@ -98,7 +98,7 @@ def experiment(self) -> SummaryWriter: return self._experiment if not gfile.exists(self.root_dir): - gfile.makedirs(self.root_dir) + makedirs(self.root_dir) self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 887c222a2aed87..7d3320de3aa295 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -6,7 +6,7 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.cloud_io import gfile +from pytorch_lightning.utilities.cloud_io import gfile, makedirs, pathlike class TrainerCallbackConfigMixin(ABC): @@ -69,7 +69,7 @@ def configure_checkpoint_callback(self): if self.checkpoint_callback is True: if not gfile.exists(ckpt_path): - gfile.makedirs(ckpt_path) + makedirs(ckpt_path) self.checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, monitor=monitor_key @@ -80,7 +80,7 @@ def configure_checkpoint_callback(self): self.checkpoint_callback.dirpath = ckpt_path self.checkpoint_callback.filename = '{epoch}' if not gfile.exists(self.checkpoint_callback.dirpath): - gfile.makedirs(self.checkpoint_callback.dirpath) + makedirs(self.checkpoint_callback.dirpath) elif self.checkpoint_callback is False: self.checkpoint_callback = None diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 726255dcd3269c..1bdee178553baa 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -102,7 +102,7 @@ ) from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.cloud_io import gfile +from pytorch_lightning.utilities.cloud_io import gfile, makedirs try: import torch_xla @@ -462,7 +462,7 @@ def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object if not gfile.exists(folderpath): - gfile.makedirs(folderpath) + makedirs(folderpath) # save logger to make sure we get all the metrics logger.save() @@ -470,7 +470,7 @@ def hpc_save(self, folderpath: str, logger): ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 if not gfile.exists(folderpath): - gfile.makedirs(folderpath) + makedirs(folderpath) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 655e63284405b6..c2b1b604da7c74 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -1,7 +1,9 @@ +import os import torch from pathlib import Path from urllib.parse import urlparse +from typing import Union # we want this for tf.io.gfile, which if tf is installed gives full tf, # otherwise gives a pruned down version which works for some file backends but @@ -10,10 +12,27 @@ gfile = tf.io.gfile +pathlike = Union[Path, str] -def load(path_or_url: str, map_location=None): + +def load(path_or_url: pathlike, map_location=None): parsed = urlparse(path_or_url) - if parsed.scheme == '' or Path(path_or_url).is_file(): + if parsed.scheme == "" or Path(path_or_url).is_file(): # no scheme or local file return torch.load(path_or_url, map_location=map_location) return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) + + +def cloud_open(path: pathlike, mode: str): + try: + return gfile.GFile(path, mode) + except NotImplementedError: + # minimal dependencies are installed and only local files will work + return open(path, mode) + + +def makedirs(path: pathlike): + if hasattr(gfile, "makedirs"): + return gfile.makedirs(str(path)) + # otherwise minimal dependencies are installed and only local files will work + return os.makedirs(pathlike, exist_ok=True)