Skip to content

Commit

Permalink
ONNXToDeepSparse matmul integer conversion patches for whipser support (
Browse files Browse the repository at this point in the history
#1616)

* ONNXToDeepSparse matmul integer conversion patches for whipser support

* fix dangling graph output issue

* unit tests
  • Loading branch information
bfineran committed Jun 13, 2023
1 parent 6c33599 commit ce0aaea
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class MatMulAddToMatMulIntegerAddCastMul(OnnxTransform):
A transform for converting a MatMul with kernel and bias into a
quantized representation
If add or bias initializer does not exist, the bias is skipped
```
| weight (initializer)
| |
Expand All @@ -44,9 +46,9 @@ class MatMulAddToMatMulIntegerAddCastMul(OnnxTransform):
| | |
| Q/Dq optional Transpose
| | |
| MatMul bias (initializer)
| MatMul bias (initializer) (optional)
| | |
| Add
| Add (optional)
```
(where `Q` is QuantizeLinear, and `Dq` is DequantizeLinear)
into
Expand All @@ -55,7 +57,7 @@ class MatMulAddToMatMulIntegerAddCastMul(OnnxTransform):
| |
| MatMulInteger (with constant uint8 kernel)
| |
| Add (constant bias + zero point correction)
| Add (constant bias + zero point correction) (optional)
| |
| Cast (INT32 -> FP32)
| |
Expand All @@ -78,16 +80,18 @@ def transform(self, model: ModelProto) -> ModelProto:
optional_node("Transpose"),
],
],
children_ops=[["Add"]],
children_ops=[[optional_node("Add")]],
)
for match in matches:
# NOTE: bias could be either input 0 or 1 of add node
bias_init = graph.get_init_by_name(match.children[0][0].input[1])
if bias_init is None:
bias_init = graph.get_init_by_name(match.children[0][0].input[0])
add_node = match.children[0][0]
bias_init = None
if add_node:
# NOTE: bias could be either input 0 or 1 of add node
# if add does not have a bias initializer,
# still do conversion, but do not fold the bias add to rescale
bias_init = graph.get_init_by_name(match.children[0][0].input[1])
if bias_init is None:
# bias initializer for add not present
continue
bias_init = graph.get_init_by_name(match.children[0][0].input[0])
self.log_match(match)
self._transform_match(graph, model, match, bias_init)
return model
Expand Down Expand Up @@ -121,8 +125,8 @@ def _transform_match(
input_quantize_params=input_quantize_params,
weight_quantize_params=weight_quantize_params,
bias_initializer=bias_init,
bias_add_name=add.name,
target_output=add.output[0],
bias_add_name=add.name if add else None,
target_output=add.output[0] if add and bias_init else None,
transpose_weight=opt_transpose is not None,
)

Expand All @@ -134,4 +138,6 @@ def _transform_match(
if len(graph.get_node_children(input_quant)) == 1:
self.delete_node_deferred(input_quant)
self.delete_node_deferred(matmul)
self.delete_node_deferred(add)
if bias_init is not None:
# add converted to quantized - delete previous add node
self.delete_node_deferred(add)
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,12 @@ def transform(self, model: ModelProto) -> ModelProto:
op_type="MatMul",
)
for match in matches:
is_parameterized = False
for quantize_linear_parent in [match.parents[0][0], match.parents[1][0]]:
if graph.get_init_by_name(quantize_linear_parent.input[0]):
continue
is_parameterized = True
if is_parameterized:
continue
self.log_match(match)
self._do_transform(model, match)
return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def add_quantized_conv_matmul_add_ops(
weight_quantize_node: NodeProto,
input_quantize_params: QuantizationParams,
weight_quantize_params: QuantizationParams,
bias_initializer: TensorProto,
bias_initializer: Optional[TensorProto],
bias_add_name: str,
target_output: str,
transpose_weight: bool,
Expand All @@ -49,6 +49,13 @@ def add_quantized_conv_matmul_add_ops(
Adds new quantized ops to graph, does not perform any checks or deletions
(should be called by the operator main conversion function)
"""
node_output_orig = node.output[0]
if not target_output and (
any(output.name == node_output_orig for output in model.graph.output)
):
# original node output is a graph output, make that the quant block
# output target id
target_output = node_output_orig

# Quantize weights and add to graph
quantized_weight_initializer = _quantize_weight_initializer(
Expand All @@ -65,30 +72,43 @@ def add_quantized_conv_matmul_add_ops(
)
model.graph.node.append(integer_op_node)

# Add bias + zero point correction; quantize bias and add it to graph
(
quantized_bias_initializer,
quantized_bias_scale,
quantize_bias_zero_point,
) = _quantize_bias(
node,
bias_initializer,
input_quantize_params,
weight_quantize_params,
bias_add_name,
)
model.graph.initializer.append(quantized_bias_initializer)
model.graph.initializer.append(quantized_bias_scale)
model.graph.initializer.append(quantize_bias_zero_point)
if bias_initializer is not None:
# Add bias + zero point correction; quantize bias and add it to graph
(
quantized_bias_initializer,
quantized_bias_scale,
quantize_bias_zero_point,
) = _quantize_bias(
node,
bias_initializer,
input_quantize_params,
weight_quantize_params,
bias_add_name,
)
model.graph.initializer.append(quantized_bias_initializer)
model.graph.initializer.append(quantized_bias_scale)
model.graph.initializer.append(quantize_bias_zero_point)

# Create Quantized Bias Add node and add it to graph
qadd_node = _create_qadd_node(
node,
integer_op_output="{}_quant".format(node.output[0]),
quantized_bias_name=quantized_bias_initializer.name,
output_quantize_node=output_quantize_node,
)
model.graph.node.append(qadd_node)

# Create Quantized Bias Add node and add it to graph
qadd_node = _create_qadd_node(
node,
integer_op_output="{}_quant".format(node.output[0]),
quantized_bias_name=quantized_bias_initializer.name,
output_quantize_node=output_quantize_node,
)
model.graph.node.append(qadd_node)
# bias has same scale as future rescale op
rescale_scale = quantized_bias_scale
mul_input_node_name = qadd_node.name
else:
rescale_scale = _create_rescale_init(
node, input_quantize_params, weight_quantize_params
)
model.graph.initializer.append(rescale_scale)
# cast node should come directly after quantize op output instead of add
output_quantize_node = output_quantize_node or integer_op_node
mul_input_node_name = output_quantize_node.name

# create Cast node and add it to graph
cast_node = _create_cast_node(
Expand All @@ -100,9 +120,11 @@ def add_quantized_conv_matmul_add_ops(
# create Mul node for rescale
mul_node = _create_mul_node(
cast_node_output=cast_node.output[0],
quantized_bias_scale_name=quantized_bias_scale.name,
quant_add_name=qadd_node.name,
rescale_scale_name=rescale_scale.name,
input_node_name=mul_input_node_name,
target_output=target_output,
model=model,
node_output_orig=node_output_orig,
)
model.graph.node.append(mul_node)

Expand All @@ -111,15 +133,22 @@ def add_quantized_conv_matmul_add_ops(

def _create_mul_node(
cast_node_output: str,
quantized_bias_scale_name: str,
quant_add_name: str,
rescale_scale_name: str,
input_node_name: str,
target_output: str,
model: ModelProto,
node_output_orig: str,
) -> NodeProto:
mul_node_inputs = [
cast_node_output, # a
quantized_bias_scale_name, # b -> rescale factor
rescale_scale_name, # b -> rescale factor
]
mul_node_name = "{}_rescale_mul".format(quant_add_name)
mul_node_name = "{}_rescale_mul".format(input_node_name)
if target_output is None:
target_output = mul_node_name
# since we skip the add conversion,
# update model to point all outputs of matmul/conv to the rescale mul
_update_model_input_id(model, node_output_orig, target_output)
mul_node = onnx.helper.make_node(
"Mul",
mul_node_inputs,
Expand All @@ -129,6 +158,13 @@ def _create_mul_node(
return mul_node


def _update_model_input_id(model: ModelProto, old_id: str, new_id: str):
for node in model.graph.node:
for idx, input_name in enumerate(node.input):
if input_name == old_id:
node.input[idx] = new_id


def _create_cast_node(
quant_add_name: str, output_quantize_node: Optional[NodeProto] = None
) -> NodeProto:
Expand Down Expand Up @@ -253,6 +289,15 @@ def _quantize_bias(
)


def _create_rescale_init(
node, input_quantize_params, weight_quantize_params
) -> TensorProto:
output_scale = input_quantize_params.scale * weight_quantize_params.scale
return numpy_helper.from_array(
numpy.asarray(output_scale), name=f"{node.name}_quant.rescale.scale"
)


def _quantize_weight_initializer(
node: NodeProto,
weight_quantize_params: QuantizationParams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,25 +146,22 @@ def test_without_transpose(onnx_model: onnx.ModelProto):
]


def test_no_bias_changes_nothing(onnx_model: onnx.ModelProto):
def test_matmul_no_bias_converts(onnx_model: onnx.ModelProto):
# remove "bias" initializer and "add" node
assert onnx_model.graph.initializer.pop().name == "bias"
assert onnx_model.graph.node.pop().name == "add"
validate_onnx(onnx_model)

onnx_model = MatMulAddToMatMulIntegerAddCastMul().apply(onnx_model)
validate_onnx(onnx_model)
# NOTE: nothing changes
# converted model should have matmulinteger + rescale mul without bias add
assert [i.name for i in onnx_model.graph.initializer] == [
"x_scale",
"y_scale",
"zero_point",
"weight",
"matmul.weight_quantized",
"matmul_quant.rescale.scale",
]
assert [n.name for n in onnx_model.graph.node] == [
"input_dequant",
"weight_quant",
"weight_dequant",
"transpose",
"matmul",
"matmul_quant",
"matmul_bias_add_quant_cast",
"matmul_quant_rescale_mul",
]

0 comments on commit ce0aaea

Please sign in to comment.