Skip to content

Commit

Permalink
handle optional input in quant topo sort (#7223)
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee authored Apr 3, 2021
1 parent 59b57d8 commit 8d737f9
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/quantization/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,14 @@ def topological_sort(self):
deps_count = [0]*len(self.nodes()) # dependency count of each node
deps_to_nodes = {} # input to node indice
for node_idx, node in enumerate(self.nodes()):
deps_count[node_idx] = len(node.input)
# CANNOT use len(node.input) directly because input can be optional
deps_count[node_idx] = sum(1 for _ in node.input if _ )
for input_name in node.input:
if input_name not in deps_to_nodes:
deps_to_nodes[input_name] = [node_idx]
else:
deps_to_nodes[input_name].append(node_idx)


# initialize sorted_nodes
sorted_nodes = []
for input in itertools.chain(self.initializer(), self.model.graph.input):
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/test/python/quantization/op_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import onnx
import numpy as np
from six import string_types
import onnxruntime
from pathlib import Path
from onnxruntime.quantization import CalibrationDataReader
Expand All @@ -20,8 +21,12 @@ def get_next(self):
def rewind(self):
self.iter_next = iter(self.data_feeds)

def check_op_type_order(testcase, model_path, ops):
model = onnx.load(Path(model_path))
def check_op_type_order(testcase, model_to_check, ops):
if isinstance(model_to_check, string_types):
model = onnx.load(model_to_check)
elif isinstance(model_to_check, onnx.ModelProto):
model = model_to_check

testcase.assertEqual(len(ops), len(model.graph.node), 'op count is not same')
for node_idx, node in enumerate(model.graph.node):
testcase.assertEqual(
Expand Down
77 changes: 77 additions & 0 deletions onnxruntime/test/python/quantization/test_onnx_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import unittest
import onnx
import onnxruntime
import numpy as np
from onnx import helper, TensorProto, numpy_helper
from onnxruntime.quantization.onnx_model import ONNXModel
from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_type_order


def generate_input_initializer(tensor_shape, tensor_dtype, input_name):
'''
Helper function to generate initializers for test inputs
'''
tensor = np.random.normal(0, 0.3, tensor_shape).astype(tensor_dtype)
init = numpy_helper.from_array(tensor, input_name)
return init

class TestONNXModel(unittest.TestCase):
def construct_model(self, model_path):
# (input)
# |
# GRU
# / \
# Conv(1) \
# | \
# Relu Conv(2)
# | |
# \ /
# Add
# |
# (output)
initializers = []
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 8, 12])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [4, 2, 8, 8])

# make GRU
initializers.append(generate_input_initializer([2, 24, 12], np.float32, 'W_GRU'))
initializers.append(generate_input_initializer([2, 24, 8], np.float32, 'R_GRU'))
initializers.append(generate_input_initializer([2, 8, 8], np.float32, 'H_GRU'))
gru_node = onnx.helper.make_node(
'GRU',
['input', 'W_GRU', 'R_GRU', '', '', 'H_GRU'],
['GRU_O'],
hidden_size = 8,
direction = 'bidirectional')

initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, 'W1'))
initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, 'W2'))
initializers.append(generate_input_initializer([2], np.float32, 'B1'))
initializers.append(generate_input_initializer([2], np.float32, 'B2'))
conv_node_1 = onnx.helper.make_node('Conv', ['GRU_O', 'W1', 'B1'], ['Conv1_O'], name='Conv1')
conv_node_2 = onnx.helper.make_node('Conv', ['GRU_O', 'W2', 'B2'], ['Conv2_O'], name='Conv2')
relu_node = onnx.helper.make_node('Relu', ['Conv1_O'], ['Relu_O'], name='Relu')
add_node = onnx.helper.make_node('Add', ['Relu_O', 'Conv2_O'], ['output'], name='Add')
graph = helper.make_graph([conv_node_1, relu_node, conv_node_2, gru_node, add_node],
'onnx_model_test', [input], [output], initializer=initializers)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
onnx.save(model, model_path)

def test_topo_sort(self):
test_model_path = 'onnx_model_topo_sort.onnx'
self.construct_model(test_model_path)
onnx_model = ONNXModel(onnx.load(test_model_path))
check_op_type_order(self, onnx_model.model, ['Conv', 'Relu', 'Conv', 'GRU', 'Add'])
onnx_model.topological_sort()
check_op_type_order(self, onnx_model.model, ['GRU', 'Conv', 'Conv', 'Relu', 'Add'])

if __name__ == '__main__':
unittest.main()

0 comments on commit 8d737f9

Please sign in to comment.