Skip to content

Commit

Permalink
Add more quantized modules (#3156)
Browse files Browse the repository at this point in the history
* Enable more modules to support true quantization
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
Co-authored-by: Daemyung Jang <quic_daemyung@quicinc.com>
  • Loading branch information
quic-kyunggeu committed Jul 10, 2024
1 parent d8dae7d commit 606dd84
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 28 deletions.
143 changes: 119 additions & 24 deletions TrainingExtensions/torch/src/python/aimet_torch/v2/nn/true_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,25 @@
""" Quantized modules"""

import contextlib
from functools import partial
import itertools
from abc import abstractmethod
from collections import OrderedDict
from functools import partial
from typing import Type, Any, Tuple, Dict, Optional, Callable
from weakref import WeakKeyDictionary

import torch
import torch.nn as nn
from torch import Tensor

from aimet_torch.v2.quantization.base import QuantizerBase
import aimet_torch.elementwise_ops as aimet_ops
from aimet_torch.v2.quantization import affine
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.float import FloatQuantizeDequantize
from aimet_torch.v2.quantization.tensor import QuantizedTensorBase
from aimet_torch.v2.utils import patch_attr, _ContextManager, allow_recompute
import aimet_torch.elementwise_ops as aimet_ops

from .base import BaseQuantizationMixin, _BaseQuantizedUnaryOpMixin, _BaseQuantizedBinaryOpMixin # pylint: disable=import-error
from .base import BaseQuantizationMixin, _BaseQuantizedUnaryOpMixin, \
_BaseQuantizedBinaryOpMixin # pylint: disable=import-error


def _quantize_if_applicable(data: Any, quantizer: Optional[QuantizerBase]):
Expand All @@ -72,6 +72,7 @@ def _quantize_if_applicable(data: Any, quantizer: Optional[QuantizerBase]):

return data


def _dequantize_if_applicable(data: torch.Tensor):
return data.dequantize() if isinstance(data, QuantizedTensorBase) else data

Expand Down Expand Up @@ -153,7 +154,7 @@ class QuantizationMixin(BaseQuantizationMixin): # pylint: disable=abstract-metho
qcls_to_cls = OrderedDict() # original class -> quantized class

_default_kernel: Optional[Callable] = None
_kernels = WeakKeyDictionary() # instance -> instance_kernel
_kernels = WeakKeyDictionary() # instance -> instance_kernel

@abstractmethod
def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -265,7 +266,7 @@ def get_kernel(self) -> Optional[Callable]:
return self.get_default_kernel()

@contextlib.contextmanager
def compute_encodings(self): # pylint: disable=missing-function-docstring
def compute_encodings(self): # pylint: disable=missing-function-docstring
ctx = _ContextManager(action=lambda: _enter_computing_encodings(self),
cleanup=lambda: _exit_compute_encodings(self))
with super().compute_encodings(), ctx:
Expand Down Expand Up @@ -343,7 +344,7 @@ def _view_as_qdq(quantizer):
# pylint: disable=arguments-differ, abstract-method, too-many-ancestors

class _QuantizedUnaryOpMixin(QuantizationMixin, _BaseQuantizedUnaryOpMixin):
def forward(self, *args, **kwargs): # pylint: disable=missing-function-docstring
def forward(self, *args, **kwargs): # pylint: disable=missing-function-docstring
kernel = self.get_kernel()

if not kernel or _is_computing_encodings(self):
Expand Down Expand Up @@ -382,7 +383,7 @@ def __quant_init__(self):
super().__quant_init__()
self.input_quantizers = nn.ModuleList([None, None])

def forward(self, *args, **kwargs): # pylint: disable=missing-function-docstring
def forward(self, *args, **kwargs): # pylint: disable=missing-function-docstring
kernel = self.get_kernel()

if not kernel or _is_computing_encodings(self):
Expand Down Expand Up @@ -420,8 +421,9 @@ def get_functional_args(self, x, y, *args, **kwargs) -> Tuple[Tuple, Dict]:
"""


class _QuantizedConvNdMixin(_QuantizedUnaryOpMixin): # pylint: disable=too-many-ancestors
class _QuantizedConvNdMixin(_QuantizedUnaryOpMixin): # pylint: disable=too-many-ancestors
""" Quantized ConvNd """

def __quant_init__(self):
if self.padding_mode != 'zeros':
msg = f'padding_mode other than "zeros" is currently not supported. (got {self.padding_mode})'
Expand All @@ -445,17 +447,17 @@ def get_functional_args(self, x):


@QuantizationMixin.implements(nn.Conv1d)
class QuantizedConv1d(_QuantizedConvNdMixin, nn.Conv1d): # pylint: disable=too-many-ancestors
class QuantizedConv1d(_QuantizedConvNdMixin, nn.Conv1d): # pylint: disable=too-many-ancestors
""" Quantized Conv1d """


@QuantizationMixin.implements(nn.Conv2d)
class QuantizedConv2d(_QuantizedConvNdMixin, nn.Conv2d): # pylint: disable=too-many-ancestors
class QuantizedConv2d(_QuantizedConvNdMixin, nn.Conv2d): # pylint: disable=too-many-ancestors
""" Quantized Conv2d """


@QuantizationMixin.implements(nn.Conv3d)
class QuantizedConv3d(_QuantizedConvNdMixin, nn.Conv3d): # pylint: disable=too-many-ancestors
class QuantizedConv3d(_QuantizedConvNdMixin, nn.Conv3d): # pylint: disable=too-many-ancestors
""" Quantized Conv3d """


Expand All @@ -478,15 +480,16 @@ class QuantizedGELU(_QuantizedUnaryOpMixin, nn.GELU):
""" Quantized GELU """

def get_functional_args(self, x):
return (x, ), {"approximate": self.approximate}
return (x,), {"approximate": self.approximate}


@QuantizationMixin.implements(nn.LayerNorm)
class QuantizedLayerNorm(_QuantizedUnaryOpMixin, nn.LayerNorm):
""" Quantized LayerNorm """

def get_functional_args(self, x):
return (x, self.normalized_shape), {"weight": self.weight, "bias": self.bias, "eps": self.eps}
return (x, self.normalized_shape,), {"weight": self.weight, "bias": self.bias, "eps": self.eps}


@QuantizationMixin.implements(nn.Softmax)
class QuantizedSoftmax(_QuantizedUnaryOpMixin, nn.Softmax):
Expand All @@ -495,12 +498,13 @@ class QuantizedSoftmax(_QuantizedUnaryOpMixin, nn.Softmax):
def get_functional_args(self, x):
return (x, self.dim), {}


@QuantizationMixin.implements(nn.Sigmoid)
class QuantizedSigmoid(_QuantizedUnaryOpMixin, nn.Sigmoid):
""" Quantized Sigmoid """

def get_functional_args(self, x):
return (x, ), {}
return (x,), {}


@QuantizationMixin.implements(nn.Tanh)
Expand All @@ -511,25 +515,116 @@ def get_functional_args(self, x):
return (x,), {}


@QuantizationMixin.implements(nn.ReLU)
class QuantizedReLU(_QuantizedUnaryOpMixin, nn.ReLU):
""" Quantized ReLU """

def get_functional_args(self, x):
return (x,), {"inplace": self.inplace}


@QuantizationMixin.implements(nn.PReLU)
class QuantizedPReLU(_QuantizedUnaryOpMixin, nn.PReLU):
""" Quantized PReLU """

def get_functional_args(self, x):
return (x, self.weight), {}


@QuantizationMixin.implements(nn.ConstantPad2d)
class QuantizedConstantPad2d(_QuantizedUnaryOpMixin, nn.ConstantPad2d):
""" Quantized ConstantPad2d """

def get_functional_args(self, x):
return (x, self.padding, "constant", self.value,), {}


@QuantizationMixin.implements(nn.Hardtanh)
class QuantizedHardtanh(_QuantizedUnaryOpMixin, nn.Hardtanh):
""" Quantized Hardtanh """

def get_functional_args(self, x):
return (x, self.min_val, self.max_val, self.inplace), {}


@QuantizationMixin.implements(nn.MaxPool2d)
class QuantizedMaxPool2d(_QuantizedUnaryOpMixin, nn.MaxPool2d):
""" Quantized MaxPool2d """

def get_functional_args(self, x):
return (x, self.kernel_size, self.stride, self.padding, self.dilation,), \
{"ceil_mode": self.ceil_mode, "return_indices": self.return_indices}


@QuantizationMixin.implements(nn.UpsamplingBilinear2d)
class QuantizedUpsamplingBilinear2d(_QuantizedUnaryOpMixin, nn.UpsamplingBilinear2d):
""" Quantized UpsamplingBilinear2d """

def get_functional_args(self, x):
return (x, self.size, self.scale_factor, self.mode, self.align_corners,), \
{"recompute_scale_factor": self.recompute_scale_factor}


@QuantizationMixin.implements(nn.PixelShuffle)
class QuantizedPixelShuffle(_QuantizedUnaryOpMixin, nn.PixelShuffle):
""" Quantized PixelShuffle """

def get_functional_args(self, x):
return (x, self.upscale_factor,), {}


def _as_is(self, *args, **kwargs): # pylint: disable=unused-argument
return args, kwargs

@QuantizationMixin.implements(aimet_ops.Sin)
class QuantizedSin(_QuantizedUnaryOpMixin, aimet_ops.Sin):
""" Quantized Sin """
get_functional_args = _as_is


@QuantizationMixin.implements(aimet_ops.Cos)
class QuantizedCos(_QuantizedUnaryOpMixin, aimet_ops.Cos):
""" Quantized Cos """
get_functional_args = _as_is


@QuantizationMixin.implements(aimet_ops.AvgPool2d)
class QuantizedAvgPool2d(_QuantizedUnaryOpMixin, aimet_ops.AvgPool2d):
""" Quantized AvgPool2d """
get_functional_args = _as_is


@QuantizationMixin.implements(aimet_ops.Reshape)
class QuantizedReshape(_QuantizedUnaryOpMixin, aimet_ops.Reshape):
""" Quantized Reshape """
get_functional_args = _as_is


@QuantizationMixin.implements(aimet_ops.RSqrt)
class QuantizedRSqrt(_QuantizedUnaryOpMixin, aimet_ops.RSqrt):
""" Quantized RSqrt """
get_functional_args = _as_is


@QuantizationMixin.implements(aimet_ops.Add)
class QuantizedAdd(_QuantizedBinaryOpMixin, aimet_ops.Add):
""" Quantized Add """

def get_functional_args(self, x, y):
return (x, y), {}
get_functional_args = _as_is


@QuantizationMixin.implements(aimet_ops.Multiply)
class QuantizedMultiply(_QuantizedBinaryOpMixin, aimet_ops.Multiply):
""" Quantized Multiply """

def get_functional_args(self, x, y):
return (x, y), {}
get_functional_args = _as_is


@QuantizationMixin.implements(aimet_ops.Subtract)
class QuantizedSubtract(_QuantizedBinaryOpMixin, aimet_ops.Subtract):
""" Quantized Subtract """
get_functional_args = _as_is


def get_functional_args(self, x, y):
return (x, y), {}
@QuantizationMixin.implements(aimet_ops.Divide)
class QuantizedDivide(_QuantizedBinaryOpMixin, aimet_ops.Divide):
""" Quantized Divide """
get_functional_args = _as_is
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def test_add_quantization_wrappers_with_modulelist_with_layers_to_ignore(self):
assert isinstance(sim.model.layers[2], aimet_nn.QuantizedConv2d)

assert isinstance(sim.model.layers_deep[0][0], aimet_nn.FakeQuantizedBatchNorm2d)
assert isinstance(sim.model.layers_deep[0][1], aimet_nn.FakeQuantizedReLU)
assert isinstance(sim.model.layers_deep[0][1], aimet_nn.QuantizedReLU)

assert type(sim.model.layers_deep[1]) == nn.Linear # layer ignored, so no QcQuantizeWrapper wrapper
assert isinstance(sim.model.layers_deep[2], aimet_nn.QuantizedLinear)
Expand Down
Loading

0 comments on commit 606dd84

Please sign in to comment.