Skip to content

Commit

Permalink
doctest
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jun 6, 2020
1 parent d622104 commit 93a742e
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from functools import wraps
from typing import Callable

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn, transfer_batch_to_device
from pytorch_lightning.utilities import rank_zero_warn


def data_loader(fn):
Expand Down Expand Up @@ -32,21 +34,24 @@ def auto_move_data(fn: Callable) -> Callable:
Example:
.. code-block:: python
>>> class LitModel(LightningModule):
... @auto_move_data
... def forward(self, x):
... return x
>>> LitModel.forward = auto_move_data(LitModel.forward)
>>> model = LitModel()
>>> model = model.to('cuda')
>>> model(torch.zeros(1, 3))
tensor([[0., 0., 0.]], device='cuda:0')
model = model.cuda(0)
model.prepare_data()
loader = model.train_dataloader()
for x, y in loader:
output = model(x)
"""
@wraps(fn)
def auto_transfer_args(self, *args, **kwargs):
if not isinstance(self, LightningModule):
return fn(self, *args, **kwargs)

args = transfer_batch_to_device(args, self.device)
kwargs = transfer_batch_to_device(kwargs, self.device)
args = self.transfer_batch_to_device(args, self.device)
kwargs = self.transfer_batch_to_device(kwargs, self.device)
return fn(self, *args, **kwargs)

return auto_transfer_args

0 comments on commit 93a742e

Please sign in to comment.