-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[blocked by #1756] Add decorator to auto-move data for inference #1526
Changes from all commits
37957ff
0c228d5
ec90428
be53c3e
3b59fe5
58d59ef
338aa53
ce12deb
ebccf62
b53ab74
7cd93a3
c462b55
60fe2b7
916daf6
871d9e8
f02b508
71991c5
e983022
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities import rank_zero_warn | ||
|
||
from pytorch_lightning.utilities import transfer_data_to_device | ||
try: | ||
import torch_xla.core.xla_model as xm | ||
except ImportError: | ||
|
@@ -94,11 +94,42 @@ def forward(self, x): | |
if self.trainer.proc_rank == 0: | ||
print(*args, **kwargs) | ||
|
||
def __call__(self, *data, **kwargs): | ||
r""" | ||
Automatically moves data to correct device if possible, then call torch.nn.Module.__call__ | ||
Lightning will warn you if it automatically moves any data | ||
|
||
Args: | ||
*data: Any positional arguments for torch.nn.Module.__call__. These are typically input data | ||
**kwargs: Any keyword arguments for torch.nn.Module.__call__ | ||
|
||
Example: | ||
|
||
.. code-block:: python | ||
|
||
model = model.cuda(0) | ||
model.prepare_data() | ||
loader = model.train_dataloader() | ||
for x, y in loader: | ||
output = model(x) # Lightning will automove data here and warn you of it | ||
|
||
""" | ||
devices = [p.device for p in self.parameters()] | ||
# All parameters must be on same device to automove data | ||
# Otherwise we just do what nn.Module does normally | ||
if len(set(devices)) == 1: | ||
device = devices[0] | ||
data = transfer_data_to_device(data, device.type, device.index, warn_on_transfer=True) | ||
kwargs = transfer_data_to_device(kwargs, device.type, device.index, warn_on_transfer=True) | ||
return super(LightningModule, self).__call__(*data, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Borda and I discussed this and we both agree, that we shouldn't do this in the Module (at least not by default). In our opinion we should always be able to use lightning module as nn module. What I propose is the following: We change this part to a decorator, that can be added to forward and is automatically added from trainer. (sorry for that coming so late). In that case you could make the decorator a class, that also caches devices eventually :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Love the idea of adding the decorator dynamically there! |
||
|
||
@abstractmethod | ||
def forward(self, *args, **kwargs): | ||
r""" | ||
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define | ||
the operations you want to use for prediction (i.e.: on a server or as a feature extractor). | ||
LightningModule will also automatically copy data to the same device as the model if the model | ||
is on CPU or a single GPU for inference. | ||
|
||
Normally you'd call ``self()`` from your :meth:`training_step` method. | ||
This makes it easy to write a complex system for training with the outputs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
"""General utilities""" | ||
|
||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn | ||
from pytorch_lightning.utilities.data import transfer_data_to_device |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
|
||
from pytorch_lightning.utilities import rank_zero_warn | ||
|
||
try: | ||
import torch_xla.core.xla_model as xm | ||
except ImportError: | ||
XLA_AVAILABLE = False | ||
else: | ||
XLA_AVAILABLE = True | ||
|
||
|
||
def transfer_data_to_device(batch, device_type, idx=None, warn_on_transfer=False): | ||
""" | ||
Utility function to copy data to given device | ||
Works for any form of nested lists, tuples or dictionaries containting tensors | ||
Deal with TPUs separately, they don't use device indexes for some reason | ||
""" | ||
if device_type == 'tpu' and XLA_AVAILABLE: | ||
if callable(getattr(batch, 'to', None)): | ||
if warn_on_transfer: | ||
rank_zero_warn('Auto transferred data to device {}'.format(xm.xla_device())) | ||
return batch.to(xm.xla_device()) | ||
|
||
# base case: nothing to do | ||
device = torch.device(device_type, idx) | ||
if torch.is_tensor(batch) and batch.device == device: | ||
return batch | ||
|
||
# object can be directly moved using `cuda` or `to` | ||
if callable(getattr(batch, 'cuda', None)) and device_type == 'cuda': | ||
if warn_on_transfer: | ||
rank_zero_warn('Auto transferred data to device {}'.format(device)) | ||
return batch.cuda(device=device) | ||
|
||
if callable(getattr(batch, 'to', None)): | ||
if warn_on_transfer: | ||
rank_zero_warn('Auto transferred data to device {}'.format(device)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would move the warning out to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @awaelchli I thought about this, a problem with that is I effectively run into the same code duplication problem by trying to detect which device it's on in |
||
return batch.to(device=device) | ||
|
||
# when list or tuple | ||
if isinstance(batch, (list, tuple)): | ||
if isinstance(batch, tuple): | ||
batch = list(batch) | ||
for i, x in enumerate(batch): | ||
batch[i] = transfer_data_to_device(x, device_type, idx, warn_on_transfer) | ||
return batch | ||
|
||
# when dict | ||
if isinstance(batch, dict): | ||
for k, v in batch.items(): | ||
batch[k] = transfer_data_to_device(v, device_type, idx, warn_on_transfer) | ||
return batch | ||
|
||
# nothing matches, return the value as is without transform | ||
return batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, if looping always over all the params is a good idea. Can we maybe cache the devices somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mcarilli are we missing a simpler/cleaner way of doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I don't understand what you're trying to do here, but it looks like you're only using device[0], so why collect them all? Also if the model params do reside on multiple devices, it's hard to predict which device the user actually wants the input data to reside on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HenryJia
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I only apply automatic data transfer if we are dealing with the simple case of the model residing on one device, as trying to auto transfer data when the model is spread across multiple devices is very non-trivial and is heavily dependent on model structure