Skip to content

Commit

Permalink
Bugfix/torchtext include lengths (#2689)
Browse files Browse the repository at this point in the history
* Test using torchtext.data.Field with include_lengths=True/False

* Fix issue that Tensors in a Batch generated by torchtext with torchtext.data.Field configured as include_lengths=True

* Add description for fix of issue #2688

* changes to accomodate CodeFactor issues

* Another attemt to make last CodeFactor issue pass (it's a false alarm)

* temporarly disable test of test_grad_tracking to check if testing will pass

* reenable test in test_grad_norm

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Renamed get_torchtext_data_iterator to _get_torchtext_data_iterator as suggested by @Borda

* Update pytorch_lightning/utilities/apply_func.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* adding tests more specific to batch_move_data_to_device with tochtext Batch

* added check that Tensors were moved to target device

* removed tests using RNN models to be moved into a separate PR

* fixing FLAKE8 errors that showed up after merge from master branch
	modified:   tests/base/datamodules.py
	modified:   tests/callbacks/test_model_checkpoint.py

* parameterized test to reduce code duplication

* Added check only if length tensor exist. Removed left over comments.

* rearranged device parameterization and added pytest.param

* Try to figure out why only one device is tested on Linux machines

* Testing on CPU and GPU devices (GPU test is skip if no cuda device is available.

* added test for TPU device (experimental)

* Adding test parameterization for TPU test (experimental)

* change import statement to limit what is imported for a TPU environment

* made test work with TPU

* Change to trigger CI

* Change to trigger CI

* uncommented TPU test to check CI

* reenabling TPU test

* small change to trigger CI build

* small change to trigger CI build

* small change to trigger CI build

* adding tests/utilities/test_apply_func_torchtext.py to CI TPU test

* try to make test not skipped on CI with TPU

* remove testing on TPU

* undo an accidental change to test_tpu.py (file should not have been touched)

* small change to trigger CI build

* small change to trigger CI build

* Update tests/utilities/test_apply_func_torchtext.py

* Revert to previous version

* Apply suggestions from code review

* Change to trigger CI

Co-authored-by: Thomas Schaaf <tschaaf@mmm.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu>
  • Loading branch information
5 people committed Jul 31, 2020
1 parent b88fc43 commit a6719f0
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed test metrics not being logged with `LoggerCollection` ([#2723](https://github.com/PyTorchLightning/pytorch-lightning/pull/2723))

- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))

## [0.8.5] - 2020-07-09

### Added
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

import importlib

TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None
if TORCHTEXT_AVAILABLE:
from torchtext.data import Batch
Expand Down Expand Up @@ -92,18 +93,18 @@ def move_data_to_device(batch: Any, device: torch.device):
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""

def batch_to(data):
# try to move torchtext data first
if TORCHTEXT_AVAILABLE and isinstance(data, Batch):

# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field in data.fields:
# Batch contains output of Field.process(...) which is tensor hence .to(...) exists
device_field = getattr(data, field).to(device, non_blocking=True)
device_field = move_data_to_device(getattr(data, field), device)
setattr(device_data, field, device_field)
return device_data
else:
return data.to(device, non_blocking=True)

return data.to(device, non_blocking=True)

return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)
52 changes: 52 additions & 0 deletions tests/utilities/test_apply_func_torchtext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import torch
import torchtext
from torchtext.data.example import Example

from pytorch_lightning.utilities.apply_func import move_data_to_device


def _get_torchtext_data_iterator(include_lengths=False):
text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec
init_token="<s>", eos_token="</s>", # nosec
include_lengths=include_lengths) # nosec

example1 = Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)})
example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)})
example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)})

dataset = torchtext.data.Dataset(
[example1, example2, example3],
{"text": text_field},
)
text_field.build_vocab(dataset)

iterator = torchtext.data.Iterator(dataset, batch_size=3,
sort_key=None, device=None,
batch_size_fn=None,
train=True, repeat=False, shuffle=None,
sort=None, sort_within_batch=None)
return iterator, text_field


@pytest.mark.parametrize('include_lengths', [False, True])
@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0))])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test assumes GPU machine")
def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device):
data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths)
data_iter = iter(data_iterator)
batch = next(data_iter)
batch_on_device = move_data_to_device(batch, device)

if include_lengths:
# tensor with data
assert (batch_on_device.text[0].device == device)
# tensor with length of data
assert (batch_on_device.text[1].device == device)
else:
assert (batch_on_device.text.device == device)


@pytest.mark.parametrize('include_lengths', [False, True])
def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths):
test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu'))

0 comments on commit a6719f0

Please sign in to comment.