From 499ca6451b9829d82381a8e02b4073840fc8a8c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 16:17:12 +0200 Subject: [PATCH 01/16] recursive dtype device apply --- .../utilities/device_dtype_mixin.py | 44 +++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 48ccad5307552..ec552a38138e4 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -85,10 +85,10 @@ def to(self, *args, **kwargs) -> Module: device = out[0] dtype = out[1] if device is not None: - self._device = device + self.apply(device_apply_fn(device)) if dtype is not None: - self._dtype = dtype + self.apply(dtype_apply_fn(dtype)) return super().to(*args, **kwargs) @@ -105,8 +105,7 @@ def cuda(self, device: Optional[int] = None) -> Module: Returns: Module: self """ - - self._device = torch.device('cuda', index=device) + self.apply(device_apply_fn(torch.device('cuda', index=device))) return super().cuda(device=device) def cpu(self) -> Module: @@ -114,7 +113,7 @@ def cpu(self) -> Module: Returns: Module: self """ - self._device = torch.device('cpu') + self.apply(device_apply_fn(torch.device('cpu'))) return super().cpu() def type(self, dst_type: Union[str, torch.dtype]) -> Module: @@ -126,7 +125,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Module: Returns: Module: self """ - self._dtype = dst_type + self.apply(dtype_apply_fn(dst_type)) return super().type(dst_type=dst_type) def float(self) -> Module: @@ -135,7 +134,7 @@ def float(self) -> Module: Returns: Module: self """ - self._dtype = torch.float + self.apply(dtype_apply_fn(torch.float)) return super().float() def double(self) -> Module: @@ -144,7 +143,7 @@ def double(self) -> Module: Returns: Module: self """ - self._dtype = torch.double + self.apply(dtype_apply_fn(torch.double)) return super().double() def half(self) -> Module: @@ -153,5 +152,32 @@ def half(self) -> Module: Returns: Module: self """ - self._dtype = torch.half + self.apply(dtype_apply_fn(torch.half)) return super().half() + + + + + +def dtype_apply_fn(dtype): + return apply_attr('_dtype', dtype) + + +def device_apply_fn(device): + return apply_attr('_device', device) + + +# def update_attributes(module, **kwargs): +# if not isinstance(module, DeviceDtypeModuleMixin): +# return +# for k, v in kwargs: +# module.__setattr__(k, v) + + +def apply_attr(name: str, value): + + def apply_fn(module): + if isinstance(module, DeviceDtypeModuleMixin): + module.__setattr__(name, value) + + return apply_fn From 9de24ed2295fe9af18330f76e8bffaca57ed7a81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 16:48:02 +0200 Subject: [PATCH 02/16] simplify --- .../utilities/device_dtype_mixin.py | 56 ++++++------------- 1 file changed, 17 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index ec552a38138e4..28c6428df8d37 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -1,4 +1,4 @@ -from typing import Union, Optional +from typing import Union, Optional, Any, Callable import torch from torch.nn import Module @@ -82,14 +82,7 @@ def to(self, *args, **kwargs) -> Module: """ # there is diff nb vars in PT 1.5 out = torch._C._nn._parse_to(*args, **kwargs) - device = out[0] - dtype = out[1] - if device is not None: - self.apply(device_apply_fn(device)) - - if dtype is not None: - self.apply(dtype_apply_fn(dtype)) - + self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) def cuda(self, device: Optional[int] = None) -> Module: @@ -105,7 +98,7 @@ def cuda(self, device: Optional[int] = None) -> Module: Returns: Module: self """ - self.apply(device_apply_fn(torch.device('cuda', index=device))) + self.__update_properties(device=torch.device('cuda', index=device)) return super().cuda(device=device) def cpu(self) -> Module: @@ -113,7 +106,7 @@ def cpu(self) -> Module: Returns: Module: self """ - self.apply(device_apply_fn(torch.device('cpu'))) + self.__update_properties(device=torch.device('cpu')) return super().cpu() def type(self, dst_type: Union[str, torch.dtype]) -> Module: @@ -125,7 +118,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Module: Returns: Module: self """ - self.apply(dtype_apply_fn(dst_type)) + self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) def float(self) -> Module: @@ -134,7 +127,7 @@ def float(self) -> Module: Returns: Module: self """ - self.apply(dtype_apply_fn(torch.float)) + self.__update_properties(dtype=torch.float) return super().float() def double(self) -> Module: @@ -143,7 +136,7 @@ def double(self) -> Module: Returns: Module: self """ - self.apply(dtype_apply_fn(torch.double)) + self.__update_properties(dtype=torch.double) return super().double() def half(self) -> Module: @@ -152,32 +145,17 @@ def half(self) -> Module: Returns: Module: self """ - self.apply(dtype_apply_fn(torch.half)) + self.__update_properties(dtype=torch.half) return super().half() + def __update_properties(self, device=None, dtype=None): + def apply_fn(module): + if not isinstance(module, DeviceDtypeModuleMixin): + return + if device is not None: + module._device = device + if dtype is not None: + module._dtype = dtype - - -def dtype_apply_fn(dtype): - return apply_attr('_dtype', dtype) - - -def device_apply_fn(device): - return apply_attr('_device', device) - - -# def update_attributes(module, **kwargs): -# if not isinstance(module, DeviceDtypeModuleMixin): -# return -# for k, v in kwargs: -# module.__setattr__(k, v) - - -def apply_attr(name: str, value): - - def apply_fn(module): - if isinstance(module, DeviceDtypeModuleMixin): - module.__setattr__(name, value) - - return apply_fn + self.apply(apply_fn) From 95478ce1ff73db0552230abaccbdab3c6f97c087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 16:59:03 +0200 Subject: [PATCH 03/16] simple test --- tests/utilities/test_dtype_device_mixin.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tests/utilities/test_dtype_device_mixin.py diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py new file mode 100644 index 0000000000000..6842666d438ab --- /dev/null +++ b/tests/utilities/test_dtype_device_mixin.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn + +from pytorch_lightning.metrics import Accuracy +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from tests.base import EvalModelTemplate + + +class Model(EvalModelTemplate): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.acc = Accuracy() + + +def test_submodules_device(tmpdir): + + model = Model() + assert model.device == torch.device('cpu') + model = model.to('cuda') + assert model.device == model.acc.device == torch.device('cuda') \ No newline at end of file From cc5bf5be74574ee3eae156575b17142b955ccffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 17:23:33 +0200 Subject: [PATCH 04/16] submodule test --- tests/utilities/test_dtype_device_mixin.py | 41 ++++++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index 6842666d438ab..bb595c5867b81 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -1,21 +1,48 @@ +import pytest import torch import torch.nn as nn -from pytorch_lightning.metrics import Accuracy from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from tests.base import EvalModelTemplate -class Model(EvalModelTemplate): +class SubSubModule(DeviceDtypeModuleMixin, nn.Module): + pass + + +class SubModule(nn.Module): + + def __init__(self): + super().__init__() + self.sub_sub_module = SubSubModule() + + +class TopModule(EvalModelTemplate): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.acc = Accuracy() + self.sub_module = SubModule() -def test_submodules_device(tmpdir): +@pytest.mark.parametrize(['dst_dtype'], [ + pytest.param(torch.float), + pytest.param(torch.double), + pytest.param(torch.half), +]) +@pytest.mark.parametrize(['dst_device'], [ + pytest.param(torch.device('cpu')), + pytest.param(torch.device('cuda')), + pytest.param(torch.device('cuda', 0)), +]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_submodules_device(dst_device, dst_dtype): + """ + Test that the device and dtype property updates propagate through mixed nesting of regular + nn.Modules and the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule). + """ - model = Model() + model = TopModule() assert model.device == torch.device('cpu') - model = model.to('cuda') - assert model.device == model.acc.device == torch.device('cuda') \ No newline at end of file + model = model.to(device=dst_device, dtype=dst_dtype) + assert model.device == model.sub_module.sub_sub_module.device == dst_device + assert model.dtype == model.sub_module.sub_sub_module.dtype == dst_dtype From 27e15da7ef2171194c5ec95de515c08ba21cf39a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 17:24:02 +0200 Subject: [PATCH 05/16] rename --- tests/utilities/test_dtype_device_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index bb595c5867b81..24faf1ba80477 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -35,7 +35,7 @@ def __init__(self, *args, **kwargs): pytest.param(torch.device('cuda', 0)), ]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_submodules_device(dst_device, dst_dtype): +def test_submodules_device_and_dtype(dst_device, dst_dtype): """ Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule). From db8cb0a789e97d914b608541044c3e4b2b33e4b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 17:42:37 +0200 Subject: [PATCH 06/16] explicit --- tests/utilities/test_dtype_device_mixin.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index 24faf1ba80477..a39ccb0bbc9b5 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -6,7 +6,7 @@ from tests.base import EvalModelTemplate -class SubSubModule(DeviceDtypeModuleMixin, nn.Module): +class SubSubModule(DeviceDtypeModuleMixin): pass @@ -14,14 +14,14 @@ class SubModule(nn.Module): def __init__(self): super().__init__() - self.sub_sub_module = SubSubModule() + self.module = SubSubModule() class TopModule(EvalModelTemplate): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.sub_module = SubModule() + self.module = SubModule() @pytest.mark.parametrize(['dst_dtype'], [ @@ -44,5 +44,9 @@ def test_submodules_device_and_dtype(dst_device, dst_dtype): model = TopModule() assert model.device == torch.device('cpu') model = model.to(device=dst_device, dtype=dst_dtype) - assert model.device == model.sub_module.sub_sub_module.device == dst_device - assert model.dtype == model.sub_module.sub_sub_module.dtype == dst_dtype + # nn.Module does not have these attributes + assert not hasattr(model.module, '_device') + assert not hasattr(model.module, '_dtype') + # device and dtype change should propagate down into all children + assert model.device == model.module.module.device == dst_device + assert model.dtype == model.module.module.dtype == dst_dtype From 38bdd2327b7a7c322bab8d6c9355761ec2e92886 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 17:42:48 +0200 Subject: [PATCH 07/16] type hints --- pytorch_lightning/utilities/device_dtype_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 28c6428df8d37..c5bb8ae5d9a33 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, Any, Callable +from typing import Union, Optional import torch from torch.nn import Module @@ -148,7 +148,7 @@ def half(self) -> Module: self.__update_properties(dtype=torch.half) return super().half() - def __update_properties(self, device=None, dtype=None): + def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): def apply_fn(module): if not isinstance(module, DeviceDtypeModuleMixin): From 2f510dcda42de2dad98c6d39d174b990ce3213e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 17:55:47 +0200 Subject: [PATCH 08/16] test for dp backend --- tests/utilities/test_dtype_device_mixin.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index a39ccb0bbc9b5..0ba0ba347b5d4 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +from pytorch_lightning import Trainer, Callback from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from tests.base import EvalModelTemplate @@ -50,3 +51,24 @@ def test_submodules_device_and_dtype(dst_device, dst_dtype): # device and dtype change should propagate down into all children assert model.device == model.module.module.device == dst_device assert model.dtype == model.module.module.dtype == dst_dtype + + +@pytest.mark.skipif(not torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_submodules_multi_gpu_dp(tmpdir): + + class DeviceCallback(Callback): + + def on_batch_start(self, trainer, model): + assert isinstance(model, TopModule) + assert model.device.index == trainer.local_rank + assert model.device == model.module.module.device + + model = TopModule() + trainer = Trainer( + default_root_dir=tmpdir, + distributed_backend='dp', + gpus=2, + callbacks=[DeviceCallback()], + max_steps=1, + ) + trainer.fit(model) From b955923e728d056e2d7fd09f8eadb53a99f39a74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 18:07:51 +0200 Subject: [PATCH 09/16] fix test skip --- tests/utilities/test_dtype_device_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index 0ba0ba347b5d4..fb7076186e58a 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -53,7 +53,7 @@ def test_submodules_device_and_dtype(dst_device, dst_dtype): assert model.dtype == model.module.module.dtype == dst_dtype -@pytest.mark.skipif(not torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_submodules_multi_gpu_dp(tmpdir): class DeviceCallback(Callback): From a49452355adb4b597b78e7e7ac5127b460614923 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 18:08:19 +0200 Subject: [PATCH 10/16] rename --- tests/utilities/test_dtype_device_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index fb7076186e58a..773f25baa5dda 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -56,7 +56,7 @@ def test_submodules_device_and_dtype(dst_device, dst_dtype): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_submodules_multi_gpu_dp(tmpdir): - class DeviceCallback(Callback): + class DeviceAssertCallback(Callback): def on_batch_start(self, trainer, model): assert isinstance(model, TopModule) @@ -68,7 +68,7 @@ def on_batch_start(self, trainer, model): default_root_dir=tmpdir, distributed_backend='dp', gpus=2, - callbacks=[DeviceCallback()], + callbacks=[DeviceAssertCallback()], max_steps=1, ) trainer.fit(model) From 1c56e1389cd1d373e13686de6f403f70b805c787 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 18:17:31 +0200 Subject: [PATCH 11/16] add ddp_spawn test --- tests/utilities/test_dtype_device_mixin.py | 27 ++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index 773f25baa5dda..23f15a7753db9 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -25,6 +25,14 @@ def __init__(self, *args, **kwargs): self.module = SubModule() +class DeviceAssertCallback(Callback): + + def on_batch_start(self, trainer, model): + assert isinstance(model, TopModule) + assert model.device.index == trainer.local_rank + assert model.device == model.module.module.device + + @pytest.mark.parametrize(['dst_dtype'], [ pytest.param(torch.float), pytest.param(torch.double), @@ -55,18 +63,23 @@ def test_submodules_device_and_dtype(dst_device, dst_dtype): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_submodules_multi_gpu_dp(tmpdir): + model = TopModule() + trainer = Trainer( + default_root_dir=tmpdir, + distributed_backend='dp', + gpus=2, + callbacks=[DeviceAssertCallback()], + max_steps=1, + ) + trainer.fit(model) - class DeviceAssertCallback(Callback): - - def on_batch_start(self, trainer, model): - assert isinstance(model, TopModule) - assert model.device.index == trainer.local_rank - assert model.device == model.module.module.device +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_submodules_multi_gpu_ddp_spawn(tmpdir): model = TopModule() trainer = Trainer( default_root_dir=tmpdir, - distributed_backend='dp', + distributed_backend='dpp_spawn', gpus=2, callbacks=[DeviceAssertCallback()], max_steps=1, From 1eb776a49f7ee3706881a72903585d52d4bfc75a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 18:34:25 +0200 Subject: [PATCH 12/16] fix None index in test --- tests/utilities/test_dtype_device_mixin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index 23f15a7753db9..f755cf5c634ed 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -28,8 +28,10 @@ def __init__(self, *args, **kwargs): class DeviceAssertCallback(Callback): def on_batch_start(self, trainer, model): + rank = trainer.local_rank assert isinstance(model, TopModule) - assert model.device.index == trainer.local_rank + # index = None also means first device + assert (model.device.index is None and rank == 0) or model.device.index == rank assert model.device == model.module.module.device From 565e4cb6de1cda9feb39328e88631dd3fea37409 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 18:55:01 +0200 Subject: [PATCH 13/16] try fix ddp_spawn test --- tests/utilities/test_dtype_device_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index f755cf5c634ed..f935ba4e8ec70 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -7,7 +7,7 @@ from tests.base import EvalModelTemplate -class SubSubModule(DeviceDtypeModuleMixin): +class SubSubModule(DeviceDtypeModuleMixin, nn.Module): pass From aa9306a7af5f52e77ed1fca3fdae94b162a06705 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 18:56:43 +0200 Subject: [PATCH 14/16] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf200ea15f007..6a1e23adaa369 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - +- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657)) ## [0.8.5] - 2020-07-09 From 34a68053fba42c3cdd64cf9355ca37f9a571db97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 19:15:33 +0200 Subject: [PATCH 15/16] move _dtype and _device to mixin --- pytorch_lightning/core/lightning.py | 6 ------ pytorch_lightning/metrics/metric.py | 2 -- pytorch_lightning/utilities/device_dtype_mixin.py | 7 +++++-- tests/utilities/test_dtype_device_mixin.py | 2 +- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f2b591cbacfbe..1739133edabf3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -68,12 +68,6 @@ def __init__(self, *args, **kwargs): #: True if using amp self.use_amp = False - #: Current dtype - self._dtype = torch.float - - #: device reference - self._device = torch.device('cpu') - # optionally can be set by user self._example_input_array = None diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 349a6ecfa2f82..94e8a0ea4e442 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -27,8 +27,6 @@ def __init__(self, name: str): """ super().__init__() self.name = name - self._dtype = torch.get_default_dtype() - self._device = torch.device('cpu') @abstractmethod def forward(self, *args, **kwargs) -> torch.Tensor: diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index c5bb8ae5d9a33..afb281e0f4efa 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -5,8 +5,11 @@ class DeviceDtypeModuleMixin(Module): - _device: ... - _dtype: Union[str, torch.dtype] + + def __init__(self): + super().__init__() + self._dtype = torch.get_default_dtype() + self._device = torch.device('cpu') @property def dtype(self) -> Union[str, torch.dtype]: diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index f935ba4e8ec70..f755cf5c634ed 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -7,7 +7,7 @@ from tests.base import EvalModelTemplate -class SubSubModule(DeviceDtypeModuleMixin, nn.Module): +class SubSubModule(DeviceDtypeModuleMixin): pass From 9bf20a02d350d431fe5fe0faa572bd2d5c57c3c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Jul 2020 19:24:06 +0200 Subject: [PATCH 16/16] additional doctest --- pytorch_lightning/utilities/device_dtype_mixin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index afb281e0f4efa..bea3df3e5ced9 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -82,6 +82,10 @@ def to(self, *args, **kwargs) -> Module: ExampleModule() >>> module.weight #doctest: +ELLIPSIS tensor([[...]], dtype=torch.float16) + >>> module.device + device(type='cpu') + >>> module.dtype + torch.float16 """ # there is diff nb vars in PT 1.5 out = torch._C._nn._parse_to(*args, **kwargs)