diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index cc74cfc71d83f8..15843492293cba 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -1,8 +1,10 @@ from functools import wraps from typing import Callable +import torch + from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_warn, transfer_batch_to_device +from pytorch_lightning.utilities import rank_zero_warn def data_loader(fn): @@ -32,21 +34,24 @@ def auto_move_data(fn: Callable) -> Callable: Example: - .. code-block:: python + >>> class LitModel(LightningModule): + ... @auto_move_data + ... def forward(self, x): + ... return x + >>> LitModel.forward = auto_move_data(LitModel.forward) + >>> model = LitModel() + >>> model = model.to('cuda') + >>> model(torch.zeros(1, 3)) + tensor([[0., 0., 0.]], device='cuda:0') - model = model.cuda(0) - model.prepare_data() - loader = model.train_dataloader() - for x, y in loader: - output = model(x) """ @wraps(fn) def auto_transfer_args(self, *args, **kwargs): if not isinstance(self, LightningModule): return fn(self, *args, **kwargs) - args = transfer_batch_to_device(args, self.device) - kwargs = transfer_batch_to_device(kwargs, self.device) + args = self.transfer_batch_to_device(args, self.device) + kwargs = self.transfer_batch_to_device(kwargs, self.device) return fn(self, *args, **kwargs) return auto_transfer_args