From 61a20977ada83285572d0cafa21520527e9a0daf Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Tue, 6 Jun 2023 07:02:40 -0400 Subject: [PATCH] Set bias as optional for convolution folding. Needed for CLIP (#1581) * Set bias as optional for convolution folding. Needed for CLIP * Quality fixes * Merge --- .../quantization/quantize_qat_export.py | 136 +++++++++--------- 1 file changed, 68 insertions(+), 68 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py index 8fa2f22e7b9..a96c84c55de 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py @@ -743,10 +743,10 @@ def _add_quantized_conv_matmul_add_ops( weight_quantize_node: NodeProto, input_quantize_params: QuantizationParams, weight_quantize_params: QuantizationParams, - bias_initializer: onnx.TensorProto, - bias_add_name: str, target_output: str, transpose_weight: bool, + bias_add_name: str, + bias_initializer: Optional[onnx.TensorProto] = None, output_quantize_node: Optional[NodeProto] = None, output_dequantize_node: Optional[NodeProto] = None, ): @@ -806,65 +806,62 @@ def _add_quantized_conv_matmul_add_ops( ) model.graph.node.append(integer_op_node) + output_scale = input_quantize_params.scale * weight_quantize_params.scale + output_scale_name = "{}_output.scale".format(node.name) + model.graph.initializer.append( + numpy_helper.from_array(numpy.asarray(output_scale), name=output_scale_name) + ) + + last_output = integer_op_output + # Add bias + zero point correction # quantize bias - bias_initializer = numpy_helper.to_array(bias_initializer) - bias_scale = input_quantize_params.scale * weight_quantize_params.scale - bias_zero_point = 0 - quantized_bias = _quantize_array( - bias_initializer, bias_scale, bias_zero_point, dtype=numpy.int32 - ) - if node.op_type == "Conv" and len(quantized_bias.shape) == 1: - # reshape for bias add broadcasting - quantized_bias = quantized_bias.reshape(1, quantized_bias.shape[0], 1, 1) + if bias_initializer is not None: + bias_initializer = numpy_helper.to_array(bias_initializer) - quantized_bias_name = "{}.bias_quantized".format(bias_add_name) - quantized_bias_initializer = numpy_helper.from_array( - quantized_bias, name=quantized_bias_name - ) - model.graph.initializer.append(quantized_bias_initializer) - quantized_bias_scale_name = "{}.scale".format(quantized_bias_name) - model.graph.initializer.append( - numpy_helper.from_array( - numpy.asarray(bias_scale), name=quantized_bias_scale_name + bias_zero_point = 0 + quantized_bias = _quantize_array( + bias_initializer, output_scale, bias_zero_point, dtype=numpy.int32 ) - ) - quantized_bias_zero_point_name = "{}.zero_point".format(quantized_bias_name) - model.graph.initializer.append( - numpy_helper.from_array( - numpy.asarray(bias_zero_point, dtype=numpy.uint8), - name=quantized_bias_zero_point_name, + if node.op_type == "Conv" and len(quantized_bias.shape) == 1: + # reshape for bias add broadcasting + quantized_bias = quantized_bias.reshape(1, quantized_bias.shape[0], 1, 1) + + quantized_bias_name = "{}.bias_quantized".format(bias_add_name) + quantized_bias_initializer = numpy_helper.from_array( + quantized_bias, name=quantized_bias_name ) - ) + model.graph.initializer.append(quantized_bias_initializer) - # get INT32 Add inputs and outputs - quant_add_inputs = [ - integer_op_output, # MatMul/Conv integer outputs (INT32) - quantized_bias_name, # Quantized bias (INT32) - ] + # get INT32 Add inputs and outputs + quant_add_inputs = [ + last_output, # MatMul/Conv integer outputs (INT32) + quantized_bias_name, # Quantized bias (INT32) + ] - quant_add_name = "{}_bias_add_quant".format(node.name) - quant_add_output = ( - output_quantize_node.output[0] - if output_quantize_node - else f"{quant_add_name}_output" - ) + quant_add_name = "{}_bias_add_quant".format(node.name) + quant_add_output = ( + output_quantize_node.output[0] + if output_quantize_node + else f"{quant_add_name}_output" + ) - # create Add node and add it to graph - qadd_node = onnx.helper.make_node( - "Add", - quant_add_inputs, - [quant_add_output], - quant_add_name, - ) - model.graph.node.append(qadd_node) + # create Add node and add it to graph + qadd_node = onnx.helper.make_node( + "Add", + quant_add_inputs, + [quant_add_output], + quant_add_name, + ) + model.graph.node.append(qadd_node) + last_output = quant_add_output # create Cast node and add it to graph - cast_node_name = "{}_cast".format(quant_add_name) - cast_node_output = "{}_cast".format(quant_add_output) + cast_node_name = "{}_cast".format(node.name) + cast_node_output = "{}_output".format(cast_node_name) cast_node = onnx.helper.make_node( "Cast", - [quant_add_output], + [last_output], [cast_node_output], cast_node_name, to=getattr(onnx.TensorProto, "FLOAT"), # get Float32 enum id @@ -874,9 +871,9 @@ def _add_quantized_conv_matmul_add_ops( # create Mul node for rescale mul_node_inputs = [ cast_node_output, # a - quantized_bias_scale_name, # b -> rescale factor + output_scale_name, # b -> rescale factor ] - mul_node_name = "{}_rescale_mul".format(quant_add_name) + mul_node_name = "{}_rescale_mul".format(cast_node_name) mul_node = onnx.helper.make_node( "Mul", mul_node_inputs, @@ -979,10 +976,10 @@ def _convert_quantizable_gemm_no_activations(model: ModelProto): weight_quantize_node=weight_quantize_node, input_quantize_params=input_quantize_params, weight_quantize_params=weight_quantize_params, - bias_initializer=bias_initializer, - bias_add_name="{}_bias_add".format(gemm_node.name), target_output=gemm_node.output[0], transpose_weight=transpose_weight, + bias_add_name="{}_bias_add".format(gemm_node.name), + bias_initializer=bias_initializer, ) # Cleanup @@ -1108,14 +1105,14 @@ def _convert_quantizable_matmul_and_add(model: ModelProto): weight_quantize_node=weight_quantize_node, input_quantize_params=input_quantize_params, weight_quantize_params=weight_quantize_params, - bias_initializer=bias_initializer, - bias_add_name=bias_add_node.name, target_output=( output_dequantize_node.output[0] if output_dequantize_node else bias_add_node.output[0] ), transpose_weight=True, + bias_add_name=bias_add_node.name, + bias_initializer=bias_initializer, output_quantize_node=output_quantize_node, output_dequantize_node=output_dequantize_node, ) @@ -1164,7 +1161,7 @@ def _convert_quantizable_conv_integer(model: ModelProto): | | | | DequantizeLinear | | | | - | Conv (with bias) + | Conv (with optional bias) | | | OUTPUT | We end up converting to: @@ -1174,7 +1171,7 @@ def _convert_quantizable_conv_integer(model: ModelProto): | | | ConvInteger (with constant uint8 kernel) | | - | Add (constant bias + zero point correction) + | Add (optional, constant bias + zero point correction) | | | Cast (INT32 -> FP32) | | @@ -1187,10 +1184,10 @@ def _convert_quantizable_conv_integer(model: ModelProto): conv_nodes = [n for n in model.graph.node if n.op_type in ["Conv"]] orig_conv_weight_name_to_node_ids = defaultdict(list) for conv_node in conv_nodes: - if len(conv_node.input) != 3: - # this function currently only converts Conv nodes with bias param - # (i.e. from folded batch norm value) - continue + # if len(conv_node.input) != 3: + # # this function currently only converts Conv nodes with bias param + # # (i.e. from folded batch norm value) + # continue graph = ONNXGraph(model) @@ -1226,12 +1223,15 @@ def _convert_quantizable_conv_integer(model: ModelProto): if input_quantize_node.op_type != "DequantizeLinear": continue - bias_initializer = graph.get_init_by_name(conv_node.input[2]) - if bias_initializer is None: - _LOGGER.debug(f"Unable to find bias initializer: {conv_node.input[2]}") - continue + if len(conv_node.input) == 3: + bias_initializer = graph.get_init_by_name(conv_node.input[2]) + else: + bias_initializer = None - _LOGGER.debug(f"Matched quantizable Conv weight and bias: {conv_node.name}") + if bias_initializer is None: + _LOGGER.debug(f"Matched quantizable Conv weight: {conv_node.name}") + else: + _LOGGER.debug(f"Matched quantizable Conv weight and bias: {conv_node.name}") # Conversion _add_quantized_conv_matmul_add_ops( @@ -1241,10 +1241,10 @@ def _convert_quantizable_conv_integer(model: ModelProto): weight_quantize_node=weight_quantize_node, input_quantize_params=input_quantize_params, weight_quantize_params=weight_quantize_params, - bias_initializer=bias_initializer, - bias_add_name="{}_bias_add".format(conv_node.name), target_output=conv_node.output[0], transpose_weight=False, + bias_add_name="{}_bias_add".format(conv_node.name), + bias_initializer=bias_initializer, ) orig_conv_weight_name_to_node_ids[input_quantize_node.input[0]].append( "{}_quant".format(conv_node.output[0])