Skip to content

Commit

Permalink
Allow loading checkpoints from urls (#1667)
Browse files Browse the repository at this point in the history
* allow loading checkpoints from urls

* tmpdir_server fixture

* test cases for loading checkpoints from url

* dir => root_dir

* default map_location to None

* test case for resume_from_checkpoint

* changelog

* doc update

* monkeypatch TORCH_HOME to avoid caching

* Use a threading server with random ports so that it is easier to clean up

* test fixes

* pep8 fix

* ThreadingHTTPServer support in 3.6

* pep8 fix

* fix changelog

* separate tests for urls

* typo

Co-authored-by: Peter Yu <2057325+yukw777@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
yukw777 and Borda committed Jun 11, 2020
1 parent bd49b07 commit 06cd849
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756))
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))

### Changed

Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.io import load as pl_load

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
Expand Down Expand Up @@ -52,10 +53,10 @@ def load_from_checkpoint(
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to `__init__` in the checkpoint under `module_arguments`
Any arguments specified through \*args and \*\*kwargs will override args stored in `module_arguments`.
Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`.
Args:
checkpoint_path: Path to checkpoint.
checkpoint_path: Path to checkpoint. This can also be a URL.
args: Any positional args needed to init the model.
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
Expand Down Expand Up @@ -131,9 +132,9 @@ def load_from_checkpoint(
y_hat = pretrained_model(x)
"""
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
checkpoint = pl_load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# add the hparams from csv file to checkpoint
if tags_csv is not None:
Expand Down Expand Up @@ -162,7 +163,6 @@ def load_from_checkpoint(

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):

# pass in the values we saved automatically
if cls.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
# todo add some back compatibility
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def __init__(
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
This can be a URL.
profiler: To profile individual steps during training and assist in
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
LightningDataParallel,
)
from pytorch_lightning.utilities import rank_zero_warn, parsing
from pytorch_lightning.utilities.io import load as pl_load

try:
import torch_xla
Expand Down Expand Up @@ -287,7 +288,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# checkpoint = torch.load(checkpoint_path)
# else:
# load on CPU first
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# load model state
model = self.get_model()
Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/utilities/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch

from urllib.parse import urlparse


def load(path_or_url: str, map_location=None):
parsed = urlparse(path_or_url)
if parsed.scheme == '':
# local file
return torch.load(path_or_url, map_location=map_location)
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
40 changes: 39 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from functools import wraps
from functools import wraps, partial
from http.server import SimpleHTTPRequestHandler

import sys
import pytest
import threading
import torch.multiprocessing as mp


Expand All @@ -17,3 +20,38 @@ def pytest_pyfunc_call(pyfuncitem):

mp.spawn(wraps, (testfunction, testargs))
return True


@pytest.fixture
def tmpdir_server(tmpdir):
if sys.version_info >= (3, 7):
Handler = partial(SimpleHTTPRequestHandler, directory=str(tmpdir))
from http.server import ThreadingHTTPServer
else:
# unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6
# so we have to hack it like this
import os

class Handler(SimpleHTTPRequestHandler):
def translate_path(self, path):
# get the path from cwd
path = super().translate_path(path)
# get the relative path
relpath = os.path.relpath(path, os.getcwd())
# return the full path from root_dir
return os.path.join(str(tmpdir), relpath)

# ThreadingHTTPServer was added in 3.7, so we need to define it ourselves
from socketserver import ThreadingMixIn
from http.server import HTTPServer

class ThreadingHTTPServer(ThreadingMixIn, HTTPServer):
daemon_threads = True

with ThreadingHTTPServer(('', 0), Handler) as server:
server_thread = threading.Thread(target=server.serve_forever)
# Exit the server thread when the main thread terminates
server_thread.daemon = True
server_thread.start()
yield server.server_address
server.shutdown()
28 changes: 22 additions & 6 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate


def test_no_val_module(tmpdir):
@pytest.mark.parametrize('url_ckpt', [True, False])
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)

model = EvalModelTemplate()

Expand Down Expand Up @@ -49,15 +53,19 @@ def test_no_val_module(tmpdir):
# load new model
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path
model_2 = EvalModelTemplate.load_from_checkpoint(
checkpoint_path=new_weights_path,
checkpoint_path=ckpt_path,
hparams_file=hparams_path
)
model_2.eval()


def test_no_val_end_module(tmpdir):
@pytest.mark.parametrize('url_ckpt', [True, False])
def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)

model = EvalModelTemplate()

Expand All @@ -82,8 +90,9 @@ def test_no_val_end_module(tmpdir):
# load new model
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path
model_2 = EvalModelTemplate.load_from_checkpoint(
checkpoint_path=new_weights_path,
checkpoint_path=ckpt_path,
hparams_file=hparams_path
)
model_2.eval()
Expand Down Expand Up @@ -320,8 +329,11 @@ def test_model_freeze_unfreeze():
model.unfreeze()


def test_resume_from_checkpoint_epoch_restored(tmpdir):
@pytest.mark.parametrize('url_ckpt', [True, False])
def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Verify resuming from checkpoint runs the right number of epochs"""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)

hparams = EvalModelTemplate.get_default_hparams()

Expand Down Expand Up @@ -373,10 +385,14 @@ def increment_on_load_checkpoint(self, _):

# Other checkpoints can be uncommented if/when resuming mid-epoch is supported
checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))
if url_ckpt:
# transform local paths into url checkpoints
ip, port = tmpdir_server
checkpoints = [f'http://{ip}:{port}/' + os.path.basename(check) for check in checkpoints]

for check in checkpoints:
next_model = _new_model()
state = torch.load(check)
state = pl_load(check)

# Resume training
trainer_options['max_epochs'] = 2
Expand Down

0 comments on commit 06cd849

Please sign in to comment.