Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HenryJia: auto-move data decorator #1905

Merged
merged 39 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
29f3a2a
refactor and added hook
awaelchli May 14, 2020
6727067
move changelog entry to top
awaelchli May 20, 2020
31d60de
First attempt at auto-moving data for inference
HenryJia Apr 19, 2020
678e580
Correct my copypaste errors
HenryJia Apr 19, 2020
e5e60a4
Correct for if device is CPU
HenryJia Apr 19, 2020
e4dfbaf
Get rid of the WIP code I accidentally added
HenryJia Apr 19, 2020
0ba04ba
Add tests
HenryJia Apr 19, 2020
f50d1ea
Make tests more foolproof
HenryJia Apr 19, 2020
d7e64f4
Make sure we stick with pep8 formatting
HenryJia Apr 19, 2020
ba0ddae
Clarify docs a little
HenryJia Apr 19, 2020
cbfd614
Apply suggestions from code review
Borda Apr 19, 2020
d2ebd27
Get everything working again hopefully
HenryJia Apr 19, 2020
18267b8
Move data transfer to utilities
HenryJia Apr 20, 2020
a539cc7
Add back in warnings for autotransfer
HenryJia Apr 20, 2020
a6500db
Get rid of the test code I ended up accidentally commiting again
HenryJia Apr 20, 2020
531124a
Add docs any changelog
HenryJia Apr 20, 2020
8996c4c
Correct PR number in Changelog
HenryJia Apr 20, 2020
2c50252
Correct changelog
HenryJia Apr 22, 2020
8159ed8
Update data.py
williamFalcon Apr 30, 2020
b485cd8
Update test_cpu.py
williamFalcon May 17, 2020
f8218e0
make a decorator
awaelchli May 20, 2020
d01934c
type hint
awaelchli May 20, 2020
16d5460
changelog
awaelchli May 20, 2020
f19e6e3
changelog
awaelchli May 20, 2020
a035785
remove old function
awaelchli May 20, 2020
2d65d80
import
awaelchli May 21, 2020
6e6da53
test for decorator
awaelchli May 21, 2020
454d7e2
fix test
awaelchli May 21, 2020
d8e1bd7
remove old test
awaelchli May 21, 2020
ecc1d6e
doctest
awaelchli Jun 6, 2020
0dd8985
apply decorator directly
awaelchli Jun 7, 2020
a1ddb86
convert doctest to code block
awaelchli Jun 8, 2020
cc4cb6b
prevent side effects in tests
awaelchli Jun 8, 2020
42fc1b8
fix merge
awaelchli Jun 8, 2020
587a2b2
update forward docs
awaelchli Jun 8, 2020
649c6c9
update docs
awaelchli Jun 8, 2020
8099b0a
added docs in section "deployment / prediction"
awaelchli Jun 15, 2020
6b06b4d
Merge branch 'master' into feat-auto-move-data
awaelchli Jun 15, 2020
f5f9e75
update changelog
awaelchli Jun 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))
- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134))
- Added a decorator `auto_move_data` that moves data to the correct device when using the LightningModule for inference ([#1905](https://github.com/PyTorchLightning/pytorch-lightning/pull/1905))

### Changed

Expand Down
52 changes: 52 additions & 0 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from functools import wraps
from typing import Callable

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn


Expand All @@ -12,3 +18,49 @@ def data_loader(fn):
def inner_fx(self):
return fn(self)
return inner_fx


def auto_move_data(fn: Callable) -> Callable:
"""
Decorator for :class:`~pytorch_lightning.core.lightning.LightningModule` methods for which
input arguments should be moved automatically to the correct device.
It as no effect if applied to a method of an object that is not an instance of
:class:`~pytorch_lightning.core.lightning.LightningModule` and is typically applied to ``__call__``
or ``forward``.

Args:
fn: A LightningModule method for which the arguments should be moved to the device
the parameters are on.

Example:

.. code-block:: python

# directly in the source code
class LitModel(LightningModule):

@auto_move_data
def forward(self, x):
return x

# or outside
LitModel.forward = auto_move_data(LitModel.forward)

model = LitModel()
model = model.to('cuda')
model(torch.zeros(1, 3))

# input gets moved to device
# tensor([[0., 0., 0.]], device='cuda:0')

"""
@wraps(fn)
def auto_transfer_args(self, *args, **kwargs):
if not isinstance(self, LightningModule):
return fn(self, *args, **kwargs)

args = self.transfer_batch_to_device(args, self.device)
kwargs = self.transfer_batch_to_device(kwargs, self.device)
return fn(self, *args, **kwargs)

return auto_transfer_args
3 changes: 3 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def forward(self, *args, **kwargs):
This makes it easy to write a complex system for training with the outputs
you'd want in a prediction setting.

You may also find the :func:`~pytorch_lightning.core.decorators.auto_move_data` decorator useful
when using the module outside Lightning in a production setting.

Args:
*args: Whatever you decide to pass into the forward method.
**kwargs: Keyword arguments are also possible.
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def forward(self, x):
out = pretrained_model(x)
api_write({'response': out}


You may wish to run the model on a variety of devices. Instead of moving the data
manually to the correct device, decorate the forward method (or any other method you use for inference)
with :func:`~pytorch_lightning.core.decorators.auto_move_data` and Lightning will take care of the rest.

------------

Reproducibility
Expand Down
33 changes: 33 additions & 0 deletions tests/core/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
import torch

from tests.base import EvalModelTemplate
from pytorch_lightning.core.decorators import auto_move_data


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.parametrize(['src_device', 'dest_device'], [
pytest.param(torch.device('cpu'), torch.device('cpu')),
pytest.param(torch.device('cpu', 0), torch.device('cuda', 0)),
pytest.param(torch.device('cuda', 0), torch.device('cpu')),
pytest.param(torch.device('cuda', 0), torch.device('cuda', 0)),
])
def test_auto_move_data(src_device, dest_device):
""" Test that the decorator moves the data to the device the model is on. """

class CurrentModel(EvalModelTemplate):
pass

# apply the decorator
CurrentModel.forward = auto_move_data(CurrentModel.forward)

model = CurrentModel()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
model = model.to(dest_device)
model.prepare_data()
loader = model.train_dataloader()
x, y, = next(iter(loader))
x = x.flatten(1)

# test that data on source device gets moved to destination device
x = x.to(src_device)
assert model(x).device == dest_device, "Automoving data to same device as model failed"