Skip to content

Commit

Permalink
Fix device placement when .cuda() called without specifying index (L…
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored and jessecambon committed Aug 16, 2022
1 parent a88848a commit 1f07ea4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug that caused `ddp_find_unused_parameters` to be set `False`, whereas the intended default is `True` ([#14095](https://github.com/Lightning-AI/lightning/pull/14095))


- Fixed the device placement when `LightningModule.cuda()` gets called without specifying a device index and the current cuda device was not 0 ([#14128](https://github.com/Lightning-AI/lightning/pull/14128))


## [1.7.0] - 2022-08-02

### Added
Expand Down
10 changes: 6 additions & 4 deletions src/pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,16 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty
while being optimized.
Arguments:
device: if specified, all parameters will be
copied to that device
device: If specified, all parameters will be copied to that device. If `None`, the current CUDA device
index will be used.
Returns:
Module: self
"""
if device is None or isinstance(device, int):
device = torch.device("cuda", index=(device or 0))
if device is None:
device = torch.device("cuda", torch.cuda.current_device())
elif isinstance(device, int):
device = torch.device("cuda", index=device)
self.__update_properties(device=device)
return super().cuda(device=device)

Expand Down
24 changes: 23 additions & 1 deletion tests/tests_pytorch/utilities/test_dtype_device_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir):
],
)
@RunIf(min_cuda_gpus=1)
def test_gpu_cuda_device(device):
def test_cuda_device(device):
model = TopModule()

model.cuda(device)
Expand All @@ -122,3 +122,25 @@ def test_gpu_cuda_device(device):
assert device.type == "cuda"
assert device.index is not None
assert device.index == torch.cuda.current_device()


@RunIf(min_cuda_gpus=2)
def test_cuda_current_device():
"""Test that calling .cuda() moves the model to the correct device and respects current cuda device setting."""

class CudaModule(DeviceDtypeModuleMixin):
def __init__(self):
super().__init__()
self.layer = nn.Linear(1, 1)

model = CudaModule()

torch.cuda.set_device(0)
model.cuda(1)
assert model.device == torch.device("cuda", 1)
assert model.layer.weight.device == torch.device("cuda", 1)

torch.cuda.set_device(1)
model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model
assert model.device == torch.device("cuda", 1)
assert model.layer.weight.device == torch.device("cuda", 1)

0 comments on commit 1f07ea4

Please sign in to comment.