diff --git a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py index a96c84c55de..fc67b88d2df 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py @@ -568,10 +568,195 @@ def _convert_quantizable_gemm( remove_node_and_params_from_graph(model, gemm_node) -def _convert_quantizable_matmul(model: ModelProto): +def _convert_quantizable_matmuls_with_nonquantized_outputs(model: ModelProto): """ - A pass for converting a MatMul into a quantized representation - This MatMul is the result of quantizing native torch.matmul using QATMatMul + A pass for converting a MatMul with quantized inputs into + a MatMulInteger + + | Starting with: + | INPUT_0 INPUT_1 + | | | + | QuantizeLinear QuantizeLinear + | | | + | DequantizeLinear DequantizeLinear + | | | + | MatMul + | | + | Add (optional) + | | + | OUTPUT + | We end up converting to: + | INPUT_0 INPUT_1 + | | | + | QuantizeLinear QuantizeLinear + | | | + | | | + | MatMulInteger + | | + | Add (Optional) + | | + | Cast (Int32 --> FP32) + | | + | Mul + | | + | OUTPUT + """ + + conversion_count = 0 + matmul_nodes = [n for n in model.graph.node if n.op_type in ["MatMul"]] + graph = ONNXGraph(model) + for matmul_node in matmul_nodes: + ############# + # Matching + ############# + + input_dequantize_nodes = [ + graph.get_node_single_parent(matmul_node, i) for i in range(2) + ] + + # Make sure these input nodes are DequantizeLinear + if numpy.any( + [ + (node is None or node.op_type != "DequantizeLinear") + for node in input_dequantize_nodes + ] + ): + continue + + # Make sure their parents are QuantizeLinear + parents = [ + graph.get_node_single_parent(node, 0) for node in input_dequantize_nodes + ] + if numpy.any( + [ + (parent is None or parent.op_type != "QuantizeLinear") + for parent in parents + ] + ): + continue + + _LOGGER.debug(f"Matched quantizable MatMul: {matmul_node.name}") + + # Create MatMulInteger node + node_0, node_1 = input_dequantize_nodes + + input_nodes = [ + node_0.input[0], # a + node_1.input[0], # b + node_0.input[2], # a_zero_point + node_1.input[2], # b_zero_point + ] + + matmul_int_op_node = onnx.helper.make_node( + "MatMulInteger", + input_nodes, + [f"{matmul_node.name}_quant_out"], + f"{matmul_node.name}_quant", + ) + model.graph.node.append(matmul_int_op_node) + + node_0_parameters = get_quantization_params(model, node_0) + node_1_parameters = get_quantization_params(model, node_1) + + output_scale = node_0_parameters.scale * node_1_parameters.scale + + has_bias = False + + # Check if is followed by Add node (bias) + bias_add_node = graph.get_node_single_child(matmul_node) + if bias_add_node is not None and bias_add_node.op_type == "Add": + bias_initializer = get_init_by_name( + model, bias_add_node.input[1] + ) or get_init_by_name(model, bias_add_node.input[0]) + if bias_initializer is not None: + # check if bias is finite + bias_initializer = numpy_helper.to_array(bias_initializer) + if numpy.all(numpy.isfinite(bias_initializer)): + # Create initializer for quantized bias + quantized_bias_initializer_name = f"{bias_initializer.name}_quant" + has_bias = True + + bias_zero_point = 0 + quantized_bias = _quantize_array( + bias_initializer, + output_scale, + bias_zero_point, + dtype=numpy.int32, + ) + quantized_bias_initializer = numpy_helper.from_array( + quantized_bias, + name=quantized_bias_initializer_name, + ) + model.graph.initializer.append(quantized_bias_initializer) + + # Create new Add node for quantized bias + quantized_add_node_name = f"{bias_add_node.name}_quant" + quantized_add_node = onnx.helper.make_node( + "Add", + [matmul_int_op_node.output[0], quantized_bias_initializer_name], + [f"{quantized_add_node_name}_output"], + quantized_add_node_name, + ) + model.graph.node.append(quantized_add_node) + + # Casting MatMulInteger INT32 output to FP32 + + cast_node_name = f"{matmul_node.name}_cast" + cast_node_input = ( + quantized_add_node.output if has_bias else matmul_int_op_node.output + ) + cast_node = onnx.helper.make_node( + "Cast", + cast_node_input, + [f"{cast_node_name}_output"], + cast_node_name, + to=getattr(onnx.TensorProto, "FLOAT"), # get Float32 enum id + ) + model.graph.node.append(cast_node) + + output_scale_initializer_name = f"{matmul_node.name}.output_scale" + model.graph.initializer.append( + numpy_helper.from_array( + numpy.asarray(output_scale), + name=output_scale_initializer_name, + ) + ) + + mul_node_output = bias_add_node.output if has_bias else matmul_node.output + mul_node = onnx.helper.make_node( + "Mul", + [cast_node.output[0], output_scale_initializer_name], + mul_node_output, + f"{matmul_node.name}_scale", + ) + model.graph.node.append(mul_node) + + for node in input_dequantize_nodes: + delete_quant_node(model, node) + + # delete original MatMul node + remove_node_and_params_from_graph(model, matmul_node) + + # delete original Add node + if has_bias: + remove_node_and_params_from_graph(model, bias_add_node) + + conversion_count += 1 + + if matmul_nodes: + _LOGGER.info( + f"Converted {conversion_count} quantizable MatMul " + "(A8A8 inputs, FP output) ops to MatMulInteger" + ) + graph = ONNXGraph(model) + graph.delete_unused_initializers() + + +def _convert_quantizable_matmul_with_quantized_outputs(model: ModelProto): + """ + A pass for converting a MatMul with quantized inputs and outputs into + a QLinearMatMul. This MatMul is the result of quantizing native + torch.matmul using QATMatMul | Starting with: | INPUT_0 INPUT_1 @@ -732,9 +917,17 @@ def _convert_quantizable_matmul(model: ModelProto): if matmul_nodes: _LOGGER.info( - f"Converted {conversion_count} quantizable MatMul ops " "to QLinearMatMul" + f"Converted {conversion_count} quantizable MatMul with quantized outputs " + "to QLinearMatMul" ) + return conversion_count + + +def _convert_quantizable_matmul(model: ModelProto): + _convert_quantizable_matmul_with_quantized_outputs(model) + _convert_quantizable_matmuls_with_nonquantized_outputs(model) + def _add_quantized_conv_matmul_add_ops( model: ModelProto,