Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process runs on more GPUs than specified #958

Closed
sahnimanas opened this issue Feb 26, 2020 · 7 comments · Fixed by #1349, #2029 or #6934
Closed

Process runs on more GPUs than specified #958

sahnimanas opened this issue Feb 26, 2020 · 7 comments · Fixed by #1349, #2029 or #6934
Labels
help wanted Open to be worked on

Comments

@sahnimanas
Copy link

sahnimanas commented Feb 26, 2020

I have a single 8-GPU machine with a faulty GPU0.
I'm running imagenet_example.py on 7 GPUs on this machine by specifying gpus=[1,2,3,4,5,6,7] in the Trainer i.e. I do not want to use GPU0

However, when i run nvidia-smi, I see the Trainer's pid shows on all 8 GPUs, just with lower memory on GPU0 (see output below). I also find it to be slower than non-PL code by about 4x. I don't see this behavior if I manually set CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 followed by gpus=7 in Trainer. Similarly, it works fine when using a single GPU with, say, gpus=[1].
I'm not sure if it's relevant but I also see gpu=0 in the tqdm progress bar

nvidia-smi with Trainer(gpus=[1,2,3,4,5,6,7]) and CUDA_VISIBLE_DEVICES unset

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     40155      C   python                                       719MiB |
|    1     40155      C   python                                      6003MiB |
|    2     40155      C   python                                      6019MiB |
|    3     40155      C   python                                      6019MiB |
|    4     40155      C   python                                      6019MiB |
|    5     40155      C   python                                      6019MiB |
|    6     40155      C   python                                      6019MiB |
|    7     40155      C   python                                      6019MiB |
+-----------------------------------------------------------------------------+

nvidia-smi with Trainer(gpus=7) and CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    1     34452      C   python                                      6003MiB |
|    2     34452      C   python                                      6019MiB |
|    3     34452      C   python                                      6019MiB |
|    4     34452      C   python                                      6019MiB |
|    5     34452      C   python                                      6019MiB |
|    6     34452      C   python                                      6019MiB |
|    7     34452      C   python                                      6019MiB |
+-----------------------------------------------------------------------------+

Expected behavior

The process should run on the specified GPUs without manually setting CUDA_VISIBLE_DEVICES

Environment

PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.8
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti
GPU 4: GeForce RTX 2080 Ti
GPU 5: GeForce RTX 2080 Ti
GPU 6: GeForce RTX 2080 Ti
GPU 7: GeForce RTX 2080 Ti

Nvidia driver version: 418.87.00
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] pytorch-lightning==0.6.0
[pip] torch==1.4.0
[pip] torch-lr-finder==0.1.2
[pip] torchvision==0.5.0
[conda] blas                      1.0                         mkl
[conda] mkl                       2020.0                      166
[conda] mkl-service               2.3.0            py38he904b0f_0
[conda] mkl_fft                   1.0.15           py38ha843d7b_0
[conda] mkl_random                1.1.0            py38h962f231_0
[conda] pytorch                   1.4.0           py3.8_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] pytorch-lightning         0.6.0                    pypi_0    pypi
[conda] torch-lr-finder           0.1.2                    pypi_0    pypi
[conda] torchvision               0.5.0                py38_cu101    pytorch
@sahnimanas sahnimanas added bug Something isn't working help wanted Open to be worked on labels Feb 26, 2020
@github-actions
Copy link
Contributor

Hey, thanks for your contribution! Great first issue!

@Borda
Copy link
Member

Borda commented Feb 26, 2020

Thx for comment, I do not think that the training is fully running on the GPU0, just some memory allocation... Could you also share the GPU utilization during the training process?

@Borda Borda added question Further information is requested information needed and removed bug Something isn't working labels Feb 26, 2020
@sahnimanas
Copy link
Author

I also think that the training is likely not running on GPU0 but not sure why the pid shows up on it
Here's the full output of nvidia-smi during training

Wed Feb 26 18:44:17 2020
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.87.00    Driver Version: 418.87.00    CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce RTX 208...  On   | 00000000:3D:00.0 Off |                  N/A |
| 31%   41C    P8    33W / 250W |    730MiB / 10989MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  On   | 00000000:3E:00.0 Off |                  N/A |
| 31%   48C    P2    99W / 250W |   6014MiB / 10989MiB |     14%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce RTX 208...  On   | 00000000:60:00.0 Off |                  N/A |
| 31%   50C    P2    98W / 250W |   6030MiB / 10989MiB |     13%      Default |
+-------------------------------+----------------------+----------------------+
|   3  GeForce RTX 208...  On   | 00000000:61:00.0 Off |                  N/A |
| 30%   45C    P2    73W / 250W |   6030MiB / 10989MiB |     14%      Default |
+-------------------------------+----------------------+----------------------+
|   4  GeForce RTX 208...  On   | 00000000:B1:00.0 Off |                  N/A |
| 32%   45C    P2    73W / 250W |   6030MiB / 10989MiB |     14%      Default |
+-------------------------------+----------------------+----------------------+
|   5  GeForce RTX 208...  On   | 00000000:B2:00.0 Off |                  N/A |
| 32%   45C    P2    68W / 250W |   6030MiB / 10989MiB |     14%      Default |
+-------------------------------+----------------------+----------------------+
|   6  GeForce RTX 208...  On   | 00000000:DA:00.0 Off |                  N/A |
| 31%   51C    P2    92W / 250W |   6030MiB / 10989MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   7  GeForce RTX 208...  On   | 00000000:DB:00.0 Off |                  N/A |
| 31%   44C    P2    99W / 250W |   6030MiB / 10989MiB |     13%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     35809      C   python                                       719MiB |
|    1     35809      C   python                                      6003MiB |
|    2     35809      C   python                                      6019MiB |
|    3     35809      C   python                                      6019MiB |
|    4     35809      C   python                                      6019MiB |
|    5     35809      C   python                                      6019MiB |
|    6     35809      C   python                                      6019MiB |
|    7     35809      C   python                                      6019MiB |
+-----------------------------------------------------------------------------+

