diff --git a/CHANGELOG.md b/CHANGELOG.md index ec6157d2bc1ed..c70b280b2e73f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)) - Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610)) - Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115)) +- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667)) ### Changed diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 1e399ff2ecfc4..ed299bb8a816e 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -10,6 +10,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.utilities import rank_zero_warn, AttributeDict +from pytorch_lightning.utilities.io import load as pl_load PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace) @@ -52,10 +53,10 @@ def load_from_checkpoint( Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to `__init__` in the checkpoint under `module_arguments` - Any arguments specified through \*args and \*\*kwargs will override args stored in `module_arguments`. + Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`. Args: - checkpoint_path: Path to checkpoint. + checkpoint_path: Path to checkpoint. This can also be a URL. args: Any positional args needed to init the model. map_location: If your checkpoint saved a GPU model and you now load on CPUs @@ -131,9 +132,9 @@ def load_from_checkpoint( y_hat = pretrained_model(x) """ if map_location is not None: - checkpoint = torch.load(checkpoint_path, map_location=map_location) + checkpoint = pl_load(checkpoint_path, map_location=map_location) else: - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # add the hparams from csv file to checkpoint if tags_csv is not None: @@ -162,7 +163,6 @@ def load_from_checkpoint( @classmethod def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs): - # pass in the values we saved automatically if cls.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint: # todo add some back compatibility diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1cf216eac345d..0fb76acd1a700 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -279,6 +279,7 @@ def __init__( truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here. + This can be a URL. profiler: To profile individual steps during training and assist in diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index bd3d312a86882..7c6d429a980b6 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -102,6 +102,7 @@ LightningDataParallel, ) from pytorch_lightning.utilities import rank_zero_warn, parsing +from pytorch_lightning.utilities.io import load as pl_load try: import torch_xla @@ -287,7 +288,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool): # checkpoint = torch.load(checkpoint_path) # else: # load on CPU first - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # load model state model = self.get_model() diff --git a/pytorch_lightning/utilities/io.py b/pytorch_lightning/utilities/io.py new file mode 100644 index 0000000000000..b2f9ffa788541 --- /dev/null +++ b/pytorch_lightning/utilities/io.py @@ -0,0 +1,11 @@ +import torch + +from urllib.parse import urlparse + + +def load(path_or_url: str, map_location=None): + parsed = urlparse(path_or_url) + if parsed.scheme == '': + # local file + return torch.load(path_or_url, map_location=map_location) + return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) diff --git a/tests/conftest.py b/tests/conftest.py index 8eb3444ddaaba..216bbe171d018 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,9 @@ -from functools import wraps +from functools import wraps, partial +from http.server import SimpleHTTPRequestHandler +import sys import pytest +import threading import torch.multiprocessing as mp @@ -17,3 +20,38 @@ def pytest_pyfunc_call(pyfuncitem): mp.spawn(wraps, (testfunction, testargs)) return True + + +@pytest.fixture +def tmpdir_server(tmpdir): + if sys.version_info >= (3, 7): + Handler = partial(SimpleHTTPRequestHandler, directory=str(tmpdir)) + from http.server import ThreadingHTTPServer + else: + # unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6 + # so we have to hack it like this + import os + + class Handler(SimpleHTTPRequestHandler): + def translate_path(self, path): + # get the path from cwd + path = super().translate_path(path) + # get the relative path + relpath = os.path.relpath(path, os.getcwd()) + # return the full path from root_dir + return os.path.join(str(tmpdir), relpath) + + # ThreadingHTTPServer was added in 3.7, so we need to define it ourselves + from socketserver import ThreadingMixIn + from http.server import HTTPServer + + class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): + daemon_threads = True + + with ThreadingHTTPServer(('', 0), Handler) as server: + server_thread = threading.Thread(target=server.serve_forever) + # Exit the server thread when the main thread terminates + server_thread.daemon = True + server_thread.start() + yield server.server_address + server.shutdown() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index bf4c8198e3db7..caf1632ec3474 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -16,12 +16,16 @@ from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.trainer.logging import TrainerLoggingMixin +from pytorch_lightning.utilities.io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -def test_no_val_module(tmpdir): +@pytest.mark.parametrize('url_ckpt', [True, False]) +def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv('TORCH_HOME', tmpdir) model = EvalModelTemplate() @@ -49,15 +53,19 @@ def test_no_val_module(tmpdir): # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, 'hparams.yaml') + ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path model_2 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path=new_weights_path, + checkpoint_path=ckpt_path, hparams_file=hparams_path ) model_2.eval() -def test_no_val_end_module(tmpdir): +@pytest.mark.parametrize('url_ckpt', [True, False]) +def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv('TORCH_HOME', tmpdir) model = EvalModelTemplate() @@ -82,8 +90,9 @@ def test_no_val_end_module(tmpdir): # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, 'hparams.yaml') + ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path model_2 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path=new_weights_path, + checkpoint_path=ckpt_path, hparams_file=hparams_path ) model_2.eval() @@ -320,8 +329,11 @@ def test_model_freeze_unfreeze(): model.unfreeze() -def test_resume_from_checkpoint_epoch_restored(tmpdir): +@pytest.mark.parametrize('url_ckpt', [True, False]) +def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv('TORCH_HOME', tmpdir) hparams = EvalModelTemplate.get_default_hparams() @@ -373,10 +385,14 @@ def increment_on_load_checkpoint(self, _): # Other checkpoints can be uncommented if/when resuming mid-epoch is supported checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) + if url_ckpt: + # transform local paths into url checkpoints + ip, port = tmpdir_server + checkpoints = [f'http://{ip}:{port}/' + os.path.basename(check) for check in checkpoints] for check in checkpoints: next_model = _new_model() - state = torch.load(check) + state = pl_load(check) # Resume training trainer_options['max_epochs'] = 2