Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert OPT MatMuls with quantized inputs to MatMulInteger #1585

Merged
merged 10 commits into from
Jun 8, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,195 @@ def _convert_quantizable_gemm(
remove_node_and_params_from_graph(model, gemm_node)


def _convert_quantizable_matmul(model: ModelProto):
def _convert_quantizable_matmuls_with_nonquantized_outputs(model: ModelProto):
"""
A pass for converting a MatMul into a quantized representation
This MatMul is the result of quantizing native torch.matmul using QATMatMul
A pass for converting a MatMul with quantized inputs into
a MatMulInteger

| Starting with:
| INPUT_0 INPUT_1
| | |
| QuantizeLinear QuantizeLinear
| | |
| DequantizeLinear DequantizeLinear
| | |
| MatMul
| |
| Add (optional)
| |
| OUTPUT
| We end up converting to:
| INPUT_0 INPUT_1
| | |
| QuantizeLinear QuantizeLinear
| | |
| | |
| MatMulInteger
| |
| Add (Optional)
| |
| Cast (Int32 --> FP32)
| |
| Mul
| |
| OUTPUT
"""

conversion_count = 0
matmul_nodes = [n for n in model.graph.node if n.op_type in ["MatMul"]]
graph = ONNXGraph(model)
for matmul_node in matmul_nodes:
#############
# Matching
#############

input_dequantize_nodes = [
graph.get_node_single_parent(matmul_node, i) for i in range(2)
]

# Make sure these input nodes are DequantizeLinear
if numpy.any(
[
(node is None or node.op_type != "DequantizeLinear")
for node in input_dequantize_nodes
]
):
continue

# Make sure their parents are QuantizeLinear
parents = [
graph.get_node_single_parent(node, 0) for node in input_dequantize_nodes
]
if numpy.any(
[
(parent is None or parent.op_type != "QuantizeLinear")
for parent in parents
]
):
continue

_LOGGER.debug(f"Matched quantizable MatMul: {matmul_node.name}")

# Create MatMulInteger node
node_0, node_1 = input_dequantize_nodes

input_nodes = [
node_0.input[0], # a
node_1.input[0], # b
node_0.input[2], # a_zero_point
node_1.input[2], # b_zero_point
]

matmul_int_op_node = onnx.helper.make_node(
"MatMulInteger",
input_nodes,
[f"{matmul_node.name}_quant_out"],
f"{matmul_node.name}_quant",
)
model.graph.node.append(matmul_int_op_node)

node_0_parameters = get_quantization_params(model, node_0)
node_1_parameters = get_quantization_params(model, node_1)

output_scale = node_0_parameters.scale * node_1_parameters.scale

has_bias = False

# Check if is followed by Add node (bias)
bias_add_node = graph.get_node_single_child(matmul_node)
if bias_add_node is not None and bias_add_node.op_type == "Add":
bias_initializer = get_init_by_name(
model, bias_add_node.input[1]
) or get_init_by_name(model, bias_add_node.input[0])
if bias_initializer is not None:
# check if bias is finite
bias_initializer = numpy_helper.to_array(bias_initializer)
if numpy.all(numpy.isfinite(bias_initializer)):
# Create initializer for quantized bias
quantized_bias_initializer_name = f"{bias_initializer.name}_quant"
has_bias = True

bias_zero_point = 0
quantized_bias = _quantize_array(
bias_initializer,
output_scale,
bias_zero_point,
dtype=numpy.int32,
)
quantized_bias_initializer = numpy_helper.from_array(
quantized_bias,
name=quantized_bias_initializer_name,
)
model.graph.initializer.append(quantized_bias_initializer)

# Create new Add node for quantized bias
quantized_add_node_name = f"{bias_add_node.name}_quant"
quantized_add_node = onnx.helper.make_node(
"Add",
[matmul_int_op_node.output[0], quantized_bias_initializer_name],
[f"{quantized_add_node_name}_output"],
quantized_add_node_name,
)
model.graph.node.append(quantized_add_node)

# Casting MatMulInteger INT32 output to FP32

cast_node_name = f"{matmul_node.name}_cast"
cast_node_input = (
quantized_add_node.output if has_bias else matmul_int_op_node.output
)
cast_node = onnx.helper.make_node(
"Cast",
cast_node_input,
[f"{cast_node_name}_output"],
cast_node_name,
to=getattr(onnx.TensorProto, "FLOAT"), # get Float32 enum id
)
model.graph.node.append(cast_node)

output_scale_initializer_name = f"{matmul_node.name}.output_scale"
model.graph.initializer.append(
numpy_helper.from_array(
numpy.asarray(output_scale),
name=output_scale_initializer_name,
)
)

mul_node_output = bias_add_node.output if has_bias else matmul_node.output
mul_node = onnx.helper.make_node(
"Mul",
[cast_node.output[0], output_scale_initializer_name],
mul_node_output,
f"{matmul_node.name}_scale",
)
model.graph.node.append(mul_node)

for node in input_dequantize_nodes:
delete_quant_node(model, node)

# delete original MatMul node
remove_node_and_params_from_graph(model, matmul_node)

# delete original Add node
if has_bias:
remove_node_and_params_from_graph(model, bias_add_node)

conversion_count += 1

if matmul_nodes:
_LOGGER.info(
f"Converted {conversion_count} quantizable MatMul "
"(A8A8 inputs, FP output) ops to MatMulInteger"
)
graph = ONNXGraph(model)
graph.delete_unused_initializers()


def _convert_quantizable_matmul_with_quantized_outputs(model: ModelProto):
"""
A pass for converting a MatMul with quantized inputs and outputs into
a QLinearMatMul. This MatMul is the result of quantizing native
torch.matmul using QATMatMul

| Starting with:
| INPUT_0 INPUT_1
Expand Down Expand Up @@ -732,9 +917,17 @@ def _convert_quantizable_matmul(model: ModelProto):

if matmul_nodes:
_LOGGER.info(
f"Converted {conversion_count} quantizable MatMul ops " "to QLinearMatMul"
f"Converted {conversion_count} quantizable MatMul with quantized outputs "
"to QLinearMatMul"
)

return conversion_count


def _convert_quantizable_matmul(model: ModelProto):
_convert_quantizable_matmul_with_quantized_outputs(model)
_convert_quantizable_matmuls_with_nonquantized_outputs(model)


def _add_quantized_conv_matmul_add_ops(
model: ModelProto,
Expand Down
Loading