Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA device index error in distributed training #5111

Closed
qiningonline opened this issue Oct 9, 2021 · 5 comments · Fixed by #5114
Closed

CUDA device index error in distributed training #5111

qiningonline opened this issue Oct 9, 2021 · 5 comments · Fixed by #5114
Labels
question Further information is requested

Comments

@qiningonline
Copy link
Contributor

Question

When running the distributed training, the following line is reporting a CUDA device index error

dist.barrier(device_ids=[local_rank])

Sample script when launching the training

# On master machine 0
$ python -m torch.distributed.launch --nproc_per_node G --nnodes N --node_rank 0 --master_addr "192.168.1.1" --master_port 1234 train.py --batch 64 --data coco.yaml --cfg yolov5s.yaml --weights sample_weight.pt  --device 0,...G-1''
# On machine R
$ python -m torch.distributed.launch --nproc_per_node G --nnodes N --node_rank R --master_addr "192.168.1.1" --master_port 1234 train.py --batch 64 --data coco.yaml --cfg yolov5s.yaml --weights sample_weight.pt  --device 0,...G-1''

All instances have the sample number of GPU per machine.

When running the script above,

  • master instance is running with no error
  • machine R is complaining about the CUDA device index error

Proposed fix: changing the following lines from RANK to LOCAL_RANK

[1]

yolov5/train.py

Line 102 in 276b674

with torch_distributed_zero_first(RANK):

[2]

yolov5/train.py

Line 114 in 276b674

with torch_distributed_zero_first(RANK):

[3]

yolov5/train.py

Line 211 in 276b674

hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,

Context info:

  • all instances running in docker container, with
    • NVIDIA Release 21.03 (build 21060478)
    • PyTorch Version 1.9.0a0+df837d0

Question:

  • does the proposed fix above look good?
@qiningonline qiningonline added the question Further information is requested label Oct 9, 2021
@github-actions
Copy link
Contributor

github-actions bot commented Oct 9, 2021

👋 Hello @qiningonline, thank you for your interest in YOLOv5 🚀! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com.

Requirements

Python>=3.6.0 with all requirements.txt installed including PyTorch>=1.7. To get started:

$ git clone https://github.com/ultralytics/yolov5
$ cd yolov5
$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), validation (val.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

@glenn-jocher
Copy link
Member

glenn-jocher commented Oct 9, 2021

@qiningonline thanks for the bug report!

Distributed DDP may have issues as our own internal Ultralytics DDP trainings are limited to a single node so far. This means I can't verify multi-node fix results on our side, so we will have to rely on your results. If the above changes solve the problem for you in multi-node environment then please submit a PR with these changes.

Also please verify the changes have no impact on single-node trainings. The simplest way to do this would be to train 2 single-node DDP models (with master branch and PR branch), i.e. using the command below and comparing final mAPs and training times.

$ python -m torch.distributed.launch --nproc_per_node 4 --master_port 1 train.py --batch 64 --data coco.yaml --cfg --yolov5s.yaml weights ''  --epochs 10 --device 0,1,2,3

NOTE: RNG seeds are very important for DDP (they should be different across all RANKs, and in your case across all nodes too possibly, otherwise augmentation will be identical and cause overfitting and reduced final mAP). I think this line is fine the way it is, but you should also verify seeds are set differently everywhere:

yolov5/train.py

Line 101 in 276b674

init_seeds(1 + RANK)

@glenn-jocher
Copy link
Member

@qiningonline also you might want to migrate from torch.distributed.launch to torch.distributed.run

@qiningonline
Copy link
Contributor Author

@glenn-jocher

Thank you for your comments!

The changes proposed in the issue description tested and fixed the problem.

Given your suggestion above, I ran the following 4 tests

Test-0

  • master, commit 276b674
  • 2 nodes
  • error raised
Traceback (most recent call last):
  File "train.py", line 620, in <module>
    main(opt)
  File "train.py", line 517, in main
    train(opt.hyp, opt, device, callbacks)
  File "train.py", line 102, in train
    with torch_distributed_zero_first(RANK):
  File "/opt/conda/lib/python3.8/contextlib.py", line 113, in __enter__
    return next(self.gen)
  File ".../yolov5/utils/torch_utils.py", line 37, in torch_distributed_zero_first
    dist.barrier(device_ids=[local_rank])
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2420, in barrier
    work = default_pg.barrier(opts=opts)
RuntimeError: CUDA error: invalid device ordinal

Test-1

  • master, commit 276b674
  • 1 node
  • v1.0
python -m torch.distributed.launch --nproc_per_node 2 --master_port 1 train.py --batch 64 --data coco.yaml --weights yolov5l.pt --epochs 10 --device 0,1 --project {project_name} --name v1.0

Test-2

  • commit with fix, 4206414
  • 1 node
  • v2.0
python -m torch.distributed.launch --nproc_per_node 2 --master_port 1 train.py --batch 64 --data coco.yaml --weights yolov5l.pt --epochs 10 --device 0,1 --project {project_name} --name v2.0

Test-3

  • commit with fix, 4206414
  • 2 nodes
  • v3.0
# machine 0
python -m torch.distributed.launch --nproc_per_node 2 --nnodes 2 --node_rank 0 --master_addr {master_address} --master_port 1 train.py --batch 64 --data coco.yaml --weights yolov5l.pt --epochs 10 --device 0,1 --project {project_name} --name v3.0 


# machine 1
python -m torch.distributed.launch --nproc_per_node 2 --nnodes 2 --node_rank 1 --master_addr {master_address} --master_port 1 train.py --batch 64 --data coco.yaml --weights yolov5l.pt --epochs 10 --device 0,1 --project {project_name} --name v3.0 

The metric in comparison

Screenshot from 2021-10-09 19-38-44

PR added here, #5114, please help to code review when you have a moment. Thank you!

@glenn-jocher
Copy link
Member

@qiningonline awesome, thanks for the tests. The results look good. P and R noise is expected since these evaluate at a specific confidence, so less statistics than the mAPs that evaluate at all confidences.

Small mAP differences are also expected as results are typically not perfectly reproducible in pytorch.

@glenn-jocher glenn-jocher removed the TODO label Nov 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants