Skip to content

Commit

Permalink
make the fs variable private
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Sep 2, 2020
1 parent a86fe4c commit f45913a
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if filepath:
self.fs = get_filesystem(filepath)
self._fs = get_filesystem(filepath)
else:
self.fs = get_filesystem("") # will give local fileystem
if save_top_k > 0 and filepath is not None and self.fs.isdir(filepath) and len(self.fs.ls(filepath)) > 0:
self._fs = get_filesystem("") # will give local fileystem
if save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
Expand All @@ -135,13 +135,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if self.fs.isdir(filepath):
if self._fs.isdir(filepath):
self.dirpath, self.filename = filepath, "{epoch}"
else:
if self.fs.protocol == "file": # dont normalize remote paths
if self._fs.protocol == "file": # dont normalize remote paths
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
self.fs.makedirs(self.dirpath, exist_ok=True)
self._fs.makedirs(self.dirpath, exist_ok=True)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -184,16 +184,16 @@ def kth_best_model(self):
return self.kth_best_model_path

def _del_model(self, filepath):
if self.fs.exists(filepath):
self.fs.rm(filepath)
if self._fs.exists(filepath):
self._fs.rm(filepath)

def _save_model(self, filepath, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
self.fs.makedirs(os.path.dirname(filepath), exist_ok=True)
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the model
if self.save_function is not None:
Expand Down Expand Up @@ -299,7 +299,7 @@ def on_pretrain_routine_start(self, trainer, pl_module):
self.dirpath = ckpt_path

assert trainer.global_rank == 0, "tried to make a checkpoint from non global_rank=0"
self.fs.makedirs(self.dirpath, exist_ok=True)
self._fs.makedirs(self.dirpath, exist_ok=True)

def __warn_deprecated_monitor_key(self):
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
Expand Down Expand Up @@ -348,7 +348,7 @@ def on_validation_end(self, trainer, pl_module):
ckpt_name_metrics = trainer.logged_metrics
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
version_cnt = 0
while self.fs.exists(filepath):
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
Expand Down

0 comments on commit f45913a

Please sign in to comment.