From 126e0e78ff99161c2e76e5c8eb0c3850a32eaf35 Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Wed, 29 Apr 2020 14:11:16 -0400 Subject: [PATCH 01/18] allow loading checkpoints from urls --- pytorch_lightning/core/saving.py | 5 +++-- pytorch_lightning/trainer/training_io.py | 3 ++- pytorch_lightning/utilities/io.py | 11 +++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 pytorch_lightning/utilities/io.py diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 1e399ff2ecfc4..3fcf982457bd3 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -10,6 +10,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.utilities import rank_zero_warn, AttributeDict +from pytorch_lightning.utilities.io import load as pl_load PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace) @@ -131,9 +132,9 @@ def load_from_checkpoint( y_hat = pretrained_model(x) """ if map_location is not None: - checkpoint = torch.load(checkpoint_path, map_location=map_location) + checkpoint = pl_load(checkpoint_path, map_location=map_location) else: - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # add the hparams from csv file to checkpoint if tags_csv is not None: diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index bd3d312a86882..7c6d429a980b6 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -102,6 +102,7 @@ LightningDataParallel, ) from pytorch_lightning.utilities import rank_zero_warn, parsing +from pytorch_lightning.utilities.io import load as pl_load try: import torch_xla @@ -287,7 +288,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool): # checkpoint = torch.load(checkpoint_path) # else: # load on CPU first - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # load model state model = self.get_model() diff --git a/pytorch_lightning/utilities/io.py b/pytorch_lightning/utilities/io.py new file mode 100644 index 0000000000000..f9d5112505cc4 --- /dev/null +++ b/pytorch_lightning/utilities/io.py @@ -0,0 +1,11 @@ +import torch + +from urllib.parse import urlparse + + +def load(path_or_url: str, map_location): + parsed = urlparse(path_or_url) + if parsed.scheme == '': + # local file + return torch.load(path_or_url, map_location=map_location) + return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) From 08fb93e5182ab98933f8be6ff32c119f5581982e Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sat, 30 May 2020 17:14:56 -0400 Subject: [PATCH 02/18] tmpdir_server fixture --- tests/conftest.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 8eb3444ddaaba..935b206b1da6a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ -from functools import wraps +from functools import wraps, partial +from http.server import SimpleHTTPRequestHandler, HTTPServer +import sys import pytest import torch.multiprocessing as mp @@ -17,3 +19,32 @@ def pytest_pyfunc_call(pyfuncitem): mp.spawn(wraps, (testfunction, testargs)) return True + + +def run_file_server(dir): + if sys.version_info >= (3, 7): + Handler = partial(SimpleHTTPRequestHandler, directory=dir) + else: + # unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6 + # so we have to hack it like this + import os + + class Handler(SimpleHTTPRequestHandler): + def translate_path(self, path): + # get the path from cwd + path = super().translate_path(path) + # get the relative path + relpath = os.path.relpath(path, os.getcwd()) + # return the full path from dir + return os.path.join(dir, relpath) + + with HTTPServer(('', 8000), Handler) as httpd: + httpd.serve_forever() + + +@pytest.fixture +def tmpdir_server(tmpdir): + p = mp.Process(target=run_file_server, args=(str(tmpdir),)) + p.start() + yield + p.terminate() From 50cb6d9bf3aeea4eb9329946675d49d6d99b7433 Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sat, 30 May 2020 17:18:44 -0400 Subject: [PATCH 03/18] test cases for loading checkpoints from url --- tests/trainer/test_trainer.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index bf4c8198e3db7..cedef1bfd68cd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -20,7 +20,7 @@ from tests.base import EvalModelTemplate -def test_no_val_module(tmpdir): +def test_no_val_module(tmpdir, tmpdir_server): """Tests use case where trainer saves the model, and user loads it from tags independently.""" model = EvalModelTemplate() @@ -55,8 +55,15 @@ def test_no_val_module(tmpdir): ) model_2.eval() + # load new model from url + model_3 = EvalModelTemplate.load_from_checkpoint( + checkpoint_path='http://localhost:8000/save_test.ckpt', + hparams_file=hparams_path + ) + model_3.eval() + -def test_no_val_end_module(tmpdir): +def test_no_val_end_module(tmpdir, tmpdir_server): """Tests use case where trainer saves the model, and user loads it from tags independently.""" model = EvalModelTemplate() @@ -88,6 +95,13 @@ def test_no_val_end_module(tmpdir): ) model_2.eval() + # load new model from url + model_3 = EvalModelTemplate.load_from_checkpoint( + checkpoint_path='http://localhost:8000/save_test.ckpt', + hparams_file=hparams_path + ) + model_3.eval() + def test_gradient_accumulation_scheduling(tmpdir): """ From 20166d95bcf3b296127d5ab7a3fc2e311ad0d642 Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sat, 30 May 2020 17:20:43 -0400 Subject: [PATCH 04/18] dir => root_dir --- tests/conftest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 935b206b1da6a..40cf42ce0b1c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,9 +21,9 @@ def pytest_pyfunc_call(pyfuncitem): return True -def run_file_server(dir): +def run_file_server(root_dir): if sys.version_info >= (3, 7): - Handler = partial(SimpleHTTPRequestHandler, directory=dir) + Handler = partial(SimpleHTTPRequestHandler, directory=root_dir) else: # unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6 # so we have to hack it like this @@ -35,8 +35,8 @@ def translate_path(self, path): path = super().translate_path(path) # get the relative path relpath = os.path.relpath(path, os.getcwd()) - # return the full path from dir - return os.path.join(dir, relpath) + # return the full path from root_dir + return os.path.join(root_dir, relpath) with HTTPServer(('', 8000), Handler) as httpd: httpd.serve_forever() From db7126355978ff0bd6da569a6de719010d638952 Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sat, 30 May 2020 17:31:22 -0400 Subject: [PATCH 05/18] default map_location to None --- pytorch_lightning/utilities/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/io.py b/pytorch_lightning/utilities/io.py index f9d5112505cc4..b2f9ffa788541 100644 --- a/pytorch_lightning/utilities/io.py +++ b/pytorch_lightning/utilities/io.py @@ -3,7 +3,7 @@ from urllib.parse import urlparse -def load(path_or_url: str, map_location): +def load(path_or_url: str, map_location=None): parsed = urlparse(path_or_url) if parsed.scheme == '': # local file From 03844be8ab5eba892a36256571bf3f7c25327e9c Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sat, 30 May 2020 17:31:52 -0400 Subject: [PATCH 06/18] test case for resume_from_checkpoint --- tests/trainer/test_trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index cedef1bfd68cd..9d31211438304 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -16,6 +16,7 @@ from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.trainer.logging import TrainerLoggingMixin +from pytorch_lightning.utilities.io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -334,7 +335,7 @@ def test_model_freeze_unfreeze(): model.unfreeze() -def test_resume_from_checkpoint_epoch_restored(tmpdir): +def test_resume_from_checkpoint_epoch_restored(tmpdir, tmpdir_server): """Verify resuming from checkpoint runs the right number of epochs""" hparams = EvalModelTemplate.get_default_hparams() @@ -387,10 +388,12 @@ def increment_on_load_checkpoint(self, _): # Other checkpoints can be uncommented if/when resuming mid-epoch is supported checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) + # add some url checkpoints + checkpoints += ['http://localhost:8000/' + os.path.basename(check) for check in checkpoints] for check in checkpoints: next_model = _new_model() - state = torch.load(check) + state = pl_load(check) # Resume training trainer_options['max_epochs'] = 2 From 99cf5803e1a77a49c26686ddaa58c019e9c58808 Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sat, 30 May 2020 17:33:58 -0400 Subject: [PATCH 07/18] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 80070aeb3dde6..2f8eb14a83f10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610)) - Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115)) +- Added loading checkpoings from URLs ([#1532](https://github.com/PyTorchLightning/pytorch-lightning/issues/1532)) + ### Changed - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) From e10befd4d15e22ee30d464ebfedd926c46a3f5ab Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sat, 30 May 2020 17:39:01 -0400 Subject: [PATCH 08/18] doc update --- pytorch_lightning/core/saving.py | 3 +-- pytorch_lightning/trainer/trainer.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 3fcf982457bd3..092424f870d90 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -56,7 +56,7 @@ def load_from_checkpoint( Any arguments specified through \*args and \*\*kwargs will override args stored in `module_arguments`. Args: - checkpoint_path: Path to checkpoint. + checkpoint_path: Path to checkpoint. This can also be a URL. args: Any positional args needed to init the model. map_location: If your checkpoint saved a GPU model and you now load on CPUs @@ -163,7 +163,6 @@ def load_from_checkpoint( @classmethod def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs): - # pass in the values we saved automatically if cls.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint: # todo add some back compatibility diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1cf216eac345d..0fb76acd1a700 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -279,6 +279,7 @@ def __init__( truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here. + This can be a URL. profiler: To profile individual steps during training and assist in From 2e7d93397dbf48e394f4646902e4359e36a6009f Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sun, 31 May 2020 11:31:11 -0400 Subject: [PATCH 09/18] monkeypatch TORCH_HOME to avoid caching --- tests/trainer/test_trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9d31211438304..78b6a543779b2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -21,8 +21,10 @@ from tests.base import EvalModelTemplate -def test_no_val_module(tmpdir, tmpdir_server): +def test_no_val_module(monkeypatch, tmpdir, tmpdir_server): """Tests use case where trainer saves the model, and user loads it from tags independently.""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv('TORCH_HOME', tmpdir) model = EvalModelTemplate() @@ -64,8 +66,10 @@ def test_no_val_module(tmpdir, tmpdir_server): model_3.eval() -def test_no_val_end_module(tmpdir, tmpdir_server): +def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server): """Tests use case where trainer saves the model, and user loads it from tags independently.""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv('TORCH_HOME', tmpdir) model = EvalModelTemplate() @@ -335,8 +339,10 @@ def test_model_freeze_unfreeze(): model.unfreeze() -def test_resume_from_checkpoint_epoch_restored(tmpdir, tmpdir_server): +def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server): """Verify resuming from checkpoint runs the right number of epochs""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv('TORCH_HOME', tmpdir) hparams = EvalModelTemplate.get_default_hparams() From 09fb74cbc4a25fcba8fd78aa6e216a85e75cb454 Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sun, 31 May 2020 13:16:08 -0400 Subject: [PATCH 10/18] Use a threading server with random ports so that it is easier to clean up --- tests/conftest.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 40cf42ce0b1c4..627a7837e376e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ from functools import wraps, partial -from http.server import SimpleHTTPRequestHandler, HTTPServer +from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer import sys import pytest +import threading import torch.multiprocessing as mp @@ -21,9 +22,10 @@ def pytest_pyfunc_call(pyfuncitem): return True -def run_file_server(root_dir): +@pytest.fixture +def tmpdir_server(tmpdir): if sys.version_info >= (3, 7): - Handler = partial(SimpleHTTPRequestHandler, directory=root_dir) + Handler = partial(SimpleHTTPRequestHandler, directory=str(tmpdir)) else: # unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6 # so we have to hack it like this @@ -36,15 +38,12 @@ def translate_path(self, path): # get the relative path relpath = os.path.relpath(path, os.getcwd()) # return the full path from root_dir - return os.path.join(root_dir, relpath) - - with HTTPServer(('', 8000), Handler) as httpd: - httpd.serve_forever() - - -@pytest.fixture -def tmpdir_server(tmpdir): - p = mp.Process(target=run_file_server, args=(str(tmpdir),)) - p.start() - yield - p.terminate() + return os.path.join(str(tmpdir), relpath) + + with ThreadingHTTPServer(('', 0), Handler) as server: + server_thread = threading.Thread(target=server.serve_forever) + # Exit the server thread when the main thread terminates + server_thread.daemon = True + server_thread.start() + yield server.server_address + server.shutdown() From d88340124a22870eaf3529fc8d968fcf072d300b Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sun, 31 May 2020 13:16:16 -0400 Subject: [PATCH 11/18] test fixes --- tests/trainer/test_trainer.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 78b6a543779b2..d376823484f76 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -60,7 +60,7 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server): # load new model from url model_3 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path='http://localhost:8000/save_test.ckpt', + checkpoint_path=f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/save_test.ckpt', hparams_file=hparams_path ) model_3.eval() @@ -102,7 +102,7 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server): # load new model from url model_3 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path='http://localhost:8000/save_test.ckpt', + checkpoint_path=f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/save_test.ckpt', hparams_file=hparams_path ) model_3.eval() @@ -339,7 +339,8 @@ def test_model_freeze_unfreeze(): model.unfreeze() -def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server): +@pytest.mark.parametrize('url_ckpt', [True, False]) +def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv('TORCH_HOME', tmpdir) @@ -394,8 +395,10 @@ def increment_on_load_checkpoint(self, _): # Other checkpoints can be uncommented if/when resuming mid-epoch is supported checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) - # add some url checkpoints - checkpoints += ['http://localhost:8000/' + os.path.basename(check) for check in checkpoints] + if url_ckpt: + # transform local paths into url checkpoints + checkpoints = [f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/' + + os.path.basename(check) for check in checkpoints] for check in checkpoints: next_model = _new_model() From 18d7825cf2f99eec599ed7a5d2ac7fa36dbc3d06 Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sun, 31 May 2020 13:20:09 -0400 Subject: [PATCH 12/18] pep8 fix --- tests/trainer/test_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d376823484f76..934eb5f8c59cf 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -397,8 +397,9 @@ def increment_on_load_checkpoint(self, _): checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) if url_ckpt: # transform local paths into url checkpoints - checkpoints = [f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/' + - os.path.basename(check) for check in checkpoints] + checkpoints = [ + f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/' + os.path.basename(check) for check in checkpoints + ] for check in checkpoints: next_model = _new_model() From b71f69abd52b0f10fd204330f2174b18abd91cee Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sun, 31 May 2020 13:26:01 -0400 Subject: [PATCH 13/18] ThreadingHTTPServer support in 3.6 --- tests/conftest.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 627a7837e376e..216bbe171d018 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ from functools import wraps, partial -from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer +from http.server import SimpleHTTPRequestHandler import sys import pytest @@ -26,6 +26,7 @@ def pytest_pyfunc_call(pyfuncitem): def tmpdir_server(tmpdir): if sys.version_info >= (3, 7): Handler = partial(SimpleHTTPRequestHandler, directory=str(tmpdir)) + from http.server import ThreadingHTTPServer else: # unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6 # so we have to hack it like this @@ -40,6 +41,13 @@ def translate_path(self, path): # return the full path from root_dir return os.path.join(str(tmpdir), relpath) + # ThreadingHTTPServer was added in 3.7, so we need to define it ourselves + from socketserver import ThreadingMixIn + from http.server import HTTPServer + + class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): + daemon_threads = True + with ThreadingHTTPServer(('', 0), Handler) as server: server_thread = threading.Thread(target=server.serve_forever) # Exit the server thread when the main thread terminates From 3e4b99d900855cbd9e7f8795026fded3fd93e56d Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Sun, 31 May 2020 13:26:51 -0400 Subject: [PATCH 14/18] pep8 fix --- tests/trainer/test_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 934eb5f8c59cf..000679f0b4420 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -397,8 +397,9 @@ def increment_on_load_checkpoint(self, _): checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) if url_ckpt: # transform local paths into url checkpoints + ip, port = tmpdir_server checkpoints = [ - f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/' + os.path.basename(check) for check in checkpoints + f'http://{ip}:{port}/' + os.path.basename(check) for check in checkpoints ] for check in checkpoints: From ddce3bb1315b1b0c13bd06ebe89737cdc355ed1e Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Mon, 8 Jun 2020 10:05:01 -0400 Subject: [PATCH 15/18] fix changelog --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f8eb14a83f10..b0b00fda7ff46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,8 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)) - Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610)) - Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115)) - -- Added loading checkpoings from URLs ([#1532](https://github.com/PyTorchLightning/pytorch-lightning/issues/1532)) +- Added loading checkpoings from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667)) ### Changed From f25ccdd0a562d343dea695584add01aca0d2a7cf Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Mon, 8 Jun 2020 10:13:07 -0400 Subject: [PATCH 16/18] separate tests for urls --- tests/trainer/test_trainer.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 000679f0b4420..c84d2de237a72 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -21,7 +21,8 @@ from tests.base import EvalModelTemplate -def test_no_val_module(monkeypatch, tmpdir, tmpdir_server): +@pytest.mark.parametrize('url_ckpt', [True, False]) +def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv('TORCH_HOME', tmpdir) @@ -52,21 +53,16 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server): # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, 'hparams.yaml') + ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path model_2 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path=new_weights_path, + checkpoint_path=ckpt_path, hparams_file=hparams_path ) model_2.eval() - # load new model from url - model_3 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path=f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/save_test.ckpt', - hparams_file=hparams_path - ) - model_3.eval() - -def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server): +@pytest.mark.parametrize('url_ckpt', [True, False]) +def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv('TORCH_HOME', tmpdir) @@ -94,19 +90,13 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server): # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, 'hparams.yaml') + ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path model_2 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path=new_weights_path, + checkpoint_path=ckpt_path, hparams_file=hparams_path ) model_2.eval() - # load new model from url - model_3 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path=f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/save_test.ckpt', - hparams_file=hparams_path - ) - model_3.eval() - def test_gradient_accumulation_scheduling(tmpdir): """ From bdd627eba2cf6387507837ba041fd10d2f1a52ba Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 10 Jun 2020 00:52:39 +0200 Subject: [PATCH 17/18] typo Co-authored-by: Peter Yu <2057325+yukw777@users.noreply.github.com> --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b0b00fda7ff46..801ebd6af74b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)) - Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610)) - Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115)) -- Added loading checkpoings from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667)) +- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667)) ### Changed From ae76d52741dd0546ef010eea02e6018002b7a3c1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 11 Jun 2020 21:52:23 +0200 Subject: [PATCH 18/18] Apply suggestions from code review --- pytorch_lightning/core/saving.py | 2 +- tests/trainer/test_trainer.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 092424f870d90..ed299bb8a816e 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -53,7 +53,7 @@ def load_from_checkpoint( Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to `__init__` in the checkpoint under `module_arguments` - Any arguments specified through \*args and \*\*kwargs will override args stored in `module_arguments`. + Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`. Args: checkpoint_path: Path to checkpoint. This can also be a URL. diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c84d2de237a72..caf1632ec3474 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -388,9 +388,7 @@ def increment_on_load_checkpoint(self, _): if url_ckpt: # transform local paths into url checkpoints ip, port = tmpdir_server - checkpoints = [ - f'http://{ip}:{port}/' + os.path.basename(check) for check in checkpoints - ] + checkpoints = [f'http://{ip}:{port}/' + os.path.basename(check) for check in checkpoints] for check in checkpoints: next_model = _new_model()