-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into feature/replace-mata-tags-with-yaml
- Loading branch information
Showing
6 changed files
with
170 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
from typing import Union, Optional | ||
|
||
import torch | ||
|
||
|
||
class DeviceDtypeModuleMixin(torch.nn.Module): | ||
_device: ... | ||
_dtype: Union[str, torch.dtype] | ||
|
||
@property | ||
def dtype(self) -> Union[str, torch.dtype]: | ||
return self._dtype | ||
|
||
@dtype.setter | ||
def dtype(self, new_dtype: Union[str, torch.dtype]): | ||
# necessary to avoid infinite recursion | ||
raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).') | ||
|
||
@property | ||
def device(self) -> Union[str, torch.device]: | ||
return self._device | ||
|
||
@device.setter | ||
def device(self, new_device: Union[str, torch.device]): | ||
# Necessary to avoid infinite recursion | ||
raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).') | ||
|
||
def to(self, *args, **kwargs) -> torch.nn.Module: | ||
"""Moves and/or casts the parameters and buffers. | ||
This can be called as | ||
.. function:: to(device=None, dtype=None, non_blocking=False) | ||
.. function:: to(dtype, non_blocking=False) | ||
.. function:: to(tensor, non_blocking=False) | ||
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts | ||
floating point desired :attr:`dtype` s. In addition, this method will | ||
only cast the floating point parameters and buffers to :attr:`dtype` | ||
(if given). The integral parameters and buffers will be moved | ||
:attr:`device`, if that is given, but with dtypes unchanged. When | ||
:attr:`non_blocking` is set, it tries to convert/move asynchronously | ||
with respect to the host if possible, e.g., moving CPU Tensors with | ||
pinned memory to CUDA devices. | ||
See below for examples. | ||
Note: | ||
This method modifies the module in-place. | ||
Args: | ||
device: the desired device of the parameters | ||
and buffers in this module | ||
dtype: the desired floating point type of | ||
the floating point parameters and buffers in this module | ||
tensor: Tensor whose dtype and device are the desired | ||
dtype and device for all parameters and buffers in this module | ||
Returns: | ||
Module: self | ||
Example:: | ||
>>> class ExampleModule(DeviceDtypeModuleMixin): | ||
... def __init__(self, weight: torch.Tensor): | ||
... super().__init__() | ||
... self.register_buffer('weight', weight) | ||
>>> _ = torch.manual_seed(0) | ||
>>> module = ExampleModule(torch.rand(3, 4)) | ||
>>> module.weight #doctest: +ELLIPSIS | ||
tensor([[...]]) | ||
>>> module.to(torch.double) | ||
ExampleModule() | ||
>>> module.weight #doctest: +ELLIPSIS | ||
tensor([[...]], dtype=torch.float64) | ||
>>> cpu = torch.device('cpu') | ||
>>> module.to(cpu, dtype=torch.half, non_blocking=True) | ||
ExampleModule() | ||
>>> module.weight #doctest: +ELLIPSIS | ||
tensor([[...]], dtype=torch.float16) | ||
>>> module.to(cpu) | ||
ExampleModule() | ||
>>> module.weight #doctest: +ELLIPSIS | ||
tensor([[...]], dtype=torch.float16) | ||
""" | ||
# there is diff nb vars in PT 1.5 | ||
out = torch._C._nn._parse_to(*args, **kwargs) | ||
device = out[0] | ||
dtype = out[1] | ||
if device is not None: | ||
self._device = device | ||
|
||
if dtype is not None: | ||
self._dtype = dtype | ||
|
||
return super().to(*args, **kwargs) | ||
|
||
def cuda(self, device: Optional[int] = None) -> torch.nn.Module: | ||
"""Moves all model parameters and buffers to the GPU. | ||
This also makes associated parameters and buffers different objects. So | ||
it should be called before constructing optimizer if the module will | ||
live on GPU while being optimized. | ||
Arguments: | ||
device: if specified, all parameters will be | ||
copied to that device | ||
Returns: | ||
Module: self | ||
""" | ||
|
||
self._device = torch.device('cuda', index=device) | ||
return super().cuda(device=device) | ||
|
||
def cpu(self) -> torch.nn.Module: | ||
"""Moves all model parameters and buffers to the CPU. | ||
Returns: | ||
Module: self | ||
""" | ||
self._device = torch.device('cpu') | ||
return super().cpu() | ||
|
||
def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module: | ||
"""Casts all parameters and buffers to :attr:`dst_type`. | ||
Arguments: | ||
dst_type (type or string): the desired type | ||
Returns: | ||
Module: self | ||
""" | ||
self._dtype = dst_type | ||
return super().type(dst_type=dst_type) | ||
|
||
def float(self) -> torch.nn.Module: | ||
"""Casts all floating point parameters and buffers to float datatype. | ||
Returns: | ||
Module: self | ||
""" | ||
self._dtype = torch.float | ||
return super().float() | ||
|
||
def double(self) -> torch.nn.Module: | ||
"""Casts all floating point parameters and buffers to ``double`` datatype. | ||
Returns: | ||
Module: self | ||
""" | ||
self._dtype = torch.double | ||
return super().double() | ||
|
||
def half(self) -> torch.nn.Module: | ||
"""Casts all floating point parameters and buffers to ``half`` datatype. | ||
Returns: | ||
Module: self | ||
""" | ||
self._dtype = torch.half | ||
return super().half() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters