Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checkpointing to remote file paths #2925

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -122,10 +122,10 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
filepath = os.path.realpath(filepath)
if not is_remote_path(filepath): # dont normalize remote paths
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)
makedirs(self.dirpath) # calls with exist_ok
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -174,7 +174,12 @@ def _del_model(self, filepath):
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
os.remove(filepath)
if is_remote_path(filepath):
log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode."
" Please install tensorflow to run gfile in full mode"
" if writing checkpoints to remote locations")
else:
os.remove(filepath)

def _save_model(self, filepath, trainer, pl_module):

Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def train_fx(trial_hparams, cluster_manager, _):

"""

import io
import os
import re
from abc import ABC, abstractmethod
Expand All @@ -146,6 +147,7 @@ def train_fx(trial_hparams, cluster_manager, _):
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.cloud_io import cloud_open


try:
Expand Down Expand Up @@ -435,10 +437,13 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
bytesbuffer = io.BytesIO()
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
torch.save(model.state_dict(), last_path, _use_new_zipfile_serialization=False)
torch.save(model.state_dict(), bytesbuffer, _use_new_zipfile_serialization=False)
else:
torch.save(model.state_dict(), last_path)
torch.save(model.state_dict(), bytesbuffer)
with cloud_open(last_path, 'wb') as f:
f.write(bytesbuffer.getvalue())
mp_queue.put(last_path)

def save_spawn_weights(self, model):
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -880,7 +881,7 @@ def default_root_dir(self) -> str:
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if "://" in str(self._default_root_dir):
if is_remote_path(self._default_root_dir):
# it is a remote uri, use as is
return self._default_root_dir
return os.path.normpath(self._default_root_dir)
Expand All @@ -891,7 +892,7 @@ def weights_save_path(self) -> str:
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if "://" in str(self._weights_save_path):
if is_remote_path(self._weights_save_path):
# it is a remote uri, use as is
return self._weights_save_path
return os.path.normpath(self._weights_save_path)
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@

"""

import io
import os
import re
import signal
Expand All @@ -104,7 +105,7 @@
)
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
from pytorch_lightning.utilities.cloud_io import cloud_open, gfile, makedirs

try:
import torch_xla
Expand Down Expand Up @@ -269,15 +270,16 @@ def _atomic_save(self, checkpoint, filepath: str):
filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
"""
tmp_path = str(filepath) + ".part"
bytesbuffer = io.BytesIO()
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)
torch.save(checkpoint, bytesbuffer)
with cloud_open(filepath, 'wb') as f:
f.write(bytesbuffer.getvalue())
Comment on lines +273 to +282
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we make this as a function too? when we have one-liner is_remote_path


def save_checkpoint(self, filepath, weights_only: bool = False):
checkpoint = self.dump_checkpoint(weights_only)
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def load(path_or_url: str, map_location=None):
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)


def is_remote_path(path: pathlike):
"""Determine if a path is a local path or a remote path like s3://bucket/path

This should catch paths like s3:// hdfs:// and gcs://
"""
return "://" in str(path)


def modern_gfile():
"""Check the version number of tensorboard.

Expand Down Expand Up @@ -61,6 +69,7 @@ def cloud_open(path: pathlike, mode: str, newline: str = None):

def makedirs(path: pathlike):
if hasattr(gfile, "makedirs") and modern_gfile():
return gfile.makedirs(str(path))
if not gfile.exists(str(path)):
return gfile.makedirs(str(path))
# otherwise minimal dependencies are installed and only local files will work
return os.makedirs(path, exist_ok=True)