Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dtype/device property not getting updated in submodules #2657

Merged
merged 16 commits into from
Jul 21, 2020
36 changes: 20 additions & 16 deletions pytorch_lightning/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional
from typing import Union, Optional, Any, Callable

import torch
from torch.nn import Module
Expand Down Expand Up @@ -82,14 +82,7 @@ def to(self, *args, **kwargs) -> Module:
"""
# there is diff nb vars in PT 1.5
out = torch._C._nn._parse_to(*args, **kwargs)
device = out[0]
dtype = out[1]
if device is not None:
self._device = device

if dtype is not None:
self._dtype = dtype

self.__update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)

def cuda(self, device: Optional[int] = None) -> Module:
Expand All @@ -105,16 +98,15 @@ def cuda(self, device: Optional[int] = None) -> Module:
Returns:
Module: self
"""

self._device = torch.device('cuda', index=device)
self.__update_properties(device=torch.device('cuda', index=device))
return super().cuda(device=device)

def cpu(self) -> Module:
"""Moves all model parameters and buffers to the CPU.
Returns:
Module: self
"""
self._device = torch.device('cpu')
self.__update_properties(device=torch.device('cpu'))
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> Module:
Expand All @@ -126,7 +118,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Module:
Returns:
Module: self
"""
self._dtype = dst_type
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)

def float(self) -> Module:
Expand All @@ -135,7 +127,7 @@ def float(self) -> Module:
Returns:
Module: self
"""
self._dtype = torch.float
self.__update_properties(dtype=torch.float)
return super().float()

def double(self) -> Module:
Expand All @@ -144,7 +136,7 @@ def double(self) -> Module:
Returns:
Module: self
"""
self._dtype = torch.double
self.__update_properties(dtype=torch.double)
return super().double()

def half(self) -> Module:
Expand All @@ -153,5 +145,17 @@ def half(self) -> Module:
Returns:
Module: self
"""
self._dtype = torch.half
self.__update_properties(dtype=torch.half)
return super().half()

def __update_properties(self, device=None, dtype=None):

def apply_fn(module):
if not isinstance(module, DeviceDtypeModuleMixin):
return
if device is not None:
module._device = device
if dtype is not None:
module._dtype = dtype

self.apply(apply_fn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason of using apply?

Copy link
Member Author

@awaelchli awaelchli Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

part of the answer is here in my comment.
apply in contrast to "to" works recursively on all modules and allows us to update our custom properties.
I'm writing a test right now to make sure it fixes what failed before.