diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 672670da72b55f..6bc64bdcbd8a85 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -16,6 +16,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only +from pytorch_lightning.utilities.io import gfile class ModelCheckpoint(Callback): @@ -97,7 +98,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() - if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: + if(filepath): + filepath = str(filepath) # the tests pass in a py.path.local but we want a str + if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0: rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" @@ -109,12 +112,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve if filepath is None: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: - if os.path.isdir(filepath): + if gfile.isdir(filepath): self.dirpath, self.filename = filepath, '{epoch}' else: filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) - os.makedirs(self.dirpath, exist_ok=True) + if not gfile.exists(self.dirpath): + gfile.makedirs(self.dirpath) self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only @@ -156,12 +160,19 @@ def kth_best_model(self): return self.kth_best_model_path def _del_model(self, filepath): - if os.path.isfile(filepath): - os.remove(filepath) + if gfile.exists(filepath): + try: + # in compat mode, remove is not implemented so if running this + # against an actual remove file system and the correct remote + # dependencies exist then this will work fine. + gfile.remove(filepath) + except AttributeError: + os.remove(filepath) def _save_model(self, filepath): # make paths - os.makedirs(os.path.dirname(filepath), exist_ok=True) + if not gfile.exists(os.path.dirname(filepath)): + gfile.makedirs(os.path.dirname(filepath)) # delegate the saving to the model if self.save_function is not None: @@ -249,7 +260,7 @@ def on_validation_end(self, trainer, pl_module): filepath = self.format_checkpoint_name(epoch, metrics) version_cnt = 0 - while os.path.isfile(filepath): + while gfile.exists(filepath): filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt) # this epoch called before version_cnt += 1 diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index ed299bb8a816e0..542bcb230aa9a5 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -12,6 +12,11 @@ from pytorch_lightning.utilities import rank_zero_warn, AttributeDict from pytorch_lightning.utilities.io import load as pl_load +# we want this for tf.iogfile, which if tf is installed gives full tf, +# otherwise gives a pruned down version which works for some file backends but +# not all +from tensorboard.compat import tf + PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace) try: @@ -269,11 +274,11 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: True >>> os.remove(path_csv) """ - if not os.path.isfile(tags_csv): + if not tf.io.gfile.exists(tags_csv): rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning) return {} - with open(tags_csv) as fp: + with tf.io.gfile.GFile(tags_csv, "r") as fp: csv_reader = csv.reader(fp, delimiter=',') tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} @@ -281,13 +286,13 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: - if not os.path.isdir(os.path.dirname(tags_csv)): + if not tf.io.gfile.isdir(os.path.dirname(tags_csv)): raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.') if isinstance(hparams, Namespace): hparams = vars(hparams) - with open(tags_csv, 'w') as fp: + with tf.io.gfile.GFile(tags_csv, 'w') as fp: fieldnames = ['key', 'value'] writer = csv.DictWriter(fp, fieldnames=fieldnames) writer.writerow({'key': 'key', 'value': 'value'}) @@ -306,24 +311,24 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: True >>> os.remove(path_yaml) """ - if not os.path.isfile(config_yaml): + if not tf.io.gfile.exists(config_yaml): rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning) return {} - with open(config_yaml) as fp: + with tf.io.gfile.GFile(config_yaml, "r") as fp: tags = yaml.load(fp, Loader=yaml.SafeLoader) return tags def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: - if not os.path.isdir(os.path.dirname(config_yaml)): + if not tf.io.gfile.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.') if isinstance(hparams, Namespace): hparams = vars(hparams) - with open(config_yaml, 'w', newline='') as fp: + with tf.io.gfile.GFile(config_yaml, 'w') as fp: yaml.dump(hparams, fp) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 12ce47398d3b37..a82c601176bc31 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -17,6 +17,7 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.io import gfile class TensorBoardLogger(LightningLoggerBase): @@ -97,7 +98,8 @@ def experiment(self) -> SummaryWriter: if self._experiment is not None: return self._experiment - os.makedirs(self.root_dir, exist_ok=True) + if not gfile.exists(self.root_dir): + gfile.makedirs(self.root_dir) self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment @@ -145,7 +147,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> def save(self) -> None: super().save() dir_path = self.log_dir - if not os.path.isdir(dir_path): + if not gfile.isdir(dir_path): dir_path = self.save_dir # prepare the file path @@ -171,13 +173,13 @@ def version(self) -> int: def _get_next_version(self): root_dir = os.path.join(self.save_dir, self.name) - if not os.path.isdir(root_dir): + if not gfile.isdir(root_dir): log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for d in os.listdir(root_dir): - if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): + for d in gfile.listdir(root_dir): + if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): existing_versions.append(int(d.split("_")[1])) if len(existing_versions) == 0: diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 5e490a106826bd..a42043255d13f1 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -6,6 +6,7 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.io import gfile class TrainerCallbackConfigMixin(ABC): @@ -67,7 +68,8 @@ def configure_checkpoint_callback(self): monitor_key = 'loss' if train_step_only else 'val_loss' if self.checkpoint_callback is True: - os.makedirs(ckpt_path, exist_ok=True) + if not gfile.exists(ckpt_path): + gfile.makedirs(ckpt_path) self.checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, monitor=monitor_key @@ -77,7 +79,9 @@ def configure_checkpoint_callback(self): and self.checkpoint_callback.dirpath is None: self.checkpoint_callback.dirpath = ckpt_path self.checkpoint_callback.filename = '{epoch}' - os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True) + if not gfile.exists(self.checkpoint_callback.dirpath): + gfile.makedirs(self.checkpoint_callback.dirpath) + elif self.checkpoint_callback is False: self.checkpoint_callback = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b6cdbb0cf130df..0e1bec713c52a7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -371,10 +371,11 @@ def __init__( ' val and test loop using a single batch') # set default save path if user didn't provide one - self.default_root_dir = default_root_dir - - if self.default_root_dir is None: + if default_root_dir is None: self.default_root_dir = os.getcwd() + else: + # we have to do str() because the unit tests violate type annotation and pass path objecto + self.default_root_dir = str(default_root_dir) # training bookeeping self.total_batch_idx = 0 diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 6f4e85d5b28e44..f85ba273bf6d1c 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -102,6 +102,7 @@ ) from pytorch_lightning.utilities import rank_zero_warn, parsing from pytorch_lightning.utilities.io import load as pl_load +from pytorch_lightning.utilities.io import gfile try: import torch_xla @@ -374,9 +375,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule): did_restore = False # look for hpc weights - folderpath = self.weights_save_path - if os.path.exists(folderpath): - files = os.listdir(folderpath) + folderpath = str(self.weights_save_path) + if gfile.exists(folderpath): + files = gfile.listdir(folderpath) hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x] # if hpc weights exist restore model @@ -451,15 +452,17 @@ def restore_training_state(self, checkpoint): # ---------------------------------- def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists - os.makedirs(folderpath, exist_ok=True) + folderpath = str(folderpath) # because the tests pass a path object + if not gfile.exists(folderpath): + gfile.makedirs(folderpath) # save logger to make sure we get all the metrics logger.save() ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 - if not os.path.exists(folderpath): - os.makedirs(folderpath, exist_ok=True) + if not gfile.exists(folderpath): + gfile.makedirs(folderpath) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save @@ -509,7 +512,7 @@ def hpc_load(self, folderpath, on_gpu): log.info(f'restored hpc model from: {filepath}') def max_ckpt_in_folder(self, path, name_key='ckpt_'): - files = os.listdir(path) + files = gfile.listdir(str(path)) files = [x for x in files if name_key in x] if len(files) == 0: return 0 diff --git a/pytorch_lightning/utilities/io.py b/pytorch_lightning/utilities/io.py index b2f9ffa788541f..c24ba2ae8edf46 100644 --- a/pytorch_lightning/utilities/io.py +++ b/pytorch_lightning/utilities/io.py @@ -2,6 +2,13 @@ from urllib.parse import urlparse +# we want this for tf.io.gfile, which if tf is installed gives full tf, +# otherwise gives a pruned down version which works for some file backends but +# not all +from tensorboard.compat import tf + +gfile = tf.io.gfile + def load(path_or_url: str, map_location=None): parsed = urlparse(path_or_url)