Skip to content

Commit

Permalink
SA: for #958: set torch cuda device when finding root
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamagarwal92 authored and williamFalcon committed Apr 3, 2020
1 parent 868b172 commit 5fc8165
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,4 +638,10 @@ def determine_root_gpu_device(gpus):
# set root gpu
root_gpu = gpus[0]

# set cuda device to root gpu
# related to https://github.com/PyTorchLightning/pytorch-lightning/issues/958
# Refer solution: https://github.com/pytorch/pytorch/issues/9871#issuecomment-408304190
root_device = torch.device("cuda", root_gpu)
torch.cuda.set_device(root_device)

return root_gpu

0 comments on commit 5fc8165

Please sign in to comment.