Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs for Pruning, Quantization, and SWA #6041

Merged
merged 22 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions docs/source/advanced/pruning_quantization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
.. testsetup:: *

import os
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.core.lightning import LightningModule

.. _pruning_quantization:

########################
Pruning and Quantization
########################

Pruning and Quantization are techniques to compress model size for deployment, allowing inference speed up and energy saving without significant accuracy losses.

*******
Pruning
*******

.. warning::

Pruning is in beta and subject to change.

Pruning is a technique which focuses on eliminating some of the model weights to reduce the model size and decrease inference requirements.

Pruning has been shown to achieve significant efficiency improvements while minimizing the drop in model performance (prediction quality). Model pruning is recommended for cloud endpoints, deploying models on edge devices, or mobile inference (among others).

To enable pruning during training in Lightning, simply pass in the :class:`~pytorch_lightning.callbacks.ModelPruning` callback to the Lightning Trainer. PyTorch's native pruning implementation is used under the hood.

This callback supports multiple pruning functions: pass any `torch.nn.utils.prune <https://pytorch.org/docs/stable/nn.html#utilities>`_ function as a string to select which weights to prune (`random_unstructured <https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.random_unstructured.html#torch.nn.utils.prune.random_unstructured>`_, `RandomStructured <https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.RandomStructured.html#torch.nn.utils.prune.RandomStructured>`_, etc) or implement your own by subclassing `BasePruningMethod <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#extending-torch-nn-utils-prune-with-custom-pruning-functions>`_.

TODO: what do you have to set?
carmocca marked this conversation as resolved.
Show resolved Hide resolved

You can also set the pruning percentage, perform iterative pruning, apply the `lottery ticket hypothesis <https://arxiv.org/pdf/1803.03635.pdf>`_ and more!

.. code-block:: python

from pytorch_lightning.callbacks import ModelPruning

def compute_amount(epoch):
if epoch == 10:
return 0.5

elif epoch == 50:
return 0.25

elif 75 < epoch < 99 :
return 0.01

trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=compute_amount)])


************
Quantization
************

.. warning ::
Quantization is in beta and subject to change.

Model quantization is another performance optimization technique that allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating-point precision. Moreover smaller models also speed up model loading.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Quantization Aware Training (QAT) mimics the effects of quantization during training: all computations are carried out in floating points while training, simulating the effects of ints, and weights and activations are quantized into lower precision only once training is completed.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved

Quantization is useful when serving large models on machines with limited memory or when there's a need to switch between models where each model has to be loaded from the drive.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Lightning includes :class:`~pytorch_lightning.callbacks.QuantizationAwareTraining` callback (using PyTorch native quantization, read more `here <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_), which allows creating fully quantized models (compatible with torchscript).
carmocca marked this conversation as resolved.
Show resolved Hide resolved

To quantize your model, specify TODO(borda).
tchaton marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

from pytorch_lightning.callbacks import QuantizationAwareTraining

class RegressionModel(LightningModule):

def __init__(self):
super().__init__()
self.layer_0 = nn.Linear(16, 64)
self.layer_0a = torch.nn.ReLU()
self.layer_1 = nn.Linear(64, 64)
self.layer_1a = torch.nn.ReLU()
self.layer_end = nn.Linear(64, 1)

def forward(self, x):
x = self.layer_0(x)
x = self.layer_0a(x)
x = self.layer_1(x)
x = self.layer_1a(x)
x = self.layer_end(x)
return x

qcb = QuantizationAwareTraining(
# specification of quant estimation quaity
carmocca marked this conversation as resolved.
Show resolved Hide resolved
observer_type='histogram',
# specify which layers shall be merged together to increase efficiency
modules_to_fuse=[(f'layer_{i}', f'layer_{i}a') for i in range(2)]
# make the model torchanble
tchaton marked this conversation as resolved.
Show resolved Hide resolved
input_compatible=False,
)

trainer = Trainer(callbacks=[qcb])
qmodel = RegressionModel()
trainer.fit(qmodel, ...)

batch = iter(my_dataloader()).next()
qmodel(qmodel.quant(batch[0]))

tsmodel = qmodel.to_torchscript()
tsmodel(tsmodel.quant(batch[0]))

You can also set `input_compatible=True` to make your model compatible with all original input/outputs, in such case the model is wrapped in a shell with entry/exit layers.

.. code-block:: python

batch = iter(my_dataloader()).next()
qmodel(batch[0])
1 change: 1 addition & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Lightning has a few built-in callbacks.
ModelPruning
ProgressBar
ProgressBarBase
QuantizationAwareTraining
StochasticWeightAveraging

----------
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ PyTorch Lightning Documentation
common/single_gpu
advanced/sequences
advanced/training_tricks
advanced/pruning_quantization
advanced/transfer_learning
advanced/tpu
advanced/cluster
Expand Down
55 changes: 29 additions & 26 deletions pytorch_lightning/callbacks/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,6 @@ def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool:


class QuantizationAwareTraining(Callback):
"""
Quantization allows speeding up inference and decreasing memory requirements by performing computations
and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision.
We use native PyTorch API so for more information see
`Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>_`

.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
"""

OBSERVER_TYPES = ('histogram', 'average')

def __init__(
Expand All @@ -103,30 +94,42 @@ def __init__(
input_compatible: bool = True,
) -> None:
"""
Quantization allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision.
We use native PyTorch API so for more information see `Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_.

.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.


Args:
qconfig: define quantization configuration see: `torch.quantization.QConfig
<https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>_`
or use pre-defined: 'fbgemm' for server inference and 'qnnpack' for mobile inference

qconfig: quantization configuration:

- 'fbgemm' for server inference.
- 'qnnpack' for mobile inference.
- define custom quantization configuration (see `torch.quantization.QConfig <https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>`_).

observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
and ``HistogramObserver`` as "histogram" which is more computationally expensive
collect_quantization: count or custom function to collect quantization statistics
and ``HistogramObserver`` as "histogram" which is more computationally expensive.

collect_quantization: count or custom function to collect quantization statistics:

- with default ``None`` the quantization observer is called each module forward,
typical use-case can be collecting extended statistic when user uses image/data augmentation
- custom call count to set a fixed number of calls, starting from the beginning
- custom ``Callable`` function with single trainer argument,
see example when you limit call only for last epoch::
- ``None`` (deafult). The quantization observer is called in each module forward (useful for collecting extended statistic when useing image/data augmentation).
- ``int``. Use to set a fixed number of calls, starting from the beginning.
- ``Callable``. Custom function with single trainer argument. See this example to trigger only the last epoch:

def custom_trigger_last(trainer):
return trainer.current_epoch == (trainer.max_epochs - 1)
.. code-block:: python

QuantizationAwareTraining(collect_quantization=custom_trigger_last)
def custom_trigger_last(trainer):
return trainer.current_epoch == (trainer.max_epochs - 1)

QuantizationAwareTraining(collect_quantization=custom_trigger_last)

modules_to_fuse: allows you fuse a few layers together as shown in `diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.

modules_to_fuse: allows you fuse a few layers together as shown in `diagram
<https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>_`
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
but break compatibility to torchscript
but break compatibility to torchscript.

"""
if not isinstance(qconfig, (str, QConfig)):
raise MisconfigurationException(f"Unsupported qconfig: f{qconfig}.")
Expand Down