Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_has_len does not handle NotImplementedError (raised by torchtext) #2277

Closed
thschaaf opened this issue Jun 19, 2020 · 3 comments · Fixed by #2293 or #2307
Closed

_has_len does not handle NotImplementedError (raised by torchtext) #2277

thschaaf opened this issue Jun 19, 2020 · 3 comments · Fixed by #2293 or #2307
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@thschaaf
Copy link
Contributor

thschaaf commented Jun 19, 2020

🐛 Bug

When using torchtext.data.Iterator with a batch_size_fn function the len function raises a NotImplementedError which is not caught by _has_len function.

A bug-fix is very simple by just returning False if a NotImplementedError is raised. This is unlikely to have any negative side effects since it corresponds with what _hads_len is expected to do. The fix allowed me to train my model using torch text.

I plan to submit a pull request with the fix above.

There are no additional dependencies required; however this problem occurred when using torchtext.

Example stack trace:

Traceback (most recent call last):
  File "/Users/thomas/scm/OakDataPrep/oakSkipThoughtTrainer.py", line 18, in <module>
    trainer.fit(model)
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 952, in fit
    self.run_pretrain_routine(model)
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 334, in train
    self.reset_train_dataloader(model)
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 201, in reset_train_dataloader
    if not _has_len(self.train_dataloader):
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 49, in _has_len
    if len(dataloader) == 0:
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/torchtext/data/iterator.py", line 136, in __len__
    raise NotImplementedError
NotImplementedError

To Reproduce

Sorry I currently don't have a minimal example. The issue will always occur when torchtext.data.Iterator gets a batch_size_fn passed in. If the fix is not convincing I can take the time and construct a code example. Hope this is not necessary.

Code sample

I created my own Iterator for a Skip-Thought model, that dynamically batches sentences together. This might be unnecessary complex, or even not really useful however it revealed that issue described above when using torchtext. For context here is a code excerpt that creates the issue:

import torchtext
...
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.current))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.next) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

class MyIterator(torchtext.data.Iterator):

    def create_batches(self):
        if self.train:
            def pool(d, random_shuffler):
                for p in data.batch(d, self.batch_size * 100):
                    p_batch = data.batch(
                            sorted(p, key=self.sort_key),
                            self.batch_size, self.batch_size_fn)
                    for b in random_shuffler(list(p_batch)):
                        yield b
            self.batches = pool(self.data(), self.random_shuffler)

        else:
            self.batches = []
            for b in data.batch(self.data(), self.batch_size,
                                self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))

...
class SkipThoughts(pl.LightningModule):
...
    @pl.data_loader
    def train_dataloader(self):
        train_iter = MyIterator(self.my_train_dataloader, batch_size=self.batch_size, repeat=False,
                                   sort_key=lambda x:
                                   data.interleave_keys(len(x.current),
                                                        data.interleave_keys(len(x.prev),
                                                                             len(x.next))),
                                   batch_size_fn=batch_size_fn, train=True,
                                   shuffle=True)
        return train_iter

But this happens whenever a batch_size_fn is used in torchtext. Because it is unknown how many batches the data set will have torchtext len method returns a NotImplementedError. See code snipped below:

def __len__(self):
    if self.batch_size_fn is not None:
        raise NotImplementedError
    return math.ceil(len(self.dataset) / self.batch_size)

Expected behavior

The function _has_len tests if len can is available and then returns True, otherwise False. It shoudl return False if NotImplementedError is raised.

Environment

/Users/thomas/virtualenv/Python3/PyTorch/env/bin/python /Users/thomas/scm/OakDataPrep/collect_env_details.py

  • CUDA:
    • GPU:
    • available: False
    • version: None
  • Packages:
    • numpy: 1.18.2
    • pyTorch_debug: False
    • pyTorch_version: 1.5.0
    • pytorch-lightning: 0.8.0
    • tensorboard: 2.2.0
    • tqdm: 4.45.0
  • System:
    • OS: Darwin
    • architecture:
      • 64bit
    • processor: i386
    • python: 3.7.7
    • version: Darwin Kernel Version 19.5.0: Tue May 26 20:41:44 PDT 2020; root:xnu-6153.121.2~2/RELEASE_X86_64

Process finished with exit code 0

Additional context

Issue occur with Pytorch-Lighning 0.8 and Torchtext 0.6

@thschaaf thschaaf added bug Something isn't working help wanted Open to be worked on labels Jun 19, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@williamFalcon
Copy link
Contributor

good catch! mind submitting a PR?

@thschaaf
Copy link
Contributor Author

Absolutely. The fix is easy. How important is a test for that case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants