Skip to content

Commit

Permalink
Do not pass non_blocking=True if it does not support this argument (L…
Browse files Browse the repository at this point in the history
…ightning-AI#2910)

* add docs

* non blocking only on tensor

* changelog

* add test case

* add test comment

* update changelog


changelog


chlog
  • Loading branch information
awaelchli authored and atee committed Aug 17, 2020
1 parent e9ad6db commit 6efccda
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed save apex scaler states ([#2828](https://github.com/PyTorchLightning/pytorch-lightning/pull/2828))

- Fixed passing `non_blocking=True` when transferring a batch object that does not support it ([#2910](https://github.com/PyTorchLightning/pytorch-lightning/pull/2910))

## [0.8.5] - 2020-07-09

### Added
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,13 @@ def transfer_batch_to_device(self, batch, device)
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
batch and determines the target devices.
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def batch_to(data):
setattr(device_data, field, device_field)
return device_data

return data.to(device, non_blocking=True)
kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {}
return data.to(device, **kwargs)

return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)
22 changes: 22 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import namedtuple
from unittest.mock import patch

import pytest
import torch
Expand Down Expand Up @@ -384,3 +385,24 @@ def to(self, *args, **kwargs):

assert batch.text.type() == 'torch.cuda.LongTensor'
assert batch.label.type() == 'torch.cuda.LongTensor'


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_non_blocking():
""" Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """
trainer = Trainer()

batch = torch.zeros(2, 3)
with patch.object(batch, 'to', wraps=batch.to) as mocked:
trainer.transfer_batch_to_gpu(batch, 0)
mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True)

class BatchObject(object):

def to(self, *args, **kwargs):
pass

batch = BatchObject()
with patch.object(batch, 'to', wraps=batch.to) as mocked:
trainer.transfer_batch_to_gpu(batch, 0)
mocked.assert_called_with(torch.device('cuda', 0))

0 comments on commit 6efccda

Please sign in to comment.