Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

device property #1791

Merged
merged 9 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -29,14 +30,11 @@
XLA_AVAILABLE = True


class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

#: Current dtype
self.dtype = torch.FloatTensor

self.exp_save_path = None

#: The current epoch
Expand Down Expand Up @@ -72,8 +70,10 @@ def __init__(self, *args, **kwargs):

self.hparams = None

#: Current dtype
self._dtype = torch.FloatTensor
#: device reference
self.device = None
self._device = torch.device('cpu')

def print(self, *args, **kwargs) -> None:
r"""
Expand Down
156 changes: 156 additions & 0 deletions pytorch_lightning/core/properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Union, Optional

import torch


class DeviceDtypeModuleMixin(torch.nn.Module):
_device: ...
justusschock marked this conversation as resolved.
Show resolved Hide resolved
_dtype: Union[str, torch.dtype]

@property
def dtype(self) -> Union[str, torch.dtype]:
return self._dtype

@dtype.setter
def dtype(self, new_dtype: Union[str, torch.dtype]):
# necessary to avoid infinite recursion
raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).')

@property
def device(self) -> Union[str, torch.device]:
return self._device

@device.setter
def device(self, new_device: Union[str, torch.device]):
# Necessary to avoid infinite recursion
raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).')

def to(self, *args, **kwargs) -> torch.nn.Module:
"""Moves and/or casts the parameters and buffers.

This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.

Note:
This method modifies the module in-place.

Args:
device: the desired device of the parameters
and buffers in this module
dtype: the desired floating point type of
the floating point parameters and buffers in this module
tensor: Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module

Returns:
Module: self

Example::
>>> class ExampleModule(DeviceDtypeModuleMixin):
... def __init__(self, weight: torch.Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]])
>>> module.to(torch.double)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float64)
>>> cpu = torch.device('cpu')
>>> module.to(cpu, dtype=torch.half, non_blocking=True)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.to(cpu)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
"""
# there is diff nb vars in PT 1.5
out = torch._C._nn._parse_to(*args, **kwargs)
device = out[0]
dtype = out[1]
if device is not None:
self._device = device

if dtype is not None:
self._dtype = dtype

return super().to(*args, **kwargs)

def cuda(self, device: Optional[int] = None) -> torch.nn.Module:
"""Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on GPU while being optimized.

Arguments:
device: if specified, all parameters will be
copied to that device

Returns:
Module: self
"""

self._device = torch.device('cuda', index=device)
return super().cuda(device=device)

def cpu(self) -> torch.nn.Module:
"""Moves all model parameters and buffers to the CPU.
Returns:
Module: self
"""
self._device = torch.device('cpu')
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module:
"""Casts all parameters and buffers to :attr:`dst_type`.

Arguments:
dst_type (type or string): the desired type

Returns:
Module: self
"""
self._dtype = dst_type
return super().type(dst_type=dst_type)

def float(self) -> torch.nn.Module:
"""Casts all floating point parameters and buffers to float datatype.

Returns:
Module: self
"""
self._dtype = torch.float
return super().float()

def double(self) -> torch.nn.Module:
"""Casts all floating point parameters and buffers to ``double`` datatype.

Returns:
Module: self
"""
self._dtype = torch.double
return super().double()

def half(self) -> torch.nn.Module:
"""Casts all floating point parameters and buffers to ``half`` datatype.

Returns:
Module: self
"""
self._dtype = torch.half
return super().half()
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def ddp_train(self, process_idx, model):
# copy model to each gpu
if self.on_gpu:
self.root_gpu = process_idx
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could remove all of these calls in the trainer by overloading .to() and .cuda() in LightningModule and setting the device there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to overload .cpu() :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, maybe i’m missiny something. The point of self.device is to have a readonly property to create tensors in memory directly.

Copy link
Contributor

@awaelchli awaelchli May 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we overload the the .to() method like this for example:

def to(self, device):
    self._device = device
    return super().to(device)
    

Then we get the following benefits:

  • self.device property will not break when LightningModule is used as nn.Module without Trainer
  • When LightningModule is a nested LightningModule and user calls .to(), also the self.device properties of submodules get updated
  • The Trainer code does not need to set the device, it calls .to anyway, so the code is in one place and is easier to maintain.

I see only benefits atm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justusschock also did it like this for metrics package

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we need to overwrite the following methods:

  • .to(...)
  • .cpu()
  • .cuda()
    or am I missing any? @awaelchli ^^

Copy link
Contributor

@awaelchli awaelchli May 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, exactly, although I suspect cpu and cuda already call .to internally. Not sure, need to check. EDIT: nope they don't we need all three :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it seems to me that ideally, we want to rise the whole template from metrics...

Copy link
Contributor

@awaelchli awaelchli May 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not all. device.setter, dtype does not apply for LightningModule I think? I agree we should try to avoid code duplication.
@justusschock what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, while we do this, we should think about introducing the same for dtype, since when I create a tensor in a function, it usually involves a certain dtype as well. Although I'm not sure, if this would be reflected by amp as well...

torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)

Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def copy_trainer_model_properties(self, model):
m.use_tpu = self.use_tpu
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank
m.device = self.device
m._device = self._device

def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')
Expand Down Expand Up @@ -484,7 +484,7 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None):

def single_gpu_train(self, model):
model.cuda(self.root_gpu)
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand All @@ -501,7 +501,7 @@ def single_gpu_train(self, model):
def tpu_train(self, tpu_core_idx, model):
# put model on tpu
model.to(xm.xla_device())
self.device = xm.xla_device()
self._device = xm.xla_device()

# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
Expand Down Expand Up @@ -539,7 +539,7 @@ def dp_train(self, model):
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

model.cuda(self.root_gpu)
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
Expand Down Expand Up @@ -579,7 +579,7 @@ def horovod_train(self, model):
assert self.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)
self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)

# avoid duplicating progress bar
if hvd.rank() != 0 and self.progress_bar_callback is not None:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def __init__(
# distributed backend choice
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend)
self.device = torch.device('cpu')
self._device = torch.device('cpu')

# override dist backend when using tpus
if self.on_tpu:
Expand Down Expand Up @@ -635,6 +635,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
'check_val_every_n_epoch': 1,
'checkpoint_callback': True,
'default_root_dir': None,
'deterministic': False,
'distributed_backend': None,
'early_stop_callback': False,
...
Expand Down
2 changes: 2 additions & 0 deletions tests/base/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class EvalModelTemplate(
):
"""
This template houses all combinations of model configurations we want to test

>>> model = EvalModelTemplate()
"""
def __init__(self, hparams: object = None) -> object:
"""Pass in parsed HyperOptArgumentParser to the model."""
Expand Down