Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/torchtext include lengths #2689

Merged
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
924e4c1
Test using torchtext.data.Field with include_lengths=True/False
Jul 24, 2020
1fcbe36
Fix issue that Tensors in a Batch generated by torchtext with torchte…
Jul 24, 2020
59e97b2
Add description for fix of issue #2688
Jul 24, 2020
3e3fbbe
changes to accomodate CodeFactor issues
Jul 24, 2020
fe9816d
Another attemt to make last CodeFactor issue pass (it's a false alarm)
Jul 24, 2020
957ee89
temporarly disable test of test_grad_tracking to check if testing wil…
Jul 24, 2020
7971e7d
reenable test in test_grad_norm
Jul 25, 2020
4d0a849
Update CHANGELOG.md
thschaaf Jul 26, 2020
c994e88
Renamed get_torchtext_data_iterator to _get_torchtext_data_iterator a…
Jul 26, 2020
f60613c
Update pytorch_lightning/utilities/apply_func.py
thschaaf Jul 26, 2020
c9fdf50
adding tests more specific to batch_move_data_to_device with tochtext…
Jul 27, 2020
5e568ea
added check that Tensors were moved to target device
Jul 27, 2020
a6b96b0
removed tests using RNN models to be moved into a separate PR
Jul 27, 2020
0eabe91
Merge branch 'master' into bugfix/torchtext-include_lengths
thschaaf Jul 27, 2020
398ab54
fixing FLAKE8 errors that showed up after merge from master branch
Jul 27, 2020
8d56dc8
Merge branch 'master' into bugfix/torchtext-include_lengths
thschaaf Jul 27, 2020
a99fc7d
parameterized test to reduce code duplication
Jul 28, 2020
61e692f
Added check only if length tensor exist. Removed left over comments.
Jul 28, 2020
0c25f43
rearranged device parameterization and added pytest.param
Jul 28, 2020
f08dd78
Try to figure out why only one device is tested on Linux machines
Jul 28, 2020
d2c4598
Testing on CPU and GPU devices (GPU test is skip if no cuda device is…
Jul 28, 2020
9bd3854
added test for TPU device (experimental)
Jul 28, 2020
d04c288
Adding test parameterization for TPU test (experimental)
Jul 28, 2020
cca6ff3
change import statement to limit what is imported for a TPU environment
Jul 28, 2020
5f3680d
made test work with TPU
Jul 28, 2020
08ebb6d
Change to trigger CI
Jul 28, 2020
fa6b2f9
Change to trigger CI
Jul 28, 2020
f9d9887
Merge branch 'bugfix/torchtext-include_lengths' of https://github.com…
Jul 28, 2020
940c34d
uncommented TPU test to check CI
Jul 28, 2020
584328a
reenabling TPU test
Jul 29, 2020
ae71b14
small change to trigger CI build
Jul 29, 2020
34201bc
small change to trigger CI build
Jul 29, 2020
a53a469
small change to trigger CI build
Jul 29, 2020
647e44b
adding tests/utilities/test_apply_func_torchtext.py to CI TPU test
Jul 29, 2020
ff080da
try to make test not skipped on CI with TPU
Jul 29, 2020
43a5ea9
remove testing on TPU
Jul 29, 2020
73583c1
undo an accidental change to test_tpu.py (file should not have been t…
Jul 29, 2020
b929711
small change to trigger CI build
Jul 29, 2020
68e2152
small change to trigger CI build
Jul 29, 2020
c97cd69
Update tests/utilities/test_apply_func_torchtext.py
awaelchli Jul 29, 2020
1685077
Revert to previous version
Jul 29, 2020
8a7d68b
Apply suggestions from code review
Borda Jul 29, 2020
72f64ad
Merge branch 'master' into bugfix/torchtext-include_lengths
thschaaf Jul 29, 2020
3c04090
Change to trigger CI
Jul 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632))

- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))

## [0.8.5] - 2020-07-09

### Added
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

import importlib

TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None
if TORCHTEXT_AVAILABLE:
from torchtext.data import Batch
Expand Down Expand Up @@ -92,18 +93,18 @@ def move_data_to_device(batch: Any, device: torch.device):
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""

def batch_to(data):
# try to move torchtext data first
if TORCHTEXT_AVAILABLE and isinstance(data, Batch):

# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field in data.fields:
# Batch contains output of Field.process(...) which is tensor hence .to(...) exists
device_field = getattr(data, field).to(device, non_blocking=True)
device_field = move_data_to_device(getattr(data, field), device)
setattr(device_data, field, device_field)
return device_data
else:
return data.to(device, non_blocking=True)

return data.to(device, non_blocking=True)

return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)
2 changes: 1 addition & 1 deletion tests/base/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, data_dir: str = './'):
def prepare_data(self):
TrialMNIST(self.data_dir, train=True, download=True)
TrialMNIST(self.data_dir, train=False, download=True)

def setup(self):
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
max_epochs=2,
)
trainer.fit(model)
assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints'
assert checkpoint.dirpath == tmpdir / trainer.logger.name / 'version_0' / 'checkpoints'


@pytest.mark.parametrize(
Expand Down
51 changes: 51 additions & 0 deletions tests/utilities/test_apply_func_torchtext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import torch
import torchtext
from torchtext.data.example import Example

from pytorch_lightning.utilities.apply_func import move_data_to_device


def _get_torchtext_data_iterator(include_lengths=False):
text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec
init_token="<s>", eos_token="</s>", # nosec
include_lengths=include_lengths) # nosec

example1 = Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)})
example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)})
example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)})

dataset = torchtext.data.Dataset([example1, example2, example3],
{"text": text_field}
)
Borda marked this conversation as resolved.
Show resolved Hide resolved
text_field.build_vocab(dataset)

iterator = torchtext.data.Iterator(dataset, batch_size=3,
sort_key=None, device=None,
batch_size_fn=None,
train=True, repeat=False, shuffle=None,
sort=None, sort_within_batch=None)
return iterator, text_field


@pytest.mark.parametrize('include_lengths', [False, True])
@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0))])
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test assumes GPU machine")
def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device):
data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths)
data_iter = iter(data_iterator)
batch = next(data_iter)
batch_on_device = move_data_to_device(batch, device)

if include_lengths:
# tensor with data
assert (batch_on_device.text[0].device == device)
# tensor with length of data
assert (batch_on_device.text[1].device == device)
else:
assert (batch_on_device.text.device == device)


@pytest.mark.parametrize('include_lengths', [False, True])
def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths):
test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu'))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved