Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feature/damian/refact…
Browse files Browse the repository at this point in the history
…or_injection
  • Loading branch information
bogunowicz@arrival.com committed Jul 27, 2023
2 parents 6d3b2b9 + ebc4ac6 commit 224412f
Show file tree
Hide file tree
Showing 8 changed files with 7,628 additions and 29 deletions.

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion src/sparseml/exporters/transforms/conv_to_qlinearconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def _transform_match(self, model: ModelProto, match: MatchResult):
model, input_quant, include_target=False
)
bias_scale = input_quantize_params.scale * weight_quantize_params.scale
quantized_bias = quantize_array(bias, bias_scale, 0, numpy.int32)
bias_zero_point = numpy.zeros(bias_scale.shape, dtype=numpy.int32)
quantized_bias = quantize_array(
bias, bias_scale, bias_zero_point, numpy.int32
)
quantized_bias_name = f"{conv_node.name}.bias_quantized"
quantized_bias_initializer = numpy_helper.from_array(
quantized_bias, name=quantized_bias_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,17 @@ def add_quantized_conv_matmul_add_ops(
output_quantize_node=output_quantize_node,
)
model.graph.node.append(qadd_node)

# bias has same scale as future rescale op
rescale_scale = quantized_bias_scale
mul_input_node_name = qadd_node.name

# bias has same scale as future rescale op, unless doing channel-wise Conv
if weight_quantize_params.scale.size > 1 and node.op_type == "Conv":
# channel-wise Conv
rescale_scale = _create_rescale_init(
node, input_quantize_params, weight_quantize_params, reshape=True
)
model.graph.initializer.append(rescale_scale)
else:
rescale_scale = quantized_bias_scale
else:
rescale_scale = _create_rescale_init(
node, input_quantize_params, weight_quantize_params
Expand Down Expand Up @@ -261,7 +268,7 @@ def _quantize_bias(
) -> Tuple[TensorProto, TensorProto, TensorProto]:
bias_initializer = numpy_helper.to_array(bias_initializer)
bias_scale = input_quantize_params.scale * weight_quantize_params.scale
bias_zero_point = 0
bias_zero_point = numpy.zeros(bias_scale.shape, dtype=numpy.int32)
quantized_bias = quantize_array(
bias_initializer, bias_scale, bias_zero_point, dtype=numpy.int32
)
Expand Down Expand Up @@ -290,9 +297,11 @@ def _quantize_bias(


def _create_rescale_init(
node, input_quantize_params, weight_quantize_params
node, input_quantize_params, weight_quantize_params, reshape=False
) -> TensorProto:
output_scale = input_quantize_params.scale * weight_quantize_params.scale
if reshape: # for channel-wise Conv
output_scale = output_scale.reshape(1, output_scale.shape[0], 1, 1)
return numpy_helper.from_array(
numpy.asarray(output_scale), name=f"{node.name}_quant.rescale.scale"
)
Expand Down
27 changes: 18 additions & 9 deletions src/sparseml/exporters/transforms/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def assert_node_type(node: NodeProto, op: Union[List[str], Set[str], str]) -> bo


def quantize_array(
array: numpy.ndarray, scale: float, zero_point: int, dtype: Any = numpy.uint8
array: numpy.ndarray,
scale: numpy.ndarray,
zero_point: numpy.ndarray,
dtype: Any = numpy.uint8,
) -> numpy.ndarray:
try:
import torch # noqa: F401
Expand All @@ -139,14 +142,20 @@ def quantize_array(
tensor_dtype = torch.qint32

tensor = torch.Tensor(array.copy()).to(torch.float32)
if isinstance(scale, numpy.ndarray):
scale = scale.item()
if isinstance(zero_point, numpy.ndarray):
zero_point = zero_point.item()

quant_tensor = torch.quantize_per_tensor(
tensor, scale, zero_point, tensor_dtype
)
if scale.size > 1 and zero_point.size > 1: # per-channel quantization
scale = torch.Tensor(scale.copy()).to(torch.float32)
zero_point = torch.Tensor(zero_point.copy()).to(torch.int32)
quant_tensor = torch.quantize_per_channel(
tensor,
scale,
zero_point,
0, # channel axis
tensor_dtype,
)
else: # per-tensor quantization
quant_tensor = torch.quantize_per_tensor(
tensor, scale.item(), zero_point.item(), tensor_dtype
)
return quant_tensor.int_repr().numpy()
except ModuleNotFoundError as err:
_LOGGER.debug(f"Error: {err}. Defaulting to numpy implementation.")
Expand Down
8 changes: 8 additions & 0 deletions src/sparseml/pytorch/sparsification/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ class QConfigProperties:
Default is torch.qint8.
:param activation_bits: number of bits for activations. Default is 8.
:param weight_bits: number of bits for weights. Default is 8.
:param activation_strategy: "tensor" to quantize over the whole activation tensor,
or "channel" to quantize per channel. Default is "tensor"
:param weight_strategy: "tensor" to quantize over the whole weight tensor, or
"channel" to quantize per channel. Default is "tensor"
:param tensorrt: if True sets quantization configuration for compatibility with
explict quantization as supported by TensorRT 8.2.
"""
Expand All @@ -150,6 +154,8 @@ class QConfigProperties:
weight_dtype: torch.dtype = torch.qint8
activation_bits: int = 8
weight_bits: int = 8
activation_strategy: str = "tensor"
weight_strategy: str = "tensor"
activation_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict)
weight_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict)
tensorrt: bool = False
Expand Down Expand Up @@ -552,6 +558,7 @@ def get_qat_qconfig(qproperties: QConfigProperties) -> "torch.quantization.QConf
"""
activation_observer = get_observer(
qproperties.symmetric_activations,
qproperties.activation_strategy,
qproperties.activation_dtype,
qproperties.activation_bits,
qproperties.reduce_range,
Expand All @@ -560,6 +567,7 @@ def get_qat_qconfig(qproperties: QConfigProperties) -> "torch.quantization.QConf

weight_observer = get_observer(
qproperties.symmetric_weights,
qproperties.weight_strategy,
qproperties.weight_dtype,
qproperties.weight_bits,
False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import torch
from packaging import version
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
from torch.nn import Identity


Expand Down Expand Up @@ -78,6 +78,12 @@ class QuantizationArgs(BaseModel):
default=False,
description="set True to use symmetric quantization. Default False",
)
strategy: str = Field(
default="tensor",
description=(
"scope of the quantization to be applied. can be 'tensor' or 'channel'",
),
)
kwargs: Dict[str, Any] = Field(
default_factory=dict,
description=(
Expand Down Expand Up @@ -106,12 +112,20 @@ def get_observer(self) -> "torch.quantization.FakeQuantize":
"""
return get_observer(
symmetric=self.symmetric,
strategy=self.strategy,
dtype=torch.qint8,
bits=self.num_bits,
reduce_range=self.kwargs.get("reduce_range", False),
qconfig_kwargs=self.kwargs,
)

@validator("strategy")
def validate_strategy(cls, value):
valid_scopes = ["tensor", "channel"]
if value not in valid_scopes:
raise ValueError(f"`strategy` must be one of {valid_scopes}, got {value}")
return value


class QuantizationScheme(BaseModel):
"""
Expand Down Expand Up @@ -276,21 +290,31 @@ def compute_range(dtype: torch.dtype, bits: int):

def get_observer(
symmetric: bool,
strategy: str,
dtype: torch.dtype,
bits: int,
reduce_range: bool,
qconfig_kwargs: Dict[str, Any],
):
qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine
quant_min, quant_max, is_custom_qrange = compute_range(dtype, bits)

observer_cls = torch_quantization.MovingAverageMinMaxObserver
observer_kwargs = dict(
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
)

if strategy == "channel":
qscheme = torch.per_channel_symmetric if symmetric else torch.per_channel_affine
observer_cls = torch_quantization.MovingAveragePerChannelMinMaxObserver
observer_kwargs = dict(
ch_axis=0,
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
)
else: # default to tensor strategy
qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine
observer_cls = torch_quantization.MovingAverageMinMaxObserver
observer_kwargs = dict(
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
)
"""
in torch 1.9.1, quant_min and quant_max are not passed to observer:
https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,14 @@ def test_zero_point_is_128():
def test_standard_qrange_zero_points():
bits = 8

fake_quantize = get_observer(True, torch.qint8, bits, False, {})()
fake_quantize = get_observer(True, "tensor", torch.qint8, bits, False, {})()
fake_quantize(torch.randn(10, 10))
assert fake_quantize.quant_min == -128
assert fake_quantize.quant_max == 127
_, zero_point = fake_quantize.calculate_qparams()
assert zero_point[0] == 0

fake_quantize = get_observer(True, torch.quint8, bits, False, {})()
fake_quantize = get_observer(True, "tensor", torch.quint8, bits, False, {})()
fake_quantize(torch.randn(10, 10))
assert fake_quantize.quant_min == 0
assert fake_quantize.quant_max == 255
Expand All @@ -324,7 +324,7 @@ def test_custom_qrange_zero_points():
# non 8 bits is what makes it a custom qrange
bits = 4

fake_quantize = get_observer(True, torch.qint8, bits, False, {})()
fake_quantize = get_observer(True, "tensor", torch.qint8, bits, False, {})()
fake_quantize(torch.randn(10, 10))
assert fake_quantize.quant_min == -8
assert fake_quantize.quant_max == 7
Expand All @@ -333,7 +333,7 @@ def test_custom_qrange_zero_points():
_, zero_point = fake_quantize.calculate_qparams()
assert zero_point[0] == 0

fake_quantize = get_observer(True, torch.quint8, bits, False, {})()
fake_quantize = get_observer(True, "tensor", torch.quint8, bits, False, {})()
fake_quantize(torch.randn(10, 10))
assert fake_quantize.quant_min == 0
assert fake_quantize.quant_max == 15
Expand Down
95 changes: 94 additions & 1 deletion tests/sparseml/pytorch/test_torch_to_onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import onnxruntime as ort
import pytest
import torch
from packaging import version

from sparseml.exporters.onnx_to_deepsparse import ONNXToDeepsparse
from sparseml.onnx.utils.helpers import get_init_by_name
from sparseml.onnx.utils.helpers import get_init_by_name, get_node_by_id
from sparseml.pytorch.models.registry import ModelRegistry
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.sparsification.quantization import QuantizationModifier
Expand All @@ -45,6 +46,19 @@
symmetric: True
"""

CHANNEL_QUANT_RECIPE = """
!QuantizationModifier
start_epoch: 0.0
scheme:
input_activations:
num_bits: 8
symmetric: False
weights:
num_bits: 4
symmetric: True
strategy: "channel"
"""


def _get_4bit_modules(model):
fake_quant_modules = [
Expand Down Expand Up @@ -108,6 +122,7 @@ def test_export_4bit_model(tmp_path, model, sample_batch):
ONNXToDeepsparse(use_qlinear_conv=True).export(
onnx_model_new, new_dir / "model.onnx"
)
onnx_model_new = onnx.load(new_dir / "model.onnx")
validate_onnx(str(new_dir / "model.onnx"))

# ensure export didn't modify original model
Expand Down Expand Up @@ -150,6 +165,84 @@ def test_export_4bit_model_range(tmp_path):
assert all(conv_range <= 16 for name, conv_range in conv_quant_ranges.items())


@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("2.0"),
reason="Channel-wise quantization only supported for ONNX opset version 13+",
)
def test_export_per_channel_conv_4bit_model(tmp_path):
model, sample_batch = ConvNet(), torch.randn(1, 3, 28, 28)
new_dir = tmp_path / "new_exporter"
new_dir.mkdir()

manager = ScheduledModifierManager.from_yaml(CHANNEL_QUANT_RECIPE)
manager.apply(model)

new_exporter = TorchToONNX(sample_batch)
new_exporter.export(model, new_dir / "model.onnx")
onnx_model = onnx.load(new_dir / "model.onnx")
ONNXToDeepsparse(use_qlinear_conv=False).export(onnx_model, new_dir / "model.onnx")
onnx_model = onnx.load(new_dir / "model.onnx")
validate_onnx(onnx_model)

add_value = get_init_by_name(
onnx_model, "/seq/conv1/module/Conv_bias_add.bias_quantized"
)
bias = onnx.numpy_helper.to_array(add_value)
mul_value = get_init_by_name(
onnx_model, "/seq/conv1/module/Conv_quant.rescale.scale"
)
rescale = onnx.numpy_helper.to_array(mul_value)
assert bias.shape == rescale.shape == (1, 16, 1, 1)

conv_int_node = get_node_by_id(onnx_model, "/seq/conv1/module/Conv_output_0_quant")
_, _, _, w_zero_point = conv_int_node.input
zero_value = get_init_by_name(onnx_model, w_zero_point)
zero_point = onnx.numpy_helper.to_array(zero_value)
assert zero_point.size == 16 and zero_point.ndim == 1

# this checks all the I/O shapes check out
# don't call session.run() b/c ort doesn't support channel-wise for ConvInteger
ort.InferenceSession(new_dir / "model.onnx")


@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("2.0"),
reason="Channel-wise quantization only supported for ONNX opset version 13+",
)
@pytest.mark.parametrize(
"model,sample_batch",
[
(MLPNet(), torch.randn(8)),
(MLPNet(), torch.randn(10, 8)),
(LinearNet(), torch.randn(8)),
(LinearNet(), torch.randn(10, 8)),
],
)
def test_export_and_load_per_channel_model(tmp_path, model, sample_batch):
new_dir = tmp_path / "new_exporter"
new_dir.mkdir()

manager = ScheduledModifierManager.from_yaml(CHANNEL_QUANT_RECIPE)
manager.apply(model)

new_exporter = TorchToONNX(sample_batch)
new_exporter.export(model, new_dir / "model.onnx")
onnx_model = onnx.load(new_dir / "model.onnx")
ONNXToDeepsparse(use_qlinear_conv=False).export(onnx_model, new_dir / "model.onnx")
onnx_model = onnx.load(new_dir / "model.onnx")
validate_onnx(onnx_model)

session = ort.InferenceSession(new_dir / "model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
session.run(
[output_name],
sample_batch
if isinstance(sample_batch, dict)
else {input_name: sample_batch.numpy()},
)


@pytest.mark.parametrize(
"model,sample_batch",
[
Expand Down

0 comments on commit 224412f

Please sign in to comment.