Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP not moving batch to device? #4987

Closed
cccntu opened this issue Dec 6, 2020 · 14 comments
Closed

DDP not moving batch to device? #4987

cccntu opened this issue Dec 6, 2020 · 14 comments
Labels
distributed Generic distributed-related topic duplicate This issue or pull request already exists won't fix This will not be worked on

Comments

@cccntu
Copy link
Contributor

cccntu commented Dec 6, 2020

Hi, I am using 1.0.8. I encountered error saying the input is not on the same device as the model. I printed the inputs and found out they are on cpu. I noticed this code below does not move the inputs to device.
https://github.com/PyTorchLightning/pytorch-lightning/blob/0979e2ce0f04ffa4facc13f08dc8d1612cdeae3e/pytorch_lightning/accelerators/ddp_accelerator.py#L153-L159
I added 2 lines of code and it seems to run.

batch = args[0]                                                                                                                                                                    
args[0] = self.batch_to_device(batch, self.trainer.model.device)                                                                                                                   

But the loss doesn't update during epoch and becomes nan after one epoch.
Is there a simple ddp example I can run?
Thanks!

@github-actions
Copy link
Contributor

github-actions bot commented Dec 6, 2020

Hi! thanks for your contribution!, great first issue!

@SeanNaren
Copy link
Contributor

SeanNaren commented Dec 6, 2020

Hi @cccntu, DDP moves the batch to the device internally, hence why this is missing from the accelerator code.

The examples all work with DDP, have a look here: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/simple_image_classifier.py

If you're able to replicate the error using the bug report model, we'll be able to help you get to a solution! https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report_model.py

@SeanNaren SeanNaren added the distributed Generic distributed-related topic label Dec 6, 2020
@cccntu
Copy link
Contributor Author

cccntu commented Dec 6, 2020

Hi @SeanNaren , Thanks for the reply.

DDP moves the batch to the device internally, hence why this is missing from the accelerator code.

What do you mean internally, can you give me some pointers? I think there should be a section in documentation for questions like "How does pl does internally?", with explanation and links to actual code.

My current guess is that I am using huggingface's BatchEncoding as input to training_step(), while BatchEncoding supports .to() method, but "internally" pl does not use this method. And that's why I needed to move it myself. @SeanNaren Can you help answer this question?

Also I think I found a bug about ddp:
ddp always uses first k gpus, despite specifying using other gpus
python pl_examples/basic_examples/simple_image_classifier.py --gpus=3,4 --acce=ddp # this uses gpu 0,1
python pl_examples/basic_examples/simple_image_classifier.py --gpus=3,4 --acce=dp # this uses gpu 3,4

@SeanNaren
Copy link
Contributor

That is strange! Internally we just use standard PyTorch DDP, which scatters inputs before passing them into the forward function: https://github.com/pytorch/pytorch/blob/v1.7.0/torch/nn/parallel/distributed.py#L617

This brings them to the current GPU devices automatically before passing through to the forward. Strange that this isn't happening automatically for BatchEncoding. Could you give me a sample code to reproduce this error?

@awaelchli
Copy link
Contributor

awaelchli commented Dec 7, 2020

@SeanNaren the DDP built-in scatter only moves data to the devices for tensors and tensors in collections like list, tuple, etc.
https://github.com/pytorch/pytorch/blob/e85d494707b835c12165976b8442af54b9afcb26/torch/nn/parallel/scatter_gather.py#L5

In our accelerator base class we have this method
https://github.com/PyTorchLightning/pytorch-lightning/blob/e952dee2921506c80cc9fb93e8731d7ee137ce59/pytorch_lightning/accelerators/accelerator.py#L69
which calls our utility function
https://github.com/PyTorchLightning/pytorch-lightning/blob/471ca375babad9093abf60683a8d0647ac33d4a8/pytorch_lightning/utilities/apply_func.py#L92

This is currently called for single gpu and tpu accelerators and not for distributed accelerators. If @cccntu runs with these accelerators I am sure their BatchEncoding objects will be correctly moved, because we move everything that defines a .to method automatically.

Related: #1206, #2350

@awaelchli
Copy link
Contributor

Also I think I found a bug about ddp:
ddp always uses first k gpus, despite specifying using other gpus
python pl_examples/basic_examples/simple_image_classifier.py --gpus=3,4 --acce=ddp # this uses gpu 0,1
python pl_examples/basic_examples/simple_image_classifier.py --gpus=3,4 --acce=dp # this uses gpu 3,4

This should be accelerator=ddp.
How are you verifying on which gpu it runs? Please use nvidia-smi. If you just look at the global rank it will always be in the range 0, ..., gpus-1.

@cccntu
Copy link
Contributor Author

cccntu commented Dec 7, 2020

I am working on the minimum example, but I think @awaelchli is right. Still, using batch_to_device doesn't scatter inputs in BatchEncoding across devices, right?

@awaelchli
Copy link
Contributor

awaelchli commented Dec 7, 2020

You don't necessarily need to work on a reproducible example, since this is not a bug and rather a limitation of scattering. We are aware of this. But if you want it is certainly appreciated.

Still, using batch_to_device doesn't scatter inputs in BatchEncoding across devices, right?

No, how could it? Custom python objects like BatchEncoding don't have a batch size / batch dimension so scattering is not defined there. As proposed in #1206 and #2350 we need a way for the user to define scatter and gather for these objects.

@awaelchli awaelchli added the duplicate This issue or pull request already exists label Dec 7, 2020
@cccntu
Copy link
Contributor Author

cccntu commented Dec 7, 2020

@awaelchli Thanks for the explanation and links. Here is the reproducible example I wrote for future reference. https://gist.github.com/cccntu/967d9624d37024875e6cd094d2bf13ae

Also I think I found a bug about ddp:
ddp always uses first k gpus, despite specifying using other gpus
python pl_examples/basic_examples/simple_image_classifier.py --gpus=3,4 --acce=ddp # this uses gpu 0,1
python pl_examples/basic_examples/simple_image_classifier.py --gpus=3,4 --acce=dp # this uses gpu 3,4

This should be accelerator=ddp.
How are you verifying on which gpu it runs? Please use nvidia-smi. If you just look at the global rank it will always be in the range 0, ..., gpus-1.

I just checked again using nvidia-smi, seems the computation does run on gpu 3,4, however it also occupies approximately the same amount of memory on gpu 0,1.
I had process running on gpu 0 and caused OOM, so I assumed it was running on gpu 0,1.
python pl_examples/basic_examples/simple_image_classifier.py --gpus=3,4 --accelerator=ddp

@awaelchli
Copy link
Contributor

the oom may be related to #4705

@stale
Copy link

stale bot commented Jan 6, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jan 6, 2021
@SeanNaren
Copy link
Contributor

SeanNaren commented Jan 6, 2021

This is fixed by #5195 for single device/single process (like DDP).

@stale stale bot removed the won't fix This will not be worked on label Jan 6, 2021
@stale
Copy link

stale bot commented Feb 6, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 6, 2021
@SeanNaren
Copy link
Contributor

This should be fixed via #5195!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed Generic distributed-related topic duplicate This issue or pull request already exists won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants