Skip to content

Commit

Permalink
call model hook by default
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 31, 2020
1 parent 0199148 commit eaccee7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
12 changes: 8 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import transfer_batch_to_device


try:
from apex import amp
Expand Down Expand Up @@ -158,16 +160,15 @@ def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
Lightning only calls the hook if it does not recognize the data type of your batch as one of
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:`torch.Tensor`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- ``torchtext.data.Batch`` (COMING SOON)
These data types (and any arbitrary nesting of them) are supported out of the box
(see :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device`).
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Example::
Expand All @@ -177,6 +178,8 @@ def transfer_batch_to_device(self, batch, device)
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
return batch
Args:
Expand All @@ -188,11 +191,12 @@ def transfer_batch_to_device(self, batch, device)
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument.
any other device than the one passed in as argument (unless you know what you are doing).
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
batch and determines the target devices.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.transfer_batch_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
return transfer_batch_to_device(batch, device)
7 changes: 3 additions & 4 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@
from typing import Union, Callable, Any, Optional

from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
Expand Down Expand Up @@ -474,9 +473,9 @@ def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
return self.__transfer_batch_to_device(batch, device)

def __transfer_batch_to_device(self, batch: Any, device: torch.device):
if self.is_overridden(LightningModule.transfer_batch_to_device.__name__):
# user-override for custom batch types
return self.get_model().transfer_batch_to_device(batch, device)
model = self.get_model()
if model is not None:
return model.transfer_batch_to_device(batch, device)
return transfer_batch_to_device(batch, device)

def single_gpu_train(self, model):
Expand Down
12 changes: 7 additions & 5 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,14 @@ class CurrentTestModel(EvalModelTemplate):

hook_called = False

def transfer_batch_to_device(self, batch, device):
def transfer_batch_to_device(self, data, device):
self.hook_called = True
if isinstance(batch, CustomBatch):
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
return batch
if isinstance(data, CustomBatch):
data.samples = data.samples.to(device)
data.targets = data.targets.to(device)
else:
data = super().transfer_batch_to_device(data, device)
return data

model = CurrentTestModel()
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))
Expand Down

0 comments on commit eaccee7

Please sign in to comment.