Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow loading checkpoints from urls #1667

Merged
merged 18 commits into from
Jun 11, 2020
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756))
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))

### Changed

Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.io import load as pl_load

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
Expand Down Expand Up @@ -52,10 +53,10 @@ def load_from_checkpoint(
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to `__init__` in the checkpoint under `module_arguments`

Any arguments specified through \*args and \*\*kwargs will override args stored in `module_arguments`.
Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`.

Args:
checkpoint_path: Path to checkpoint.
checkpoint_path: Path to checkpoint. This can also be a URL.
Borda marked this conversation as resolved.
Show resolved Hide resolved
args: Any positional args needed to init the model.
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
Expand Down Expand Up @@ -131,9 +132,9 @@ def load_from_checkpoint(
y_hat = pretrained_model(x)
"""
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
checkpoint = pl_load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# add the hparams from csv file to checkpoint
if tags_csv is not None:
Expand Down Expand Up @@ -162,7 +163,6 @@ def load_from_checkpoint(

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):

# pass in the values we saved automatically
if cls.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
# todo add some back compatibility
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def __init__(
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of

resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
This can be a URL.

profiler: To profile individual steps during training and assist in

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
LightningDataParallel,
)
from pytorch_lightning.utilities import rank_zero_warn, parsing
from pytorch_lightning.utilities.io import load as pl_load

try:
import torch_xla
Expand Down Expand Up @@ -287,7 +288,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# checkpoint = torch.load(checkpoint_path)
# else:
# load on CPU first
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# load model state
model = self.get_model()
Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/utilities/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch

from urllib.parse import urlparse


def load(path_or_url: str, map_location=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is type of map_location

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't really specify this b/c it can be a lot of different things: #1505

Welcome any suggestions.

parsed = urlparse(path_or_url)
if parsed.scheme == '':
# local file
return torch.load(path_or_url, map_location=map_location)
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
40 changes: 39 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from functools import wraps
from functools import wraps, partial
from http.server import SimpleHTTPRequestHandler

import sys
import pytest
import threading
import torch.multiprocessing as mp


Expand All @@ -17,3 +20,38 @@ def pytest_pyfunc_call(pyfuncitem):

mp.spawn(wraps, (testfunction, testargs))
return True


@pytest.fixture
def tmpdir_server(tmpdir):
if sys.version_info >= (3, 7):
Handler = partial(SimpleHTTPRequestHandler, directory=str(tmpdir))
from http.server import ThreadingHTTPServer
else:
# unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6
# so we have to hack it like this
import os

class Handler(SimpleHTTPRequestHandler):
def translate_path(self, path):
# get the path from cwd
path = super().translate_path(path)
# get the relative path
relpath = os.path.relpath(path, os.getcwd())
# return the full path from root_dir
return os.path.join(str(tmpdir), relpath)

# ThreadingHTTPServer was added in 3.7, so we need to define it ourselves
from socketserver import ThreadingMixIn
from http.server import HTTPServer

class ThreadingHTTPServer(ThreadingMixIn, HTTPServer):
daemon_threads = True

with ThreadingHTTPServer(('', 0), Handler) as server:
server_thread = threading.Thread(target=server.serve_forever)
# Exit the server thread when the main thread terminates
server_thread.daemon = True
server_thread.start()
yield server.server_address
server.shutdown()
28 changes: 22 additions & 6 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate


def test_no_val_module(tmpdir):
@pytest.mark.parametrize('url_ckpt', [True, False])
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)

model = EvalModelTemplate()

Expand Down Expand Up @@ -49,15 +53,19 @@ def test_no_val_module(tmpdir):
# load new model
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path
model_2 = EvalModelTemplate.load_from_checkpoint(
checkpoint_path=new_weights_path,
checkpoint_path=ckpt_path,
hparams_file=hparams_path
)
model_2.eval()


def test_no_val_end_module(tmpdir):
@pytest.mark.parametrize('url_ckpt', [True, False])
def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)

model = EvalModelTemplate()

Expand All @@ -82,8 +90,9 @@ def test_no_val_end_module(tmpdir):
# load new model
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path
model_2 = EvalModelTemplate.load_from_checkpoint(
checkpoint_path=new_weights_path,
checkpoint_path=ckpt_path,
hparams_file=hparams_path
)
model_2.eval()
Expand Down Expand Up @@ -320,8 +329,11 @@ def test_model_freeze_unfreeze():
model.unfreeze()


def test_resume_from_checkpoint_epoch_restored(tmpdir):
@pytest.mark.parametrize('url_ckpt', [True, False])
def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Verify resuming from checkpoint runs the right number of epochs"""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)

hparams = EvalModelTemplate.get_default_hparams()

Expand Down Expand Up @@ -373,10 +385,14 @@ def increment_on_load_checkpoint(self, _):

# Other checkpoints can be uncommented if/when resuming mid-epoch is supported
checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))
if url_ckpt:
# transform local paths into url checkpoints
ip, port = tmpdir_server
checkpoints = [f'http://{ip}:{port}/' + os.path.basename(check) for check in checkpoints]

for check in checkpoints:
next_model = _new_model()
state = torch.load(check)
state = pl_load(check)

# Resume training
trainer_options['max_epochs'] = 2
Expand Down