Skip to content

Commit

Permalink
Fix ONNX export for MobileBERT (#1539)
Browse files Browse the repository at this point in the history
* Account for the possibility of the quantized embedddings to be in int8 format (conversion to uint8 occurs later)

* Set the padding value to match to the zero point accordingly.

* Style and quality fixes
  • Loading branch information
anmarques committed Apr 26, 2023
1 parent b282b80 commit d360f12
Showing 1 changed file with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1761,7 +1761,7 @@ def _propagate_mobilebert_embedding_quantization(model: ModelProto):
continue

embedding_array = numpy_helper.to_array(embedding_initializer)
if embedding_array.dtype != numpy.uint8:
if embedding_array.dtype not in [numpy.uint8, numpy.int8]:
continue

dequant_node = graph.get_node_single_child(gather_node)
Expand Down Expand Up @@ -1805,12 +1805,15 @@ def _propagate_mobilebert_embedding_quantization(model: ModelProto):
# switch position of dequantize node
for branch_node in graph.get_node_children(dequant_node):
if branch_node.op_type == "Slice":
zero_point = graph.get_init_by_name(dequant_node.input[2])
zero_point_array = numpy_helper.to_array(zero_point)
branch_node.input[0] = gather_node.output[0]
pad_node = graph.get_node_single_child(branch_node)
pad_value = graph.get_init_by_name(pad_node.input[2])
pad_value_array = numpy_helper.to_array(pad_value)
pad_value_array = pad_value_array + 128
pad_value_array = pad_value_array.astype(numpy.uint8)
pad_value_array = (
pad_value_array.astype(zero_point_array.dtype) + zero_point_array
)
model.graph.initializer.remove(pad_value)
pad_value = numpy_helper.from_array(
pad_value_array, name=pad_value.name
Expand Down

0 comments on commit d360f12

Please sign in to comment.