-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
device_dtype_mixin.py
156 lines (127 loc) · 5.29 KB
/
device_dtype_mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import Union, Optional
import torch
class DeviceDtypeModuleMixin(torch.nn.Module):
_device: ...
_dtype: Union[str, torch.dtype]
@property
def dtype(self) -> Union[str, torch.dtype]:
return self._dtype
@dtype.setter
def dtype(self, new_dtype: Union[str, torch.dtype]):
# necessary to avoid infinite recursion
raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).')
@property
def device(self) -> Union[str, torch.device]:
return self._device
@device.setter
def device(self, new_device: Union[str, torch.device]):
# Necessary to avoid infinite recursion
raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).')
def to(self, *args, **kwargs) -> torch.nn.Module:
"""Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note:
This method modifies the module in-place.
Args:
device: the desired device of the parameters
and buffers in this module
dtype: the desired floating point type of
the floating point parameters and buffers in this module
tensor: Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
Returns:
Module: self
Example::
>>> class ExampleModule(DeviceDtypeModuleMixin):
... def __init__(self, weight: torch.Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]])
>>> module.to(torch.double)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float64)
>>> cpu = torch.device('cpu')
>>> module.to(cpu, dtype=torch.half, non_blocking=True)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.to(cpu)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
"""
# there is diff nb vars in PT 1.5
out = torch._C._nn._parse_to(*args, **kwargs)
device = out[0]
dtype = out[1]
if device is not None:
self._device = device
if dtype is not None:
self._dtype = dtype
return super().to(*args, **kwargs)
def cuda(self, device: Optional[int] = None) -> torch.nn.Module:
"""Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on GPU while being optimized.
Arguments:
device: if specified, all parameters will be
copied to that device
Returns:
Module: self
"""
self._device = torch.device('cuda', index=device)
return super().cuda(device=device)
def cpu(self) -> torch.nn.Module:
"""Moves all model parameters and buffers to the CPU.
Returns:
Module: self
"""
self._device = torch.device('cpu')
return super().cpu()
def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module:
"""Casts all parameters and buffers to :attr:`dst_type`.
Arguments:
dst_type (type or string): the desired type
Returns:
Module: self
"""
self._dtype = dst_type
return super().type(dst_type=dst_type)
def float(self) -> torch.nn.Module:
"""Casts all floating point parameters and buffers to float datatype.
Returns:
Module: self
"""
self._dtype = torch.float
return super().float()
def double(self) -> torch.nn.Module:
"""Casts all floating point parameters and buffers to ``double`` datatype.
Returns:
Module: self
"""
self._dtype = torch.double
return super().double()
def half(self) -> torch.nn.Module:
"""Casts all floating point parameters and buffers to ``half`` datatype.
Returns:
Module: self
"""
self._dtype = torch.half
return super().half()