Skip to content

Commit

Permalink
Fix bug in onnx perchannel mode with ConvTranspose (#3149)
Browse files Browse the repository at this point in the history
* Fix bug in onnx perchannel mode with ConvTranspose
* Add perchannel depthwise convtranspose tests
* Remove hardcoded path when getting model

---------

Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle committed Jul 7, 2024
1 parent dd3e0b3 commit 8983229
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 17 deletions.
3 changes: 2 additions & 1 deletion TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def _create_quant_info_object_for_param(self, param_name: str):
tensor_quantizer_params.num_output_channels = param_shape[0]
else:
tensor_quantizer_params.axis = self._get_quantization_axis(op)
tensor_quantizer_params.num_output_channels = param_shape[quant_info.channelAxis]
tensor_quantizer_params.num_output_channels = param_shape[tensor_quantizer_params.axis]
quant_info.channelAxis = tensor_quantizer_params.axis

return quant_info, tensor_quantizer_params

Expand Down
55 changes: 41 additions & 14 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from typing import Dict, List

import os
import tempfile
import torch.nn.functional as F
from torch import nn as nn
from torchvision.ops import roi_align
Expand Down Expand Up @@ -538,8 +539,8 @@ def __init__(self):
self.bn1 = torch.nn.BatchNorm2d(10)
self.relu1 = torch.nn.ReLU()

self.conv2 = torch.nn.ConvTranspose2d(10, 10, 3)
self.bn2 = torch.nn.BatchNorm2d(10)
self.conv2 = torch.nn.ConvTranspose2d(10, 20, 3)
self.bn2 = torch.nn.BatchNorm2d(20)

# pylint: disable=arguments-differ
def forward(self, x):
Expand All @@ -552,6 +553,13 @@ def forward(self, x):
x = self.bn2(x)
return x

class DepthwiseTransposedConvModel(TransposedConvModel):

def __init__(self):
super(DepthwiseTransposedConvModel, self).__init__()
self.conv1 = torch.nn.ConvTranspose2d(10, 10, 3, groups=10)
self.conv2 = torch.nn.ConvTranspose2d(10, 20, 3, groups=10)


class TransposedConvModelWithoutBN(torch.nn.Module):
"""
Expand Down Expand Up @@ -1105,21 +1113,40 @@ def multi_output_model():
return model

def transposed_conv_model():
x = torch.randn(10, 10, 4, 4, requires_grad=True)
model = TransposedConvModel()
with tempfile.TemporaryDirectory() as save_dir:
x = torch.randn(10, 10, 4, 4, requires_grad=True)
model = TransposedConvModel()
save_path = os.path.join(save_dir, "model_transposed_conv.onnx")
# Export the model
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
save_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'])
model = ONNXModel(load_model(save_path))
return model

# Export the model
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
"./model_transposed_conv.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'])
model = ONNXModel(load_model('./model_transposed_conv.onnx'))
def depthwise_transposed_conv_model():
with tempfile.TemporaryDirectory() as save_dir:
x = torch.randn(10, 10, 4, 4, requires_grad=True)
model = DepthwiseTransposedConvModel()
save_path = os.path.join(save_dir, "model_transposed_conv.onnx")
# Export the model
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
save_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'])
model = ONNXModel(load_model(save_path))
return model


def transposed_conv_model_without_bn():
x = torch.randn(10, 10, 4, 4, requires_grad=True)
model = TransposedConvModelWithoutBN()
Expand Down
9 changes: 8 additions & 1 deletion TrainingExtensions/onnx/test/python/test_connected_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,18 @@ def test_multi_inputs_model(self):
def test_transposed_conv_model(self):
if version.parse(torch.__version__) >= version.parse("1.13"):
model = models_for_tests.transposed_conv_model()

activations = set()
for node in model.graph().node:
if node.op_type != "Identity":
activations.add(node.input[0])
activations.add(node.output[0])

conn_graph = ConnectedGraph(model)
assert len(conn_graph.get_all_ops()) == 5

products = conn_graph.get_all_products()
assert len(products) == 12
assert len(products) == len(activations) + len(model.graph().initializer)
assert {'bn1.weight',
'bn1.bias'}.issubset({product for product in products})

Expand Down
34 changes: 33 additions & 1 deletion TrainingExtensions/onnx/test/python/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
from aimet_onnx.qc_quantize_op import OpMode
from aimet_onnx.utils import make_dummy_input
from models.models_for_tests import SingleResidual
from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model , multi_output_model, custom_add_model, build_lstm_gru_dummy_model
from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model , multi_output_model, custom_add_model, build_lstm_gru_dummy_model, \
transposed_conv_model, depthwise_transposed_conv_model


class DummyModel(SingleResidual):
Expand Down Expand Up @@ -440,6 +441,37 @@ def dummy_callback(session, args):
assert encoding.bw == 8
assert encoding.min != encoding.max

@pytest.mark.parametrize("model_factory", (transposed_conv_model, depthwise_transposed_conv_model))
def test_per_channel_quant_conv_transpose(self, model_factory):
model = model_factory()
conv_transpose_weight_names = []
for node in model.graph().node:
if node.op_type == "ConvTranspose":
conv_transpose_weight_names.append(node.input[1])

with tempfile.TemporaryDirectory() as tempdir:
sim = QuantizationSimModel(model, use_cuda=False, config_file=get_path_for_per_channel_config(),
path=tempdir)

def dummy_callback(session, args):
in_tensor = {'input': np.random.rand(10, 10, 4, 4).astype(np.float32)}
session.run(None, in_tensor)

for param_name in sim.param_names:
if param_name in conv_transpose_weight_names:
for weight in sim.model.graph().initializer:
if weight.name == param_name:
break
else:
raise RuntimeError(f"Param {param_name} not found in model")
qc_op = sim.qc_quantize_op_dict[param_name]
assert qc_op.quant_info.usePerChannelMode
assert qc_op.quant_info.enabled
assert qc_op.quant_info.channelAxis == 1
assert len(qc_op.encodings) == weight.dims[1]

sim.compute_encodings(dummy_callback, None)

def test_load_encodings_ptq(self):
model = single_residual_model().model
with tempfile.TemporaryDirectory() as tempdir:
Expand Down

0 comments on commit 8983229

Please sign in to comment.