Skip to content

Commit

Permalink
make a decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 20, 2020
1 parent b0245e2 commit ecb0dd3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
36 changes: 35 additions & 1 deletion pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn, transfer_batch_to_device


def data_loader(fn):
Expand All @@ -12,3 +13,36 @@ def data_loader(fn):
def inner_fx(self):
return fn(self)
return inner_fx


def auto_move_data(fn):
"""
Decorator for :class:`~pytorch_lightning.core.lightning.LightningModule` methods for which
input arguments should be moved automatically to the correct device.
It as no effect if applied to a method of an object that is not an instance of
:class:`~pytorch_lightning.core.lightning.LightningModule` and is typically applied to ``__call__``
or ``forward``.
Args:
fn: A LightningModule method for which the arguments should be moved to the device
the parameters are on.
Example:
.. code-block:: python
model = model.cuda(0)
model.prepare_data()
loader = model.train_dataloader()
for x, y in loader:
output = model(x)
"""
def auto_transfer_args(self, *args, **kwargs):
if not isinstance(self, LightningModule):
return fn(self, *args, **kwargs)

args = transfer_batch_to_device(args, self.device)
kwargs = transfer_batch_to_device(kwargs, self.device)
return fn(self, *args, **kwargs)

return auto_transfer_args
29 changes: 0 additions & 29 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,35 +100,6 @@ def forward(self, x):
if self.trainer.proc_rank == 0:
print(*args, **kwargs)

def __call__(self, *data, **kwargs):
r"""
Automatically moves data to correct device if possible, then call torch.nn.Module.__call__
Lightning will warn you if it automatically moves any data
Args:
*data: Any positional arguments for torch.nn.Module.__call__. These are typically input data
**kwargs: Any keyword arguments for torch.nn.Module.__call__
Example:
.. code-block:: python
model = model.cuda(0)
model.prepare_data()
loader = model.train_dataloader()
for x, y in loader:
output = model(x) # Lightning will automove data here and warn you of it
"""
devices = [p.device for p in self.parameters()]
# All parameters must be on same device to automove data
# Otherwise we just do what nn.Module does normally
if len(set(devices)) == 1:
device = devices[0]
data = transfer_data_to_device(data, device.type, device.index, warn_on_transfer=True)
kwargs = transfer_data_to_device(kwargs, device.type, device.index, warn_on_transfer=True)
return super(LightningModule, self).__call__(*data, **kwargs)

@abstractmethod
def forward(self, *args, **kwargs):
r"""
Expand Down

0 comments on commit ecb0dd3

Please sign in to comment.