Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to launch multiple gpus nodes #2578

Closed
mortonjt opened this issue Jul 10, 2020 · 12 comments
Closed

Unable to launch multiple gpus nodes #2578

mortonjt opened this issue Jul 10, 2020 · 12 comments
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@mortonjt
Copy link
Contributor

mortonjt commented Jul 10, 2020

🐛 Bug

I'm having trouble launching multiple GPU nodes with pytorch-lightning-0.8.5-dev. I'm getting the following error

Traceback (most recent call last):
  File "/home/jmorton/miniconda3/envs/alignment/bin/deepblast-train", line 7, in <module>
    exec(compile(f.read(), __file__, 'exec'))
  File "/home/jmorton/research/gert/deepblast/scripts/deepblast-train", line 67, in <module>
    main(hparams)
  File "/home/jmorton/research/gert/deepblast/scripts/deepblast-train", line 47, in main
    trainer.fit(model)
  File "/home/jmorton/miniconda3/envs/alignment/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 964, in fit
    self.set_random_port()
  File "/home/jmorton/miniconda3/envs/alignment/lib/python3.8/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 392, in set_random_port
    assert self.num_nodes == 1, 'random port can only be called from single node training'
AssertionError: random port can only be called from single node training

To Reproduce

I've setup my model similar to as follows

    parser = LightningAligner.add_model_specific_args(parser)
    args = parser.parse_args()
    model = LightningAligner(args)

    trainer = Trainer(
        max_epochs=10,
        gpus=4,
        num_nodes=2,
        accumulate_grad_batches=10,
        distributed_backend='ddp',
        precision=32,
        check_val_every_n_epoch=1,
        fast_dev_run=False
    )

Environment

Output of python collect_env_details.py

* CUDA:
        - GPU:
                - Tesla V100-PCIE-32GB
                - Tesla V100-PCIE-32GB
                - Tesla V100-PCIE-32GB
                - Tesla V100-PCIE-32GB
        - available:         True
        - version:           10.1
* Packages:
	- numpy:             1.17.5
	- pyTorch_debug:     False
	- pyTorch_version:   1.5.1
	- pytorch-lightning: 0.8.5-dev
	- tensorboard:       2.2.2
	- tqdm:              4.47.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.1
        - version:           #1 SMP Wed Jun 3 14:28:03 UTC 2020

Additional context

Only 4 out of 8 GPUs are recognized.

I'm curious why the assert statement is there.

@mortonjt mortonjt added bug Something isn't working help wanted Open to be worked on labels Jul 10, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@awaelchli awaelchli added the priority: 0 High priority task label Jul 10, 2020
@awaelchli
Copy link
Member

@williamFalcon was the assertion maybe supposed to go somewhere else?

@awaelchli
Copy link
Member

@mortonjt does simply removing the assertion solve the problem for you? (sorry can't test myself since I don't have multi node)

@mortonjt
Copy link
Contributor Author

I've commented the offending line, and it now it's been hanging for 16 min (the single node version is able to boot within 1 min).

@williamFalcon
Copy link
Contributor

for multinode you have to set the master port yourself since the port can’t be random bc all the nodes need to know where to connect

MASTER_PORT=1234, MASTER_ADDRESS=some.ip python main.py ...

@blackwer
Copy link

This fails first because python needs to be called by srun, which when done, should pass the buck down the road.

The second failure is seemingly because of how lightning parses slurm environment variables. The SLURM_NODELIST environment variable is not trivially formatted, with formatting sometimes being server01,server02, sometimes server[01,02], and sometimes more complex combinations of these.
The offending line is in core/lightning.py, root_node = os.environ['SLURM_NODELIST'].split(' ')[0]. If you have a comma separated list of hosts, you just end up with server01,server02 as your host, which obviously fails. Oddly, presumably because of something srun is doing to the environment before passing off to python, the server[00,01] variant works fine, but the comma separated list fails. As a hack for this particular use case, you can edit the offending line to root_node = os.environ['SLURM_NODELIST'].split(',')[0]. This will likely fail if you end up with the other variant, but works on this particular cluster.

I'm not sure the most portable way to extract hosts from SLURM_JOB_NODELIST / SLURM_NODELIST. In bash, we often just scontrol show hostnames $SLURM_NODELIST.

@rakhimovv
Copy link

rakhimovv commented Jul 30, 2020

Also seems that multi-node testing is not supported, right?
I tried to set MASTER_ADDR and MASTER_PORT, no effect

File "/trinity/home/r.rakhimov/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1411, in __test_using_best_weights
    self.set_random_port(force=True)
File "/trinity/home/r.rakhimov/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 411, in set_random_port
    assert self.num_nodes == 1, 'random port can only be called from single node training'

@haideraltahan
Copy link

Even with the edits suggested by @blackwer, the problem persists.

@asrafulashiq
Copy link
Contributor

I am facing same problem. Even if I use MASTER_PORT, or comment out the assertion line, the problem remains.

@vladisai
Copy link

vladisai commented Aug 6, 2020

Are there any updates on this? I'm having the same problem.

@williamFalcon
Copy link
Contributor

yes... probably a bug on our end. Let me push a fix tonight.

@Borda @luiscape can we set up the multi-node testing for GPUs?

@mortonjt
Copy link
Contributor Author

mortonjt commented Aug 7, 2020

Just as a heads up, I took @blackwer 's solution and now have a multi-gpu example working on slurm.
Below is a high-level example of how to launch this.

#SBATCH -N 2
#SBATCH --ntasks-per-node=4   
export SLURM_JOB_NODELIST=$(scontrol show hostnames $SLURM_JOB_NODELIST | tr '\n' ' ')
export SLURM_NODELIST=$SLURM_JOB_NODELIST
slurm_nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST)
export MASTER_ADDRESS=$(echo $slurm_nodes | cut -d' ' -f1)

srun my_script.py --nodes 2 --gpus 4 --backend ddp

@mortonjt mortonjt closed this as completed Aug 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

No branches or pull requests

9 participants