Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert OPT MatMuls with quantized inputs to MatMulInteger #1585

Merged
merged 10 commits into from
Jun 8, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,119 @@ def _convert_quantizable_gemm(
remove_node_and_params_from_graph(model, gemm_node)


def _convert_quantizable_matmul(model: ModelProto):
def _convert_quantizable_matmuls_with_nonquantized_outputs(model: ModelProto):
"""
A pass for converting a MatMul into a quantized representation
This MatMul is the result of quantizing native torch.matmul using QATMatMul
A pass for converting a MatMul with quantized inputs into
a MatMulInteger

| Starting with:
| INPUT_0 INPUT_1
| | |
| QuantizeLinear QuantizeLinear
| | |
| DequantizeLinear DequantizeLinear
| | |
| MatMul
| |
| OUTPUT
| We end up converting to:
| INPUT_0 INPUT_1
| | |
| QuantizeLinear QuantizeLinear
| | |
| | |
| MatMulInteger
| |
| Cast (Int32 --> FP32)
| |
| OUTPUT
"""

conversion_count = 0
matmul_nodes = [n for n in model.graph.node if n.op_type in ["MatMul"]]
graph = ONNXGraph(model)
for matmul_node in matmul_nodes:
#############
# Matching
#############

input_dequantize_nodes = [
graph.get_node_single_parent(matmul_node, i) for i in range(2)
]

# Make sure these input nodes are DequantizeLinear
if numpy.any(
[
(node is None or node.op_type != "DequantizeLinear")
for node in input_dequantize_nodes
]
):
continue

# Make sure their parents are QuantizeLinear
parents = [
graph.get_node_single_parent(node, 0) for node in input_dequantize_nodes
]
if numpy.any(
[
(parent is None or parent.op_type != "QuantizeLinear")
for parent in parents
]
):
continue

_LOGGER.debug(f"Matched quantizable MatMul: {matmul_node.name}")

# Create MatMulInteger node
node_0, node_1 = input_dequantize_nodes

input_nodes = [
node_0.input[0], # a
node_1.input[0], # b
node_0.input[2], # a_zero_point
node_1.input[2], # b_zero_point
]
cast_node_name = f"{matmul_node.name}_cast"
matmul_int_op_node = onnx.helper.make_node(
"MatMulInteger",
input_nodes,
[cast_node_name],
f"{matmul_node.name}_quant",
)
model.graph.node.append(matmul_int_op_node)

# Casting MatMulInteger INT32 output to FP32

cast_node = onnx.helper.make_node(
"Cast",
[matmul_int_op_node.output[0]],
[matmul_node.output[0]],
cast_node_name,
to=getattr(onnx.TensorProto, "FLOAT"), # get Float32 enum id
)
model.graph.node.append(cast_node)

for node in input_dequantize_nodes:
delete_quant_node(model, node)

remove_node_and_params_from_graph(model, matmul_node)

conversion_count += 1

if matmul_nodes:
_LOGGER.info(
f"Converted {conversion_count} quantizable MatMul "
"(A8A8 inputs, FP output) ops to MatMulInteger"
)
graph = ONNXGraph(model)
graph.delete_unused_initializers()


def _convert_quantizable_matmul_with_quantized_outputs(model: ModelProto):
"""
A pass for converting a MatMul with quantized inputs and outputs into
a QLinearMatMul. This MatMul is the result of quantizing native
torch.matmul using QATMatMul

| Starting with:
| INPUT_0 INPUT_1
Expand Down Expand Up @@ -736,6 +845,12 @@ def _convert_quantizable_matmul(model: ModelProto):
)


def _convert_quantizable_matmul(model: ModelProto):
conversion_count = _convert_quantizable_matmul_with_quantized_outputs(model)
if conversion_count == 0:
anmarques marked this conversation as resolved.
Show resolved Hide resolved
_convert_quantizable_matmuls_with_nonquantized_outputs(model)


def _add_quantized_conv_matmul_add_ops(
model: ModelProto,
node: NodeProto,
Expand Down
Loading