Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training on GPU failed with Torchtext when using include_lengths=True in torchtext.data.Field #2688

Closed
thschaaf opened this issue Jul 24, 2020 · 0 comments · Fixed by #2689
Closed
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@thschaaf
Copy link
Contributor

🐛 Bug

The issues raises in pytorch_lightning/utilities/apply_func.py which assumes that the attributes of a Batch from trochtext are Tensors, however if torchtext.data.Field is configured to include a length Tensor (include_lengths=True) the field is a tuple.

A bugfix is prepared and a PR can be submitted soon.

To Reproduce

Steps to reproduce the behavior:

  1. Use Torchtext Field with include_lengths=True on a GPU machine and fit model.
  2. Training works on CPU but fails on GPU with: TypeError: cannot unpack non-iterable NoneType object

Full Error Message

Traceback (most recent call last):
 File "debug_torchtext.py", line 105, in <module>
  trainer.fit(model)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1003, in fit
  results = self.single_gpu_train(model)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 186, in single_gpu_train
  results = self.run_pretrain_routine(model)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1213, in run_pretrain_routine
  self.train()
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 370, in train
  self.run_training_epoch()
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 452, in run_training_epoch
  batch_output = self.run_training_batch(batch, batch_idx)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 632, in run_training_batch
  self.hiddens
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 776, in optimizer_closure
  hiddens)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 944, in training_forward
  batch = self.transfer_batch_to_gpu(batch, gpu_id)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 159, in transfer_batch_to_gpu
  return self.__transfer_batch_to_device(batch, device)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 164, in __transfer_batch_to_device
  return model.transfer_batch_to_device(batch, device)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 242, in transfer_batch_to_device
  return move_data_to_device(batch, device)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/utilities/apply_func.py", line 128, in move_data_to_device
  return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/utilities/apply_func.py", line 35, in apply_to_collection
  return function(data, *args, **kwargs)
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/pytorch_lightning/utilities/apply_func.py", line 103, in batch_to
  device_field = getattr(data, field).to(device, non_blocking=True)
AttributeError: 'tuple' object has no attribute 'to'
Exception ignored in: <function tqdm.__del__ at 0x7fcb5e0b2680>
Traceback (most recent call last):
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/tqdm/std.py", line 1086, in __del__
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/tqdm/std.py", line 1293, in close
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/tqdm/std.py", line 1471, in display
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/tqdm/std.py", line 1089, in __repr__
 File "/home1/thschaaf/miniconda3/envs/p37/lib/python3.7/site-packages/tqdm/std.py", line 1433, in format_dict
TypeError: cannot unpack non-iterable NoneType object

Code sample

import torch
from torch import nn, Tensor
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from torchtext import data
seed_everything(1234)
def get_debug_data_loader():
    text_field = data.Field(sequential=True, pad_first=False,
                            init_token="<s>", eos_token="</s>", include_lengths=True)
    example1 = data.example.Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)})
    example2 = data.example.Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)})
    example3 = data.example.Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)})
    dataset = data.Dataset([example1, example2, example3], {"text": text_field})
    text_field.build_vocab(dataset)
    iterator = data.Iterator(dataset, batch_size=3,
                             sort_key=None, device=None, batch_size_fn=None,
                             train=True, repeat=False, shuffle=None, sort=None, sort_within_batch=None)
    return iterator, text_field
class DebugModel(pl.LightningModule):
    def __init__(self):
        super(DebugModel, self).__init__()
        # setup data loader
        self.debug_data_loader, self.text_field = get_debug_data_loader()
        self.learning_rate = 0.001
        self.hid_dim = 4
        pad_idx = self.text_field.vocab.stoi['<pad>']
        self.criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
        self.INPUT_DIM = len(self.text_field.vocab)
        self.ENC_EMB_DIM = 4  # keep it small for debugging
        self.embedding = nn.Embedding(self.INPUT_DIM, self.ENC_EMB_DIM)
        self.rnn = nn.GRU(self.ENC_EMB_DIM, self.hid_dim, 1, bidirectional=False)
        self.out = nn.Linear(self.hid_dim, self.embedding.num_embeddings)
        self.OUTPUT_DIM = len(self.text_field.vocab)
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    def forward(self, input_seq, length):
        embedded: Tensor = self.embedding(input_seq)
        packed_embedded: Tensor = torch.nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=False,
                                                                          enforce_sorted=False)
        packed_outputs, hidden = self.rnn(packed_embedded)  # [sent len, batch size, emb dim]
        outputs, length = torch.nn.utils.rnn.pad_packed_sequence(packed_outputs)
        # outputs -> [sent len, batch size, hid dim * n directions]
        # hidden -> [n layers * n directions, batch size, hid dim]
        output = outputs.squeeze(0)
        prediction = self.out(output)
        return prediction
    @staticmethod
    def _parse_batch(batch):
        source = batch.text[0]
        source_length = batch.text[1]
        return source, source_length
    def training_step(self, batch, batch_nb):
        x = self._parse_batch(batch)
        target, target_length = x
        output = self.forward(target, target_length)
        loss = self.criterion(output[:-1].view(-1, output.shape[2]), target[1:].view(-1))
        prefix = 'train'
        tensorboard_logs = {f'{prefix}_loss': loss.item()}
        result = {'loss': loss, 'log': tensorboard_logs}
        return result
    def train_dataloader(self):
        return self.debug_data_loader
