From 22d9464e56fbf4a0a65b29171873cf9cca7adb92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 15 Jun 2020 23:04:32 +0200 Subject: [PATCH] HenryJia: auto-move data decorator (#1905) * First attempt at auto-moving data for inference * Correct my copypaste errors * Correct for if device is CPU * Get rid of the WIP code I accidentally added * Add tests * Make tests more foolproof * Make sure we stick with pep8 formatting * Clarify docs a little * Apply suggestions from code review * Get everything working again hopefully * refactor and added hook variant a variant b add test revert rename add changelog docs * move changelog entry to top * Move data transfer to utilities * Add back in warnings for autotransfer * Get rid of the test code I ended up accidentally commiting again * Add docs any changelog * Correct PR number in Changelog * Correct changelog * Update data.py * Update test_cpu.py * make a decorator * type hint * changelog * changelog * remove old function * import * test for decorator * fix test * remove old test * doctest * apply decorator directly * convert doctest to code block * prevent side effects in tests * fix merge * update forward docs * update docs * added docs in section "deployment / prediction" * update changelog Co-authored-by: Hengjian Jia Co-authored-by: Jirka Borovec Co-authored-by: William Falcon --- CHANGELOG.md | 1 + pytorch_lightning/core/decorators.py | 52 +++++++++++++++++++++++++++ pytorch_lightning/core/lightning.py | 3 ++ pytorch_lightning/trainer/__init__.py | 5 +++ tests/core/test_decorators.py | 33 +++++++++++++++++ 5 files changed, 94 insertions(+) create mode 100644 tests/core/test_decorators.py diff --git a/CHANGELOG.md b/CHANGELOG.md index efae799ae4262..f5fa80353cab8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115)) - Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667)) - Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134)) +- Added a decorator `auto_move_data` that moves data to the correct device when using the LightningModule for inference ([#1905](https://github.com/PyTorchLightning/pytorch-lightning/pull/1905)) ### Changed diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 3979a4fc6f7ee..8f2721201a124 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -1,3 +1,9 @@ +from functools import wraps +from typing import Callable + +import torch + +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn @@ -12,3 +18,49 @@ def data_loader(fn): def inner_fx(self): return fn(self) return inner_fx + + +def auto_move_data(fn: Callable) -> Callable: + """ + Decorator for :class:`~pytorch_lightning.core.lightning.LightningModule` methods for which + input arguments should be moved automatically to the correct device. + It as no effect if applied to a method of an object that is not an instance of + :class:`~pytorch_lightning.core.lightning.LightningModule` and is typically applied to ``__call__`` + or ``forward``. + + Args: + fn: A LightningModule method for which the arguments should be moved to the device + the parameters are on. + + Example: + + .. code-block:: python + + # directly in the source code + class LitModel(LightningModule): + + @auto_move_data + def forward(self, x): + return x + + # or outside + LitModel.forward = auto_move_data(LitModel.forward) + + model = LitModel() + model = model.to('cuda') + model(torch.zeros(1, 3)) + + # input gets moved to device + # tensor([[0., 0., 0.]], device='cuda:0') + + """ + @wraps(fn) + def auto_transfer_args(self, *args, **kwargs): + if not isinstance(self, LightningModule): + return fn(self, *args, **kwargs) + + args = self.transfer_batch_to_device(args, self.device) + kwargs = self.transfer_batch_to_device(kwargs, self.device) + return fn(self, *args, **kwargs) + + return auto_transfer_args diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index af7527f550d0e..75b08ddc7ac1b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -112,6 +112,9 @@ def forward(self, *args, **kwargs): This makes it easy to write a complex system for training with the outputs you'd want in a prediction setting. + You may also find the :func:`~pytorch_lightning.core.decorators.auto_move_data` decorator useful + when using the module outside Lightning in a production setting. + Args: *args: Whatever you decide to pass into the forward method. **kwargs: Keyword arguments are also possible. diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index ad59df49262b8..357613b0685c9 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -101,6 +101,11 @@ def forward(self, x): out = pretrained_model(x) api_write({'response': out} + +You may wish to run the model on a variety of devices. Instead of moving the data +manually to the correct device, decorate the forward method (or any other method you use for inference) +with :func:`~pytorch_lightning.core.decorators.auto_move_data` and Lightning will take care of the rest. + ------------ Reproducibility diff --git a/tests/core/test_decorators.py b/tests/core/test_decorators.py new file mode 100644 index 0000000000000..0f35a1630e1d9 --- /dev/null +++ b/tests/core/test_decorators.py @@ -0,0 +1,33 @@ +import pytest +import torch + +from tests.base import EvalModelTemplate +from pytorch_lightning.core.decorators import auto_move_data + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.parametrize(['src_device', 'dest_device'], [ + pytest.param(torch.device('cpu'), torch.device('cpu')), + pytest.param(torch.device('cpu', 0), torch.device('cuda', 0)), + pytest.param(torch.device('cuda', 0), torch.device('cpu')), + pytest.param(torch.device('cuda', 0), torch.device('cuda', 0)), +]) +def test_auto_move_data(src_device, dest_device): + """ Test that the decorator moves the data to the device the model is on. """ + + class CurrentModel(EvalModelTemplate): + pass + + # apply the decorator + CurrentModel.forward = auto_move_data(CurrentModel.forward) + + model = CurrentModel() + model = model.to(dest_device) + model.prepare_data() + loader = model.train_dataloader() + x, y, = next(iter(loader)) + x = x.flatten(1) + + # test that data on source device gets moved to destination device + x = x.to(src_device) + assert model(x).device == dest_device, "Automoving data to same device as model failed"