Skip to content

Commit

Permalink
[ONNXGraph] delete_orphaned_node_branches utility function (#1694)
Browse files Browse the repository at this point in the history
* [ONNXGraph] delete_orphaned_nodes utility function

* rename to

* additionally remove appropriate initializers (constants)

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
  • Loading branch information
bfineran and bogunowicz@arrival.com committed Jul 31, 2023
1 parent 3ec2c16 commit d199188
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def swap_nodes_for_input(
"""

graph = ONNXGraph(model)
orphaned_nodes = []
for node in nodes:
child_node = graph.get_node_children(node)[0]

Expand All @@ -172,11 +171,7 @@ def swap_nodes_for_input(
if input_name_child_node == output_to_replace:
graph.update_node_input(child_node, input_name, idx)

orphaned_nodes.extend(graph.find_orphaned_nodes(node))

graph.delete_nodes(orphaned_nodes)
graph.update()
graph.delete_unused_initializers()
graph.delete_orphaned_node_branches()

_LOGGER.info(
f"Successfully swapped {len(nodes)} nodes for input '{input_name}'"
Expand Down
42 changes: 42 additions & 0 deletions src/sparseml/onnx/utils/graph_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,48 @@ def sort_nodes_topologically(self):
self._model.graph.ClearField("node")
self._model.graph.node.extend(updated_node_list)

def get_orphaned_nodes(
self,
graph_output_ids: Optional[Iterable[str]] = None,
) -> List[NodeProto]:
"""
:param graph_output_ids: iterable of output ids in graph. if not supplied,
will be read from the model
:return: list of all nodes in the graph that are not inputs to
other nodes or outputs of the graph
"""
if graph_output_ids is None:
graph_output_ids = {output.name for output in self._model.graph.output}

orphaned_nodes = []
for node in self.nodes:
node_is_orphaned = True
# iterate over possible output ids, in practice, there is almost
# always only 1
for out_id in node.output:
if out_id in self._input_id_to_nodes or out_id in graph_output_ids:
node_is_orphaned = False
if node_is_orphaned:
orphaned_nodes.append(node)
return orphaned_nodes

def delete_orphaned_node_branches(self):
"""
Deletes all nodes in the graph that are not inputs to other nodes or outputs of
the graph. Additionally deletes all nodes that would become orphaned
after the node deletion until the graph contains no orphaned nodes
"""
graph_output_ids = {output.name for output in self._model.graph.output}
orphaned_nodes = self.get_orphaned_nodes(graph_output_ids=graph_output_ids)

while orphaned_nodes:
# no need to refresh self, delete nodes should update internal graph edges
self.delete_nodes(orphaned_nodes)
self.update()
self.delete_unused_initializers()
# update now orphaned nodes, can only run up to len(nodes) times
orphaned_nodes = self.get_orphaned_nodes(graph_output_ids=graph_output_ids)

def _store_node_edges(self, node: NodeProto):
for output_id in node.output:
self._output_id_to_node[output_id] = node
Expand Down

0 comments on commit d199188

Please sign in to comment.