diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 3979a4fc6f7ee4..0553615e4d2046 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -1,4 +1,5 @@ -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn, transfer_batch_to_device def data_loader(fn): @@ -12,3 +13,36 @@ def data_loader(fn): def inner_fx(self): return fn(self) return inner_fx + + +def auto_move_data(fn): + """ + Decorator for :class:`~pytorch_lightning.core.lightning.LightningModule` methods for which + input arguments should be moved automatically to the correct device. + It as no effect if applied to a method of an object that is not an instance of + :class:`~pytorch_lightning.core.lightning.LightningModule` and is typically applied to ``__call__`` + or ``forward``. + + Args: + fn: A LightningModule method for which the arguments should be moved to the device + the parameters are on. + + Example: + + .. code-block:: python + + model = model.cuda(0) + model.prepare_data() + loader = model.train_dataloader() + for x, y in loader: + output = model(x) + """ + def auto_transfer_args(self, *args, **kwargs): + if not isinstance(self, LightningModule): + return fn(self, *args, **kwargs) + + args = transfer_batch_to_device(args, self.device) + kwargs = transfer_batch_to_device(kwargs, self.device) + return fn(self, *args, **kwargs) + + return auto_transfer_args diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1eafb4fb03f63c..568af7e0ddad36 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -100,35 +100,6 @@ def forward(self, x): if self.trainer.proc_rank == 0: print(*args, **kwargs) - def __call__(self, *data, **kwargs): - r""" - Automatically moves data to correct device if possible, then call torch.nn.Module.__call__ - Lightning will warn you if it automatically moves any data - - Args: - *data: Any positional arguments for torch.nn.Module.__call__. These are typically input data - **kwargs: Any keyword arguments for torch.nn.Module.__call__ - - Example: - - .. code-block:: python - - model = model.cuda(0) - model.prepare_data() - loader = model.train_dataloader() - for x, y in loader: - output = model(x) # Lightning will automove data here and warn you of it - - """ - devices = [p.device for p in self.parameters()] - # All parameters must be on same device to automove data - # Otherwise we just do what nn.Module does normally - if len(set(devices)) == 1: - device = devices[0] - data = transfer_data_to_device(data, device.type, device.index, warn_on_transfer=True) - kwargs = transfer_data_to_device(kwargs, device.type, device.index, warn_on_transfer=True) - return super(LightningModule, self).__call__(*data, **kwargs) - @abstractmethod def forward(self, *args, **kwargs): r"""