Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve moving data / model to GPU using torchtext #1245

Closed
elkotito opened this issue Mar 26, 2020 · 28 comments · Fixed by #2379
Closed

Improve moving data / model to GPU using torchtext #1245

elkotito opened this issue Mar 26, 2020 · 28 comments · Fixed by #2379
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@elkotito
Copy link
Contributor

🚀 Feature

Improve moving data and model to GPU if torchtext is used.

Motivation

Case 1:

Batch object generated by torchtext.data.Iterator doesn't follow the rules described here https://github.com/PyTorchLightning/pytorch-lightning/blob/45d671a4a81788b9d97fd6b47763816926e58e95/pytorch_lightning/trainer/distrib_parts.py#L420

As the result data is not moved to GPU. torchtext.data.Iterator is returned by method train_dataloader. Take in mind that torchtext.data.Iterator has a device argument that is not properly utilized by pytorch-ligthning.

    @ptl.data_loader
    def train_dataloader(self):
        ...  
        return Iterator(dataset=dataset, batch_size=self.batch_size, shuffle=False, device=DEVICE)

Partially reported here #226

Case 2

Using torchtext you can read pre-trained embeddings and create nn.Embedding object as follows

    def train_dataloader(self):
        ...
        self.text_field.build_vocab(
            dataset,
            vectors=Vectors("/data/embeddings/glove/glove.840B.300d.txt"),
        )

        self.embeddings = nn.Embedding(
            ...
            padding_idx=self.text_field.vocab.stoi[PAD_TOKEN],
            _weight=self.text_field.vocab.vectors.to(DEVICE),
        )

nn.Embedding is clearly dependent on self.text_field.vocab and this is in turn dependent on dataset that is used by train_dataloader. Currently any part of the model that is not created fully in __init__ of the ptl.LigthningModule is not moved to the GPU. It requires still to have a global variable that determines a device i.e. DEVICE. It makes Trainer(n_gpus=...) useless.

Pitch

I would like not to worry about moving data to GPU using torchtext combined with pytorch-lightning.

@elkotito elkotito added feature Is an improvement or enhancement help wanted Open to be worked on labels Mar 26, 2020
@Borda
Copy link
Member

Borda commented Mar 26, 2020

@mateuszpieniak which version of lightning are you actually using?
@PyTorchLightning/core-contributors suggestions? ^^

@elkotito
Copy link
Contributor Author

@Borda Version == 0.6.0, but I see that there is some new release ^^ Let me check it for the new version.

  • Some solution for Case 1 could be not only looking for .to() or .cuda(), attribute, but also .device field inside the class?
  • Some solution for Case 2 could be moving the model to GPU a bit further in "code but it is not so straightforward since

https://github.com/PyTorchLightning/pytorch-lightning/blob/45d671a4a81788b9d97fd6b47763816926e58e95/pytorch_lightning/trainer/distrib_parts.py#L457

moving the model to GPU and configures AMP straightaway.

@elkotito
Copy link
Contributor Author

elkotito commented Mar 26, 2020

@Borda Well, for version == 0.7.1 I have a slightly different issue.

As I mentioned before, I build a vocabulary using the training dataset in train_dataloader.

    def train_dataloader(self):
        ...
        self.text_field.build_vocab(
            dataset,
            vectors=Vectors("/data/embeddings/glove/glove.840B.300d.txt"),
        )

It creates a self.text_field.vocab object field. I use the same self.text_field (class torchtext.data.Field) for val_dataloader. It is expected since I want to have the same preprocessing / tokenization / embeddings for validation and test set. This time I got

AttributeError: 'Field' object has no attribute 'vocab'

because of running sanity check for the validation set before training happens (or to be precise before train_dataloader happens) - https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py#L883

I thought that I could disable this using fast_dev_run, but it seems like such condition doesn't work properly - https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py#L879

@Borda
Copy link
Member

Borda commented Mar 26, 2020

@neggert @srush pls ^^

@zeeshansayyed
Copy link

zeeshansayyed commented Apr 26, 2020

@Borda I am still not able to get torchtext iterators to work with pytorch lightning even on a single GPU as @mateuszpieniak mentioned in his first Case above and in #226

The (very bad) workaround which I have been using so far is to manually move the data to a GPU while creating the iterator (by passing the device to the torchtext iterator). This prevents me from utilizing multiple GPUs and also unnecessarily loads all the data on the GPU in the beginning. Has there been any progress on this issue?

Thanks

@awaelchli
Copy link
Member

awaelchli commented Apr 30, 2020

Sorry, don't know torchtext well, but if I understand correctly, the problem is that the internal method that moves the data to gpu checks for some known types (tuple, list, dict) but when passing a custom batch object, it does not know how to do that.

My pitch:
Introduce a model hook (or callback, not sure?).
__transfer_data_to_device will call the hook first:

def __transfer_data_to_device(...)
    
    batch = model.transfer_batch_to_device(batch, device)  # call the hook
    if batch is not None:  # for example, it returns None if user did not override it
        return batch

    # otherwise continue with built in detection of tuples etc.

This allows anyone to implement their own transfers for custom objects.
Maybe we should still support the torchtext out of the box, but if we can't, this could be an option.
@mateuszpieniak would such a model hook help you solve the problem with the dataloading and moving?