model = DebugModel()
cuda_device_cnt = torch.cuda.device_count()
if cuda_device_cnt > 0:
    use_num_cuda_devices = 1
else:
    use_num_cuda_devices = None
trainer = Trainer(fast_dev_run=False, max_steps=None,
                  gradient_clip_val=10,
                  weights_summary='full', gpus=use_num_cuda_devices,
                  show_progress_bar=True)
trainer.fit(model)

Expected behavior

Should not raise an error :-)

Environment

 CUDA:
    - GPU:
        - TITAN X (Pascal)
    - available:     True
    - version:      10.2
* Packages:
    - numpy:       1.17.3
    - pyTorch_debug:   False
    - pyTorch_version:  1.5.1
    - pytorch-lightning: 0.8.5
    - tensorboard:    2.2.2
    - tqdm:       4.47.0
* System:
    - OS:        Linux
    - architecture:
        - 64bit
        - 
    - processor:     x86_64
    - python:      3.7.4
    - version:      #1 SMP Tue Mar 17 23:49:17 UTC 2020

Additional context

@thschaaf thschaaf added bug Something isn't working help wanted Open to be worked on labels Jul 24, 2020
thschaaf pushed a commit to thschaaf/pytorch-lightning that referenced this issue Jul 24, 2020
williamFalcon pushed a commit that referenced this issue Jul 31, 2020
* Test using torchtext.data.Field with include_lengths=True/False

* Fix issue that Tensors in a Batch generated by torchtext with torchtext.data.Field configured as include_lengths=True

* Add description for fix of issue #2688

* changes to accomodate CodeFactor issues

* Another attemt to make last CodeFactor issue pass (it's a false alarm)

* temporarly disable test of test_grad_tracking to check if testing will pass

* reenable test in test_grad_norm

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Renamed get_torchtext_data_iterator to _get_torchtext_data_iterator as suggested by @Borda

* Update pytorch_lightning/utilities/apply_func.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* adding tests more specific to batch_move_data_to_device with tochtext Batch

* added check that Tensors were moved to target device

* removed tests using RNN models to be moved into a separate PR

* fixing FLAKE8 errors that showed up after merge from master branch
	modified:   tests/base/datamodules.py
	modified:   tests/callbacks/test_model_checkpoint.py

* parameterized test to reduce code duplication

* Added check only if length tensor exist. Removed left over comments.

* rearranged device parameterization and added pytest.param

* Try to figure out why only one device is tested on Linux machines

* Testing on CPU and GPU devices (GPU test is skip if no cuda device is available.

* added test for TPU device (experimental)

* Adding test parameterization for TPU test (experimental)

* change import statement to limit what is imported for a TPU environment

* made test work with TPU

* Change to trigger CI

* Change to trigger CI

* uncommented TPU test to check CI

* reenabling TPU test

* small change to trigger CI build

* small change to trigger CI build

* small change to trigger CI build

* adding tests/utilities/test_apply_func_torchtext.py to CI TPU test

* try to make test not skipped on CI with TPU

* remove testing on TPU

* undo an accidental change to test_tpu.py (file should not have been touched)

* small change to trigger CI build

* small change to trigger CI build

* Update tests/utilities/test_apply_func_torchtext.py

* Revert to previous version

* Apply suggestions from code review

* Change to trigger CI

Co-authored-by: Thomas Schaaf <tschaaf@mmm.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant