Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not pass non_blocking=True if it does not support this argument #2910

Merged
merged 6 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed save apex scaler states ([#2828](https://github.com/PyTorchLightning/pytorch-lightning/pull/2828))

- Fixed passing `non_blocking=True` when transferring a batch object that does not support it ([#2910](https://github.com/PyTorchLightning/pytorch-lightning/pull/2910))

## [0.8.5] - 2020-07-09

### Added
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,13 @@ def transfer_batch_to_device(self, batch, device)
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
batch and determines the target devices.
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def batch_to(data):
setattr(device_data, field, device_field)
return device_data

return data.to(device, non_blocking=True)
kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {}
return data.to(device, **kwargs)

return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)
22 changes: 22 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import namedtuple
from unittest.mock import patch

import pytest
import torch
Expand Down Expand Up @@ -384,3 +385,24 @@ def to(self, *args, **kwargs):

assert batch.text.type() == 'torch.cuda.LongTensor'
assert batch.label.type() == 'torch.cuda.LongTensor'


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_non_blocking():
""" Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """
trainer = Trainer()

batch = torch.zeros(2, 3)
with patch.object(batch, 'to', wraps=batch.to) as mocked:
trainer.transfer_batch_to_gpu(batch, 0)
mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True)

class BatchObject(object):

def to(self, *args, **kwargs):
pass

batch = BatchObject()
with patch.object(batch, 'to', wraps=batch.to) as mocked:
trainer.transfer_batch_to_gpu(batch, 0)
mocked.assert_called_with(torch.device('cuda', 0))