Skip to content

Commit

Permalink
use set for unreachable nodes to further cut runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Jun 25, 2024
1 parent 60613fe commit 4e40524
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions src/deepsparse/utils/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,39 @@ def _dfs_search_reachable_nodes(
node_output_name: str,
graph_input_names: Set[str],
reachable_nodes: Set[NodeProto],
unreachable_nodes: Set[NodeProto]
) -> None:
if node_output_name in graph_input_names:
return
for node in self.graph.node:
# check output_name first to reduce run time
if node_output_name not in node.output:
continue
if node in reachable_nodes:
continue

nodes_to_search = [
node
for node in unreachable_nodes
if node_output_name in node.output and node not in reachable_nodes
]

for node in nodes_to_search:
reachable_nodes.add(node)
unreachable_nodes.remove(node)

for node in nodes_to_search:
for name in node.input:
self._dfs_search_reachable_nodes(
name, graph_input_names, reachable_nodes
name, graph_input_names, reachable_nodes, unreachable_nodes
)

def _collect_reachable_nodes(
self,
input_names: List[str],
output_names: List[str],
) -> List[NodeProto]:
input_names = set(input_names)
reachable_nodes = set() # type: ignore
unreachable_nodes = set(self.graph.node) # type: ignore
for name in output_names:
self._dfs_search_reachable_nodes(name, set(input_names), reachable_nodes)
self._dfs_search_reachable_nodes(
name, input_names, reachable_nodes, unreachable_nodes
)
# needs to be topology sorted.
nodes = [n for n in self.graph.node if n in reachable_nodes]
return nodes
Expand Down

0 comments on commit 4e40524

Please sign in to comment.