Skip to content

Commit

Permalink
Add transfer_batch_to_device hook to DataModule (#3038)
Browse files Browse the repository at this point in the history
* ✨ add dm to_device logic in trainer

* 🔥 remove unnecessary comment

* ✨ add to_device logic to datamodule

* ✅ add test

* updated docs

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
nateraw and williamFalcon committed Aug 20, 2020
1 parent cee5eaf commit bab89b8
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 6 deletions.
16 changes: 16 additions & 0 deletions docs/source/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,22 @@ Use this method to generate the test dataloader. This is also a good place to pl
])
return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
transfer_batch_to_device
^^^^^^^^^^^^^^^^^^^^^^^^
Override to define how you want to move an arbitrary batch to a device

.. code-block:: python
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def transfer_batch_to_device(self, batch, device):
x = batch['x']
x = CustomDataWrapper(x)
batch['x'].to(device)
return batch
------------------

Using a DataModule
Expand Down
51 changes: 51 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Tuple, Union

import torch
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import parsing, rank_zero_only, rank_zero_warn
Expand Down Expand Up @@ -306,6 +307,56 @@ def test_dataloader(self):
return loader
"""

@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:`torch.Tensor` or anything that implements `.to(...)`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- :class:`torchtext.data.batch.Batch`
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Example::
def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
return batch
Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.
Returns:
A reference to the data on the new device.
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `LightningDataModule` attributes.
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,11 +1113,6 @@ def __attach_datamodule(self, model, datamodule, stage):
# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:

# If datamodule.setup('test') has not been called yet, call it
# if stage == 'test':
# if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test:
# datamodule.setup('test')

# Override loader hooks
if self.is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
Expand All @@ -1126,6 +1121,10 @@ def __attach_datamodule(self, model, datamodule, stage):
if self.is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader

# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
if self.is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device

self.datamodule = datamodule

def run_pretrain_routine(self, model: LightningModule):
Expand Down
40 changes: 39 additions & 1 deletion tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pickle
from argparse import ArgumentParser
from unittest.mock import MagicMock

import pytest
import torch

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.develop_utils import reset_seed
Expand Down Expand Up @@ -317,3 +318,40 @@ def test_full_loop_ddp_spawn(tmpdir):
result = trainer.test(datamodule=dm)
result = result[0]
assert result['test_acc'] > 0.8


@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires multi-GPU machine")
def test_dm_transfer_batch_to_device(tmpdir):
class CustomBatch:

def __init__(self, data):
self.samples = data[0]
self.targets = data[1]

class CurrentTestDM(LightningDataModule):

hook_called = False

def transfer_batch_to_device(self, data, device):
self.hook_called = True
if isinstance(data, CustomBatch):
data.samples = data.samples.to(device)
data.targets = data.targets.to(device)
else:
data = super().transfer_batch_to_device(data, device)
return data

model = EvalModelTemplate()
dm = CurrentTestDM()
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))

trainer = Trainer()
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
trainer.get_model = MagicMock(return_value=model)
if trainer.is_overridden('transfer_batch_to_device', dm):
model.transfer_batch_to_device = dm.transfer_batch_to_device

batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
expected = torch.device('cuda', 0)
assert dm.hook_called
assert batch_gpu.samples.device == batch_gpu.targets.device == expected

0 comments on commit bab89b8

Please sign in to comment.