Skip to content

Commit

Permalink
Add option for weight tying on TPU's (#5441)
Browse files Browse the repository at this point in the history
* added on_post_move_to_device

* added tests

* docs and refactors

* Update tests/backends/test_tpu_backend.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update docs/source/tpu.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update docs/source/tpu.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/core/decorators.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/core/decorators.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update docs/source/tpu.rst

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/core/decorators.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/core/decorators.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/core/decorators.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/core/decorators.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/core/hooks.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* moved weight sharing module back to test

updated tpu available

* add count to warning

* fix doctest

* import trainer in doctest

* import trainer in doctest

* do not test code as no TPU device

* param count to layer count

* formatting

* update docs

* update import

* update

* resolve tests

* remove legacy accelerator

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Your Name <you@example.com>
  • Loading branch information
5 people authored Feb 18, 2021
1 parent bac617f commit d2cd7cb
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))

- Support to tie weights after moving model to TPU via `on_post_move_to_device` hook

- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))

Expand Down
57 changes: 56 additions & 1 deletion docs/source/advanced/tpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,62 @@ set the 16-bit flag.
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.

----------------

-----------------

Weight Sharing/Tying
--------------------
Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers.
This is a common method to reduce memory consumption and is utilized in many State of the Art
architectures today.

PyTorch XLA requires these weights to be tied/shared after moving the model
to the TPU device. To support this requirement Lightning provides a model hook which is
called after the model is moved to the device. Any weights that require to be tied should
be done in the `on_post_move_to_device` model hook. This will ensure that the weights
among the modules are shared and not copied.

PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths
match once the model is moved to the device. If the lengths do not match Lightning
throws a warning message.

Example:

.. code-block:: python
from pytorch_lightning.core.lightning import LightningModule
from torch import nn
from pytorch_lightning.trainer.trainer import Trainer
class WeightSharingModule(LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(32, 10, bias=False)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.layer_3 = nn.Linear(32, 10, bias=False)
# TPU shared weights are copied independently
# on the XLA device and this line won't have any effect.
# However, it works fine for CPU and GPU.
self.layer_3.weight = self.layer_1.weight
def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.layer_3(x)
return x
def on_post_move_to_device(self):
# Weights shared after the model has been moved to TPU Device
self.layer_3.weight = self.layer_1.weight
model = WeightSharingModule()
trainer = Trainer(max_epochs=1, tpu_cores=8)
See `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_

-----------------------

Performance considerations
--------------------------
Expand Down
42 changes: 41 additions & 1 deletion pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from functools import wraps
from typing import Callable

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn


def auto_move_data(fn: Callable) -> Callable:
Expand Down Expand Up @@ -54,6 +54,7 @@ def forward(self, x):

@wraps(fn)
def auto_transfer_args(self, *args, **kwargs):
from pytorch_lightning.core.lightning import LightningModule
if not isinstance(self, LightningModule):
return fn(self, *args, **kwargs)

Expand All @@ -62,3 +63,42 @@ def auto_transfer_args(self, *args, **kwargs):
return fn(self, *args, **kwargs)

return auto_transfer_args


def parameter_validation(fn: Callable) -> Callable:
"""
Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method.
Validates that the module parameter lengths match after moving to the device. It is useful
when tying weights on TPU's.
Args:
fn: ``.to`` method
Note:
TPU's require weights to be tied/shared after moving the module to the device.
Failure to do this results in the initialization of new weights which are not tied.
To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook
which is called after the module has been moved to the device.
See Also:
- `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
"""

@wraps(fn)
def inner_fn(self, *args, **kwargs):
pre_layer_count = len(list(self.parameters()))
module = fn(self, *args, **kwargs)
self.on_post_move_to_device()
post_layer_count = len(list(self.parameters()))

if not pre_layer_count == post_layer_count:
rank_zero_warn(
f'The model layers do not match after moving to the target device.'
' If your model employs weight sharing on TPU,'
' please tie your weights using the `on_post_move_to_device` model hook.\n'
f'Layer count: [Before: {pre_layer_count} After: {post_layer_count}]'
)

return module

return inner_fn
16 changes: 16 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,22 @@ def on_after_backward(self):
"""

def on_post_move_to_device(self) -> None:
"""
Called in the ``parameter_validation`` decorator after :meth:`~pytorch_lightning.core.LightningModule.to`
is called. This is a good place to tie weights between modules after moving them to a device. Can be
used when training models with weight sharing properties on TPU.
Addresses the handling of shared weights on TPU:
https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks
Example::
def on_post_move_to_device(self):
self.decoder.weight = self.encoder.weight
"""


class DataHooks:
"""Hooks to be used with LightningDataModule."""
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
from torch.nn import Module

from pytorch_lightning.core.decorators import parameter_validation


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ['device', 'dtype']
Expand Down Expand Up @@ -50,6 +52,7 @@ def device(self, new_device: Union[str, torch.device]):
# Necessary to avoid infinite recursion
raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).')

@parameter_validation
def to(self, *args, **kwargs) -> Module:
"""Moves and/or casts the parameters and buffers.
Expand Down Expand Up @@ -86,6 +89,9 @@ def to(self, *args, **kwargs) -> Module:
... def __init__(self, weight: torch.Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
...
... def on_post_move_to_device(self):
... pass
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
Expand Down
63 changes: 60 additions & 3 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,32 @@

import pytest
import torch
from torch import nn

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.xla_device import XLADeviceUtils
from pytorch_lightning.utilities import _TPU_AVAILABLE
from tests.helpers.boring_model import BoringModel
from tests.helpers.utils import pl_multi_process_test


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
class WeightSharingModule(BoringModel):

def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(32, 10, bias=False)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.layer_3 = nn.Linear(32, 10, bias=False)
self.layer_3.weight = self.layer_1.weight

def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.layer_3(x)
return x


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_resume_training_on_cpu(tmpdir):
""" Checks if training can be resumed from a saved checkpoint on CPU"""
Expand Down Expand Up @@ -53,7 +70,7 @@ def test_resume_training_on_cpu(tmpdir):
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_if_test_works_after_train(tmpdir):
""" Ensure that .test() works after .fit() """
Expand All @@ -63,3 +80,43 @@ def test_if_test_works_after_train(tmpdir):
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
assert trainer.test(model) == 1


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_weight_tying_warning(tmpdir, capsys=None):
"""
Ensure a warning is thrown if model parameter lengths do not match
post moving to device.
"""

model = WeightSharingModule()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

with pytest.warns(UserWarning, match=r'The model layers do not match after moving to the target device.'):
result = trainer.fit(model)
assert result


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_if_weights_tied(tmpdir, capsys=None):
"""
Test if weights are properly tied on `on_post_move_to_device`.
Ensure no warning for parameter mismatch is thrown.
"""

class Model(WeightSharingModule):

def on_post_move_to_device(self):
self.layer_3.weight = self.layer_1.weight

model = Model()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

with pytest.warns(UserWarning) as warnings:
result = trainer.fit(model)
assert result

assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
assert trainer.test(model) == 1

0 comments on commit d2cd7cb

Please sign in to comment.