Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using the Trainer class more than once fails with "Address already in use" with the DDP backend #2537

Closed
JanSellner opened this issue Jul 7, 2020 · 14 comments · Fixed by #2997
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@JanSellner
Copy link
Contributor

🐛 Bug

It is not possible to create and use the Trainer class more than once with the DDP backend since the program crashes the second time with RuntimeError: Address already in use.

To Reproduce

Here is a minimal code example which reproduces the issue:

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __getitem__(self, index: int):
        return torch.rand(10), torch.rand(10).type(torch.int64)

    def __len__(self) -> int:
        return 10


class LitModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(10, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y, y_hat)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def train_dataloader(self):
        dataset = MyDataset()
        loader = DataLoader(dataset, batch_size=2, shuffle=True)
        return loader


def train():
    model = LitModel()

    trainer = Trainer(gpus=2, distributed_backend="ddp", max_epochs=2)
    trainer.fit(model)


train()
train()  # This second call fails

Expected behavior

It should be possible to use the Trainer object multiple times. In my case, it breaks my k-fold validation loop since I create a new Trainer for each fold.

Environment

* CUDA:
        - GPU:
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
        - available:         True
        - version:           10.2
* Packages:
        - numpy:             1.18.5
        - pyTorch_debug:     False
        - pyTorch_version:   1.5.0
        - pytorch-lightning: 0.8.4
        - tensorboard:       2.2.2
        - tqdm:              4.46.1
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.3
        - version:           #109-Ubuntu SMP Fri Jun 19 11:33:10 UTC 2020
@JanSellner JanSellner added bug Something isn't working help wanted Open to be worked on labels Jul 7, 2020
@SkafteNicki
Copy link
Member

Your problem seems similar to #401.
Could you try the solution proposed in that issue: set the master port yourself.

@JanSellner
Copy link
Contributor Author

You mean something like

os.environ['MASTER_PORT'] = "44513"
train()
os.environ['MASTER_PORT'] = "44514"
train()

I don't get the RuntimeError: Address already in use anymore but now the program hangs up completely. No error message at all but the program is still running and both ports are occupied:

(base) ➜  ~ sudo netstat -tulpn | grep 4451
tcp        0      0 0.0.0.0:44513           0.0.0.0:*               LISTEN      21791/python
tcp        0      0 0.0.0.0:44514           0.0.0.0:*               LISTEN      21791/python

@awaelchli
Copy link
Contributor

Check master, this PR #2512 chooses the port randomly. I think that would solve your issue. Not sure :)

@JanSellner
Copy link
Contributor Author

Using the latest master (and latest PyTorch) I now get yet another, different error:

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0,1]
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/2
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/2
----------------------------------------------------------------------------------------------------
distributed_backend=ddp
All DDP processes registered. Starting ddp with 2 processes
----------------------------------------------------------------------------------------------------

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 110
/home/*****/Downloads/pytorch-lightning/pytorch_lightning/utilities/distributed.py:25: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 2: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 671.98it/s, loss=0.058, v_num=6]
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/2
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/2
Traceback (most recent call last):
  File "/home/*****/Downloads/ddp_issue.py", line 50, in <module>
    train()
  File "/home/*****/Downloads/ddp_issue.py", line 46, in train
    trainer.fit(model)
  File "/home/*****/Downloads/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 955, in fit
    self.ddp_train(process_idx=task, q=None, model=model)
  File "/home/*****/Downloads/pytorch-lightning/pytorch_lightning/trainer/distrib_data_parallel.py", line 553, in ddp_train
    model = model.configure_ddp(model, device_ids)
  File "/home/*****/Downloads/pytorch-lightning/pytorch_lightning/core/lightning.py", line 896, in configure_ddp
    model = LightningDistributedDataParallel(
  File "/home/*****/miniconda3/envs/pl/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 327, in __init__
    self._distributed_broadcast_coalesced(
  File "/home/*****/miniconda3/envs/pl/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 546, in _distributed_broadcast_coalesced
    dist._broadcast_coalesced(self.process_group, tensors, buffer_size)
RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:530, unhandled system error, NCCL version 2.4.8

Are you sure that the socket used for ddp is properly closed when the Trainer object gets destroyed?

@yippp
Copy link

yippp commented Aug 2, 2020

I also meet unhandled system error problem, who know how to solve it?

@JanSellner
Copy link
Contributor Author

Not really a solution but maybe some hints towards the bug: I noticed that this seems to be an issue with ddp in comparison with ddp_spawn. If you watch for open sockets with watch -n0.1 "sudo netstat -tulpn | grep 12355" during program execution, you will notice that with ddp_spawn the socket is properly closed after the first run. This is not the case with ddp. So, for some reason, the socket remains open with ddp leading to the above errors.

@dthiagarajan
Copy link

This also seems to be the root cause of an issue when trying to do LR find on distributed compute.

@awaelchli
Copy link
Contributor

Fixed it in #2790
Should work already, just cleaning things up and adding tests.

@dthiagarajan
Copy link

Thanks! Just to confirm, is the crux of the fix here the DistributedConnection.reset_connection conditional you've added in that PR?

@awaelchli
Copy link
Contributor

awaelchli commented Aug 10, 2020

Yes! The PR solves many interconnected issues, but one case is the following situation

  • .fit() initializes nccl on port X on all ranks, then finishes and destroys ddp connection on ranks > 0
  • process returns to rank 0
  • rank 0 starts 2nd .fit (or .test) with new processes launching an initing ddp connection
  • now processes in rank > 0 complain that port X is already in use, because rank 0 is still connected to it.

solution is to find a new free port and connect all processes to that. (this applies only to single node training)
previously (on master), this was solved by choosing a random port, but that will yield different random ports in each subrocess. and it could happen that the port is already occupied. therefore we need to do it in rank 0 only and then broadcast the chosen port to all other subprocesses.

@awaelchli
Copy link
Contributor

awaelchli commented Aug 16, 2020

@JanSellner @dthiagarajan Just a quick follow up here why this issue got closed:
In your case, you are calling trainer.fit() multiple times, or even instantiate Trainer multiple times. In DDP mode, this cannot work since the Trainer will call the same script multiple times. When this happens, we have no way of controlling which trainer.fit() is executed. I added a note in the docs.

For your use case it means:

  • if you want to call trainer.fit() multiple times, use distributed_backend="ddp_spawn"
  • if you want to use distributed_backend="ddp", you must make sure your script only calls trainer.fit once (or trainer.test)

It is a tradeoff between these two backends, both have their advantages and disadvantages, as outlined in the docs.

(and ignore my previous post here, #2537 (comment))

@xtzd
Copy link

xtzd commented Sep 12, 2020

Hi, I am still stuck in here. So the final solution is change 'ddp' to 'ddp_spawn' mode? I was stuck when using ddp to fit kfold trainers.

@williamFalcon
Copy link
Contributor

exactly

@ShaneTian
Copy link

I often meet this issue when using deepspeed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
8 participants