diff --git a/src/deepsparse/utils/extractor.py b/src/deepsparse/utils/extractor.py index 0ccb4b10b3..e3253ac7ce 100644 --- a/src/deepsparse/utils/extractor.py +++ b/src/deepsparse/utils/extractor.py @@ -86,19 +86,25 @@ def _dfs_search_reachable_nodes( node_output_name: str, graph_input_names: Set[str], reachable_nodes: Set[NodeProto], + unreachable_nodes: Set[NodeProto] ) -> None: if node_output_name in graph_input_names: return - for node in self.graph.node: - # check output_name first to reduce run time - if node_output_name not in node.output: - continue - if node in reachable_nodes: - continue + + nodes_to_search = [ + node + for node in unreachable_nodes + if node_output_name in node.output and node not in reachable_nodes + ] + + for node in nodes_to_search: reachable_nodes.add(node) + unreachable_nodes.remove(node) + + for node in nodes_to_search: for name in node.input: self._dfs_search_reachable_nodes( - name, graph_input_names, reachable_nodes + name, graph_input_names, reachable_nodes, unreachable_nodes ) def _collect_reachable_nodes( @@ -106,9 +112,13 @@ def _collect_reachable_nodes( input_names: List[str], output_names: List[str], ) -> List[NodeProto]: + input_names = set(input_names) reachable_nodes = set() # type: ignore + unreachable_nodes = set(self.graph.node) # type: ignore for name in output_names: - self._dfs_search_reachable_nodes(name, set(input_names), reachable_nodes) + self._dfs_search_reachable_nodes( + name, input_names, reachable_nodes, unreachable_nodes + ) # needs to be topology sorted. nodes = [n for n in self.graph.node if n in reachable_nodes] return nodes