@elkotito
Copy link
Contributor Author

elkotito commented May 3, 2020

@awaelchli You understood the case 1 correctly. Well, it should work 😃 I wasn't aware of the hook concept in PyTorchLightning. I'll let you know whether it works. Thanks!

@elkotito
Copy link
Contributor Author

elkotito commented May 6, 2020

Case 1) @awaelchli I ended up with creating my custom Trainer and overriding a private method:

    def _TrainerDPMixin__transfer_data_to_device(self, batch, device, gpu_id=None):
        """ Handles torchtext.data.Batch. """
        ...
        # when torchtext.data.Batch
        if isinstance(batch, Batch):
            for field in batch.input_fields + batch.target_fields:
                data = self._TrainerDPMixin__transfer_data_to_device(getattr(batch, field), device, gpu_id)
                setattr(batch, field, data)
            return batch

Case 2) I just build dictionary in the __init__ method.

If we don't plan to support torchtext, then then I think we can consider this issue closed.

@awaelchli
Copy link
Member

Case 1: that certainly works. If we introduced a hook for that, then the user of PL would not need to modify the source code of Trainer. It is not meant to be modified.

Case 2: If you want a tensor in your model be moved with the model (when .to(..) is called on the model) then register it as a buffer:
https://pytorch.org/docs/stable/nn.html?highlight=register_buffer#torch.nn.Module.register_buffer
It can be accessed just like a regular attribute of the model, but there is no need to pass in the device.

@awaelchli
Copy link
Member

awaelchli commented May 6, 2020

If we don't plan to support torchtext, then then I think we can consider this issue closed.

I would not close it yet. @PyTorchLightning/core-contributors should torchtext be supported directly in PL or do we provide a hook as I suggested, or both?

@Borda
Copy link
Member

Borda commented May 6, 2020

@williamFalcon ^^

@williamFalcon
Copy link
Contributor

  1. let’s introduce s formal hook
  2. let’s formally support torchtext!

good ideas everyone!

@jeremyjordan
Copy link
Contributor

@awaelchli i would prefer a hook since that will be generally useful and will likely reduce the long-term maintenance burden

@awaelchli
Copy link
Member

awaelchli commented Jun 7, 2020

@mateuszpieniak the model hook I added should make it easier to work with the Batch object in torchtext, but I have not actually tested it yet in a full example.
Do you by chance have already have mini example for torchtext with PL? I have 0 experience with torchtext, but will look into it to provide better support.

@elkotito
Copy link
Contributor Author

elkotito commented Jun 8, 2020

@awaelchli Super simple classification example with a bag of embeddings. It should work both on GPU & CPU.

https://gist.github.com/mateuszpieniak/f290b3a727db7e94b9da0bd3bd2e33c1

@awaelchli
Copy link
Member

great! i can start with that!

@Borda
Copy link
Member

Borda commented Jun 18, 2020

nice, @mateuszpieniak or if you want to take is and @awaelchli may assist you :]

@elkotito
Copy link
Contributor Author

@Borda Sure, I will do this over the weekend unless @awaelchli already started to work on this :]

@awaelchli
Copy link
Member

Not yet started, and I don't mind you taking it if you are interested.

@celsofranssa
Copy link

Great discussion.
Currently, is it possible or not to use PyTorch Lightning in multi-GPUs training?

@williamFalcon
Copy link
Contributor

? of course it's possible... this is at the core of lightning...

If you specifically mean for torchtext, then it sounds like there's some edge case it introduces that we are working on.

@williamFalcon
Copy link
Contributor

@mateuszpieniak we're preparing 0.8.2 now. Want to submit this PR?

@celsofranssa
Copy link

? of course it's possible... this is at the core of lightning...

If you specifically mean for torchtext, then it sounds like there's some edge case it introduces that we are working on.

Sorry, indeed I meant multi-GPU training with torchtext.

@elkotito
Copy link
Contributor Author

@williamFalcon @ceceu Sorry for the delay. I've just landed a PR, so @awaelchli please take a look.

@ceceu I don't think it works for multi GPU due to the different issue #2350 .

@celsofranssa
Copy link

Thanks @mateuszpieniak,

But it’s sad to hear that.

@awaelchli
Copy link
Member

we might be able to override / extend the scatter and gather in DataParallel to make use of the transfer data hook, but not sure if that's safe... I havent had the time to look into that. I agree the hook right now is not very useful :(

@awaelchli
Copy link
Member

@mateuszpieniak maybe it was a bit premature to add an extra implementation to move torchtext.Batch data to the device. It seems they are dropping the Batch class from torchtext, it is outdated.
pytorch/text#861
It seems they have a different, simpler dataloading pipeline now and the batches can be directly moved to GPU.

@elkotito
Copy link
Contributor Author

elkotito commented Jul 7, 2020

@awaelchli Perhaps it was. The actual need for Batch is BucketIterator. In new versions in PyTorch, they introduced BatchSampler and Sampler arguments for DataLoader, so it feels natural to implement bucket iteration using such abstraction. For example AllenNLP moved from Iterator class into PyTorch's DataLoader. Well, we will see how they resolve it :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants