Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[blocked by #1756] Add decorator to auto-move data for inference #1526

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Enable `NeptuneLogger` to work with `distributed_backend=ddp` ([#1753](https://github.com/PyTorchLightning/pytorch-lightning/pull/1753))

- Added automatic GPU data transfer to single GPU and CPU inference ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1526))

### Changed

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
Expand Down
33 changes: 32 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn

from pytorch_lightning.utilities import transfer_data_to_device
try:
import torch_xla.core.xla_model as xm
except ImportError:
Expand Down Expand Up @@ -94,11 +94,42 @@ def forward(self, x):
if self.trainer.proc_rank == 0:
print(*args, **kwargs)

def __call__(self, *data, **kwargs):
r"""
Automatically moves data to correct device if possible, then call torch.nn.Module.__call__
Lightning will warn you if it automatically moves any data

Args:
*data: Any positional arguments for torch.nn.Module.__call__. These are typically input data
**kwargs: Any keyword arguments for torch.nn.Module.__call__

Example:

.. code-block:: python

model = model.cuda(0)
model.prepare_data()
loader = model.train_dataloader()
for x, y in loader:
output = model(x) # Lightning will automove data here and warn you of it

"""
devices = [p.device for p in self.parameters()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, if looping always over all the params is a good idea. Can we maybe cache the devices somehow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcarilli are we missing a simpler/cleaner way of doing this?

x = x.cpu()
model.cuda()

# this works in lightning
out = model(x)

# out is cuda tensor now

Copy link

@mcarilli mcarilli Apr 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I don't understand what you're trying to do here, but it looks like you're only using device[0], so why collect them all? Also if the model params do reside on multiple devices, it's hard to predict which device the user actually wants the input data to reside on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I only apply automatic data transfer if we are dealing with the simple case of the model residing on one device, as trying to auto transfer data when the model is spread across multiple devices is very non-trivial and is heavily dependent on model structure

# All parameters must be on same device to automove data
# Otherwise we just do what nn.Module does normally
if len(set(devices)) == 1:
device = devices[0]
data = transfer_data_to_device(data, device.type, device.index, warn_on_transfer=True)
kwargs = transfer_data_to_device(kwargs, device.type, device.index, warn_on_transfer=True)
return super(LightningModule, self).__call__(*data, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda and I discussed this and we both agree, that we shouldn't do this in the Module (at least not by default). In our opinion we should always be able to use lightning module as nn module.

What I propose is the following:

We change this part to a decorator, that can be added to forward and is automatically added from trainer. (sorry for that coming so late).

In that case you could make the decorator a class, that also caches devices eventually :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the idea of adding the decorator dynamically there!


@abstractmethod
def forward(self, *args, **kwargs):
r"""
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define
the operations you want to use for prediction (i.e.: on a server or as a feature extractor).
LightningModule will also automatically copy data to the same device as the model if the model
is on CPU or a single GPU for inference.

Normally you'd call ``self()`` from your :meth:`training_step` method.
This makes it easy to write a complex system for training with the outputs
Expand Down
47 changes: 3 additions & 44 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.data import transfer_data_to_device

try:
from apex import amp
Expand Down Expand Up @@ -435,52 +436,10 @@ def copy_trainer_model_properties(self, model):
m.device = self.device

def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')
return transfer_data_to_device(batch, device_type='tpu')

def transfer_batch_to_gpu(self, batch, gpu_id):
return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id)

def __transfer_data_to_device(self, batch, device, gpu_id=None):
if device == 'tpu' and XLA_AVAILABLE:
# base case: object can be directly moved using `to`
if callable(getattr(batch, 'to', None)):
return batch.to(xm.xla_device())

if device == 'gpu':
# base case: object can be directly moved using `cuda` or `to`
if callable(getattr(batch, 'cuda', None)):
return batch.cuda(gpu_id)

if callable(getattr(batch, 'to', None)):
return batch.to(torch.device('cuda', gpu_id))

# when list
if isinstance(batch, list):
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return batch

# when tuple
if isinstance(batch, tuple):
# when namedtuple
if hasattr(batch, '_fields'):
elem_type = type(batch)
return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch))
else:
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return tuple(batch)

# when dict
if isinstance(batch, dict):
for k, v in batch.items():
batch[k] = self.__transfer_data_to_device(v, device, gpu_id)

return batch

# nothing matches, return the value as is without transform
return batch
return transfer_data_to_device(batch, device_type='cuda', idx=gpu_id)

def single_gpu_train(self, model):
model.cuda(self.root_gpu)
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""General utilities"""

from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.data import transfer_data_to_device
56 changes: 56 additions & 0 deletions pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch

from pytorch_lightning.utilities import rank_zero_warn

try:
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


def transfer_data_to_device(batch, device_type, idx=None, warn_on_transfer=False):
"""
Utility function to copy data to given device
Works for any form of nested lists, tuples or dictionaries containting tensors
Deal with TPUs separately, they don't use device indexes for some reason
"""
if device_type == 'tpu' and XLA_AVAILABLE:
if callable(getattr(batch, 'to', None)):
if warn_on_transfer:
rank_zero_warn('Auto transferred data to device {}'.format(xm.xla_device()))
return batch.to(xm.xla_device())

# base case: nothing to do
device = torch.device(device_type, idx)
if torch.is_tensor(batch) and batch.device == device:
return batch

# object can be directly moved using `cuda` or `to`
if callable(getattr(batch, 'cuda', None)) and device_type == 'cuda':
if warn_on_transfer:
rank_zero_warn('Auto transferred data to device {}'.format(device))
return batch.cuda(device=device)

if callable(getattr(batch, 'to', None)):
if warn_on_transfer:
rank_zero_warn('Auto transferred data to device {}'.format(device))
Copy link
Member

@awaelchli awaelchli Apr 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move the warning out to the __call__ , because 1. this utility function is more general (it is used in other parts) and 2. this function is recursive, so if a dict of tensors is passed in, the warning would be shown multiple times.

Copy link
Contributor Author

@HenryJia HenryJia Apr 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli I thought about this, a problem with that is I effectively run into the same code duplication problem by trying to detect which device it's on in __call__ since I'd need to recurse into whatever format the data is in again in almost the exact same way
Also, I believe rank_zero_warn will only warn once anyway so that is not an issue

return batch.to(device=device)

# when list or tuple
if isinstance(batch, (list, tuple)):
if isinstance(batch, tuple):
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = transfer_data_to_device(x, device_type, idx, warn_on_transfer)
return batch

# when dict
if isinstance(batch, dict):
for k, v in batch.items():
batch[k] = transfer_data_to_device(v, device_type, idx, warn_on_transfer)
return batch

# nothing matches, return the value as is without transform
return batch
17 changes: 17 additions & 0 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,20 @@ def test_single_gpu_model(tmpdir):

model = EvalModelTemplate()
tutils.run_model_test(trainer_options, model)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_auto_move_data(tmpdir):
"""Make sure auto moving data works for the base case where it doesn't have to move anything"""

tutils.reset_seed()
tutils.set_random_master_port()

model, hparams = tutils.get_default_model()
model.prepare_data()
loader = model.train_dataloader()
for x, y in loader:
x = x.view(x.size(0), -1)
assert model(x).device == torch.device('cpu'), "Automoving data to same device as model failed"
x = x.cuda(0)
assert model(x).device == torch.device('cpu'), "Automoving data to same device as model failed"
18 changes: 18 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,24 @@ def test_multi_gpu_none_backend(tmpdir):
tutils.run_model_test(trainer_options, model)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_auto_move_data(tmpdir):
"""Make sure auto moving data works"""

tutils.reset_seed()
tutils.set_random_master_port()

model, hparams = tutils.get_default_model()
model = model.cuda(0)
model.prepare_data()
loader = model.train_dataloader()
for x, y in loader:
x = x.view(x.size(0), -1)
assert model(x).device == torch.device('cuda:0'), "Automoving data to same device as model failed"
x = x.cuda(0)
assert model(x).device == torch.device('cuda:0'), "Automoving data to same device as model failed"


@pytest.fixture
def mocked_device_count(monkeypatch):
def device_count():
Expand Down