Skip to content

Commit

Permalink
Make bias optional for Conv QAT conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques committed Jul 5, 2023
1 parent 8b622a4 commit 7629f5e
Showing 1 changed file with 37 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
add_quantized_conv_matmul_add_ops,
get_quantization_params,
get_structural_matches,
optional_node,
)
from sparseml.onnx.utils import ONNXGraph

Expand Down Expand Up @@ -66,7 +67,9 @@ class ConvToConvIntegerAddCastMul(OnnxTransform):

def transform(self, model: ModelProto) -> ModelProto:
graph = ONNXGraph(model)
matches = get_structural_matches(

# Nodes with bias
matches_bias = get_structural_matches(
graph,
parent_ops=[
["DequantizeLinear"],
Expand All @@ -78,20 +81,50 @@ def transform(self, model: ModelProto) -> ModelProto:
],
[
# bias should be initializer
INITIALIZER_MATCH
INITIALIZER_MATCH,
],
],
op_type="Conv",
)

# Nodes without bias
matches_no_bias = get_structural_matches(
graph,
parent_ops=[
["DequantizeLinear"],
[
# weight should be initializer
INITIALIZER_MATCH,
"QuantizeLinear",
"DequantizeLinear",
],
],
op_type="Conv",
)

matches = matches_bias
matches_names = [m.node.name for m in matches]
for match in matches_no_bias:
if match.node.name not in matches_names:
matches.append(match)

for match in matches:
self.log_match(match)
self._transform_match(graph, model, match)
return model

def _transform_match(self, graph: ONNXGraph, model: ModelProto, match: MatchResult):
def _transform_match(
self,
graph: ONNXGraph,
model: ModelProto,
match: MatchResult,
):
(input_dequant,) = match.parents[0]
weight_init, weight_quantize_node, weight_dequantize_node = match.parents[1]
(bias_init,) = match.parents[2]
if len(match.parents) == 3:
(bias_init,) = match.parents[2]
else:
bias_init = None

model = add_quantized_conv_matmul_add_ops(
model=model,
Expand Down

0 comments on commit 7629f5e

Please sign in to comment.