Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Jul 5, 2024
1 parent 45e2238 commit 1d12193
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions src/deepsparse/utils/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,30 @@ def _dfs_search_reachable_nodes(
reachable: Set[int],
unreachable: Set[int],
) -> None:
"""
Helper function to find nodes which are connected to an output
:param node_output_name: The name of the output
:param graph_input_names: The names of all inputs of the graph
:param nodes: The list of all nodes of the graph
:param reachable: The set of indexes to reachable nodes in `nodes`
:param unreachable: The set of indexes to unreachable nodes in `nodes`
"""
# finish search at inputs
if node_output_name in graph_input_names:
return

# find nodes connected to this output
nodes_to_search = [
index for index in unreachable if node_output_name in nodes[index].output
]

# add nodes connected to this output to sets
for node_index in nodes_to_search:
reachable.add(node_index)
unreachable.remove(node_index)

# recurse on inputs
for node_index in nodes_to_search:
for name in nodes[node_index].input:
self._dfs_search_reachable_nodes(
Expand All @@ -110,18 +123,17 @@ def _collect_reachable_nodes(
self,
input_names: List[str],
output_names: List[str],
) -> List[NodeProto]:
input_names = set(input_names)
nodes = [node for node in self.graph.node]
reachable = set()
unreachable = set(range(len(nodes)))
) -> list[NodeProto]:
_input_names = set(input_names)
nodes = list(self.graph.node)
reachable: Set[int] = set()
unreachable: Set[int] = set(range(len(nodes)))
for name in output_names:
self._dfs_search_reachable_nodes(
name, input_names, nodes, reachable, unreachable
name, _input_names, nodes, reachable, unreachable
)
# needs to be topologically sorted
reachable = sorted(list(reachable))
nodes = [nodes[node_index] for node_index in reachable]
nodes = [nodes[node_index] for node_index in sorted(reachable)]
return nodes

def _collect_referred_local_functions(
Expand Down

0 comments on commit 1d12193

Please sign in to comment.