@sahnimanas
Copy link
Author

BTW I switched the distributed backend from dp (default) to ddp and this went away. No PID is shown on GPU0 and its memory usage is at 11MiB (same as any other inactive GPU)

@Borda
Copy link
Member

Borda commented Feb 27, 2020

Ok, in such case I would assume it as resolved, but feel free to reopne it if you need to 🤖

@Borda Borda closed this as completed Feb 27, 2020
shubhamagarwal92 added a commit to shubhamagarwal92/pytorch-lightning that referenced this issue Mar 8, 2020
shubhamagarwal92 added a commit to shubhamagarwal92/pytorch-lightning that referenced this issue Mar 8, 2020
shubhamagarwal92 added a commit to shubhamagarwal92/pytorch-lightning that referenced this issue Mar 12, 2020
shubhamagarwal92 added a commit to shubhamagarwal92/pytorch-lightning that referenced this issue Mar 12, 2020
shubhamagarwal92 added a commit to shubhamagarwal92/pytorch-lightning that referenced this issue Mar 15, 2020
shubhamagarwal92 added a commit to shubhamagarwal92/pytorch-lightning that referenced this issue Mar 15, 2020
shubhamagarwal92 added a commit to shubhamagarwal92/pytorch-lightning that referenced this issue Mar 15, 2020
shubhamagarwal92 added a commit to shubhamagarwal92/pytorch-lightning that referenced this issue Mar 15, 2020
williamFalcon added a commit that referenced this issue Apr 3, 2020
* SA: for #958: set torch cuda device when finding root

* SA: for #958: removing root gpu hack in trainer/evaluation_loop

* SA: setting torch cuda device

* comment line too long

* check if root gpu exists or available

* Incorporating suggestions on #1094

* since root gpu returns none instead of -1 for cpu

* undo changes

* fixed dp memory thing

Co-authored-by: Shubham Agarwal <shubhamagarwal92@gmail.com>
alexeykarnachev pushed a commit to alexeykarnachev/pytorch-lightning that referenced this issue Apr 4, 2020
* SA: for Lightning-AI#958: set torch cuda device when finding root

* SA: for Lightning-AI#958: removing root gpu hack in trainer/evaluation_loop

* SA: setting torch cuda device

* comment line too long

* check if root gpu exists or available

* Incorporating suggestions on Lightning-AI#1094

* since root gpu returns none instead of -1 for cpu

* undo changes

* fixed dp memory thing

Co-authored-by: Shubham Agarwal <shubhamagarwal92@gmail.com>
@yakobyd
Copy link

yakobyd commented Apr 14, 2020

I have cloned the repository yesterday (pytorch-lightning==0.7.4.dev0) and there are some edge cases that are still not fixed by #1349. Below is minimal code for reproduction:

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl


class Model(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(1000, 10)

    def forward(self, x):
        return torch.relu(self.l1(x))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        return {'avg_loss': avg_loss}

    def train_dataloader(self):
        data = torch.rand(4096, 1000)
        labels = torch.randint(high=10, size=(4096,))
        return DataLoader(list(zip(data, labels)), batch_size=64, pin_memory=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


trainer = pl.Trainer(gpus=[3])
model = Model()

trainer.fit(model)

After running the above code, nvidia-smi outputs the following:

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     26657      C   python                                       479MiB |
|    3     26657      C   python                                       482MiB |
+-----------------------------------------------------------------------------+

I have tested a few scenarios, and found that this is caused by two factors:

  • In the DataLoader, pin_memory==True.
  • Presence of training_epoch_end.

These factors must happen together for the problem to arise. For example, if pin_memory==True, but training_epoch_end is not implemented, then the GPU memory does not leak.

Similarly, the problem happens if the validation phase is defined together with validation_epoch_end and the corersponding validation DataLoader has pin_memory==True. Moreover, even if the training DataLoader defines pin_memory==True and validation_epoch_end is also defined, the GPU memory leaks.

I am afraid I will not have time to dig deeper here. But hopefully the maintainers will find this usefull.

@Borda Borda reopened this Apr 14, 2020
@jiahuei
Copy link

jiahuei commented May 10, 2020

This issue occurred to me during validation sanity check even if pin_memory==False.
My validation_epoch_end is defined.

I had to use this as a temporary fix

if __name__ == "__main__":
    args = parse_arguments()
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, args.gpus))
    args.gpus = list(range(len(args.gpus)))
    ...
    trainer = pl.Trainer(gpus=args.gpus)

However, if I call torch.cuda.device_count() beforehand like so, the issue still occurs. There might be something that occurred when device_count is called.

    if len(args.gpus) != torch.cuda.device_count():
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, args.gpus))
        args.gpus = list(range(len(args.gpus)))

tullie pushed a commit to tullie/pytorch-lightning that referenced this issue Jun 7, 2020
* SA: for Lightning-AI#958: set torch cuda device when finding root

* SA: for Lightning-AI#958: removing root gpu hack in trainer/evaluation_loop

* SA: setting torch cuda device

* comment line too long

* check if root gpu exists or available

* Incorporating suggestions on Lightning-AI#1094

* since root gpu returns none instead of -1 for cpu

* undo changes

* fixed dp memory thing

Co-authored-by: Shubham Agarwal <shubhamagarwal92@gmail.com>
@Borda Borda removed the question Further information is requested label Dec 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment