diff --git a/CHANGELOG.md b/CHANGELOG.md index f1382f621d2eb..2f519c7de9b38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed test metrics not being logged with `LoggerCollection` ([#2723](https://github.com/PyTorchLightning/pytorch-lightning/pull/2723)) +- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 6f9b8e176ffe1..75130b297ddcc 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -6,6 +6,7 @@ import torch import importlib + TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None if TORCHTEXT_AVAILABLE: from torchtext.data import Batch @@ -92,6 +93,7 @@ def move_data_to_device(batch: Any, device: torch.device): - :meth:`torch.Tensor.to` - :class:`torch.device` """ + def batch_to(data): # try to move torchtext data first if TORCHTEXT_AVAILABLE and isinstance(data, Batch): @@ -99,11 +101,10 @@ def batch_to(data): # Shallow copy because each Batch has a reference to Dataset which contains all examples device_data = copy(data) for field in data.fields: - # Batch contains output of Field.process(...) which is tensor hence .to(...) exists - device_field = getattr(data, field).to(device, non_blocking=True) + device_field = move_data_to_device(getattr(data, field), device) setattr(device_data, field, device_field) return device_data - else: - return data.to(device, non_blocking=True) + + return data.to(device, non_blocking=True) return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to) diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index 23c07f93d4697..b1ac0cfe4d7cd 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -13,7 +13,7 @@ def __init__(self, data_dir: str = './'): def prepare_data(self): TrialMNIST(self.data_dir, train=True, download=True) TrialMNIST(self.data_dir, train=False, download=True) - + def setup(self): mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 4cb52a54610e3..f3e4f113784dc 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -28,7 +28,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): max_epochs=2, ) trainer.fit(model) - assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints' + assert checkpoint.dirpath == tmpdir / trainer.logger.name / 'version_0' / 'checkpoints' @pytest.mark.parametrize( diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py new file mode 100644 index 0000000000000..9ea29420788d7 --- /dev/null +++ b/tests/utilities/test_apply_func_torchtext.py @@ -0,0 +1,52 @@ +import pytest +import torch +import torchtext +from torchtext.data.example import Example + +from pytorch_lightning.utilities.apply_func import move_data_to_device + + +def _get_torchtext_data_iterator(include_lengths=False): + text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec + init_token="", eos_token="", # nosec + include_lengths=include_lengths) # nosec + + example1 = Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)}) + example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) + example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)}) + + dataset = torchtext.data.Dataset( + [example1, example2, example3], + {"text": text_field}, + ) + text_field.build_vocab(dataset) + + iterator = torchtext.data.Iterator(dataset, batch_size=3, + sort_key=None, device=None, + batch_size_fn=None, + train=True, repeat=False, shuffle=None, + sort=None, sort_within_batch=None) + return iterator, text_field + + +@pytest.mark.parametrize('include_lengths', [False, True]) +@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0))]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test assumes GPU machine") +def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): + data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) + data_iter = iter(data_iterator) + batch = next(data_iter) + batch_on_device = move_data_to_device(batch, device) + + if include_lengths: + # tensor with data + assert (batch_on_device.text[0].device == device) + # tensor with length of data + assert (batch_on_device.text[1].device == device) + else: + assert (batch_on_device.text.device == device) + + +@pytest.mark.parametrize('include_lengths', [False, True]) +def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths): + test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu'))