From a8ada6a4d560a342a202479deafc458a9f0d38c7 Mon Sep 17 00:00:00 2001 From: Brendan Fahy Date: Fri, 12 Jun 2020 15:51:56 -0400 Subject: [PATCH] use gfile to support remote directories Tests all use the `tmpfile` fixture which provides a py.path.local which is incompatible with the compat.gfile. The contract in many places is type str or Optional[str] which py.path.local is not. I hope that folks are not passing in path.local objects, if so this change will break them. The type annotations say to use str, so this should be ok. The other option is to just explicitly convert to str as to not break people using an incorrect type (like the tests were doing) --- .../callbacks/model_checkpoint.py | 25 +++++++++++++------ pytorch_lightning/core/saving.py | 21 ++++++++++------ pytorch_lightning/loggers/tensorboard.py | 12 +++++---- pytorch_lightning/trainer/callback_config.py | 8 ++++-- pytorch_lightning/trainer/trainer.py | 7 +++--- pytorch_lightning/trainer/training_io.py | 17 +++++++------ pytorch_lightning/utilities/io.py | 7 ++++++ 7 files changed, 65 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 672670da72b55f..6bc64bdcbd8a85 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -16,6 +16,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only +from pytorch_lightning.utilities.io import gfile class ModelCheckpoint(Callback): @@ -97,7 +98,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() - if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: + if(filepath): + filepath = str(filepath) # the tests pass in a py.path.local but we want a str + if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0: rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" @@ -109,12 +112,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve if filepath is None: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: - if os.path.isdir(filepath): + if gfile.isdir(filepath): self.dirpath, self.filename = filepath, '{epoch}' else: filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) - os.makedirs(self.dirpath, exist_ok=True) + if not gfile.exists(self.dirpath): + gfile.makedirs(self.dirpath) self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only @@ -156,12 +160,19 @@ def kth_best_model(self): return self.kth_best_model_path def _del_model(self, filepath): - if os.path.isfile(filepath): - os.remove(filepath) + if gfile.exists(filepath): + try: + # in compat mode, remove is not implemented so if running this + # against an actual remove file system and the correct remote + # dependencies exist then this will work fine. + gfile.remove(filepath) + except AttributeError: + os.remove(filepath) def _save_model(self, filepath): # make paths - os.makedirs(os.path.dirname(filepath), exist_ok=True) + if not gfile.exists(os.path.dirname(filepath)): + gfile.makedirs(os.path.dirname(filepath)) # delegate the saving to the model if self.save_function is not None: @@ -249,7 +260,7 @@ def on_validation_end(self, trainer, pl_module): filepath = self.format_checkpoint_name(epoch, metrics) version_cnt = 0 - while os.path.isfile(filepath): + while gfile.exists(filepath): filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt) # this epoch called before version_cnt += 1 diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index ed299bb8a816e0..542bcb230aa9a5 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -12,6 +12,11 @@ from pytorch_lightning.utilities import rank_zero_warn, AttributeDict from pytorch_lightning.utilities.io import load as pl_load +# we want this for tf.iogfile, which if tf is installed gives full tf, +# otherwise gives a pruned down version which works for some file backends but +# not all +from tensorboard.compat import tf + PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace) try: @@ -269,11 +274,11 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: True >>> os.remove(path_csv) """ - if not os.path.isfile(tags_csv): + if not tf.io.gfile.exists(tags_csv): rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning) return {} - with open(tags_csv) as fp: + with tf.io.gfile.GFile(tags_csv, "r") as fp: csv_reader = csv.reader(fp, delimiter=',') tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} @@ -281,13 +286,13 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: - if not os.path.isdir(os.path.dirname(tags_csv)): + if not tf.io.gfile.isdir(os.path.dirname(tags_csv)): raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.') if isinstance(hparams, Namespace): hparams = vars(hparams) - with open(tags_csv, 'w') as fp: + with tf.io.gfile.GFile(tags_csv, 'w') as fp: fieldnames = ['key', 'value'] writer = csv.DictWriter(fp, fieldnames=fieldnames) writer.writerow({'key': 'key', 'value': 'value'}) @@ -306,24 +311,24 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: True >>> os.remove(path_yaml) """ - if not os.path.isfile(config_yaml): + if not tf.io.gfile.exists(config_yaml): rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning) return {} - with open(config_yaml) as fp: + with tf.io.gfile.GFile(config_yaml, "r") as fp: tags = yaml.load(fp, Loader=yaml.SafeLoader) return tags def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: - if not os.path.isdir(os.path.dirname(config_yaml)): + if not tf.io.gfile.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.') if isinstance(hparams, Namespace): hparams = vars(hparams) - with open(config_yaml, 'w', newline='') as fp: + with tf.io.gfile.GFile(config_yaml, 'w') as fp: yaml.dump(hparams, fp) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 12ce47398d3b37..a82c601176bc31 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -17,6 +17,7 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.io import gfile class TensorBoardLogger(LightningLoggerBase): @@ -97,7 +98,8 @@ def experiment(self) -> SummaryWriter: if self._experiment is not None: return self._experiment - os.makedirs(self.root_dir, exist_ok=True) + if not gfile.exists(self.root_dir): + gfile.makedirs(self.root_dir) self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment @@ -145,7 +147,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> def save(self) -> None: super().save() dir_path = self.log_dir - if not os.path.isdir(dir_path): + if not gfile.isdir(dir_path): dir_path = self.save_dir # prepare the file path @@ -171,13 +173,13 @@ def version(self) -> int: def _get_next_version(self): root_dir = os.path.join(self.save_dir, self.name) - if not os.path.isdir(root_dir): + if not gfile.isdir(root_dir): log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for d in os.listdir(root_dir): - if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): + for d in gfile.listdir(root_dir): + if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): existing_versions.append(int(d.split("_")[1])) if len(existing_versions) == 0: diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 5e490a106826bd..a42043255d13f1 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -6,6 +6,7 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.io import gfile class TrainerCallbackConfigMixin(ABC): @@ -67,7 +68,8 @@ def configure_checkpoint_callback(self): monitor_key = 'loss' if train_step_only else 'val_loss' if self.checkpoint_callback is True: - os.makedirs(ckpt_path, exist_ok=True) + if not gfile.exists(ckpt_path): + gfile.makedirs(ckpt_path) self.checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, monitor=monitor_key @@ -77,7 +79,9 @@ def configure_checkpoint_callback(self): and self.checkpoint_callback.dirpath is None: self.checkpoint_callback.dirpath = ckpt_path self.checkpoint_callback.filename = '{epoch}' - os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True) + if not gfile.exists(self.checkpoint_callback.dirpath): + gfile.makedirs(self.checkpoint_callback.dirpath) + elif self.checkpoint_callback is False: self.checkpoint_callback = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b6cdbb0cf130df..0e1bec713c52a7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -371,10 +371,11 @@ def __init__( ' val and test loop using a single batch') # set default save path if user didn't provide one - self.default_root_dir = default_root_dir - - if self.default_root_dir is None: + if default_root_dir is None: self.default_root_dir = os.getcwd() + else: + # we have to do str() because the unit tests violate type annotation and pass path objecto + self.default_root_dir = str(default_root_dir) # training bookeeping self.total_batch_idx = 0 diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 6f4e85d5b28e44..f85ba273bf6d1c 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -102,6 +102,7 @@ ) from pytorch_lightning.utilities import rank_zero_warn, parsing from pytorch_lightning.utilities.io import load as pl_load +from pytorch_lightning.utilities.io import gfile try: import torch_xla @@ -374,9 +375,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule): did_restore = False # look for hpc weights - folderpath = self.weights_save_path - if os.path.exists(folderpath): - files = os.listdir(folderpath) + folderpath = str(self.weights_save_path) + if gfile.exists(folderpath): + files = gfile.listdir(folderpath) hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x] # if hpc weights exist restore model @@ -451,15 +452,17 @@ def restore_training_state(self, checkpoint): # ---------------------------------- def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists - os.makedirs(folderpath, exist_ok=True) + folderpath = str(folderpath) # because the tests pass a path object + if not gfile.exists(folderpath): + gfile.makedirs(folderpath) # save logger to make sure we get all the metrics logger.save() ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 - if not os.path.exists(folderpath): - os.makedirs(folderpath, exist_ok=True) + if not gfile.exists(folderpath): + gfile.makedirs(folderpath) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save @@ -509,7 +512,7 @@ def hpc_load(self, folderpath, on_gpu): log.info(f'restored hpc model from: {filepath}') def max_ckpt_in_folder(self, path, name_key='ckpt_'): - files = os.listdir(path) + files = gfile.listdir(str(path)) files = [x for x in files if name_key in x] if len(files) == 0: return 0 diff --git a/pytorch_lightning/utilities/io.py b/pytorch_lightning/utilities/io.py index b2f9ffa788541f..c24ba2ae8edf46 100644 --- a/pytorch_lightning/utilities/io.py +++ b/pytorch_lightning/utilities/io.py @@ -2,6 +2,13 @@ from urllib.parse import urlparse +# we want this for tf.io.gfile, which if tf is installed gives full tf, +# otherwise gives a pruned down version which works for some file backends but +# not all +from tensorboard.compat import tf + +gfile = tf.io.gfile + def load(path_or_url: str, map_location=None): parsed = urlparse(path_or_url)