Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNXGraph] delete_orphaned_node_branches utility function #1694

Merged
merged 3 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def swap_nodes_for_input(
"""

graph = ONNXGraph(model)
orphaned_nodes = []
for node in nodes:
child_node = graph.get_node_children(node)[0]

Expand All @@ -172,11 +171,7 @@ def swap_nodes_for_input(
if input_name_child_node == output_to_replace:
graph.update_node_input(child_node, input_name, idx)

orphaned_nodes.extend(graph.find_orphaned_nodes(node))

graph.delete_nodes(orphaned_nodes)
graph.update()
graph.delete_unused_initializers()
graph.delete_orphaned_node_branches()

_LOGGER.info(
f"Successfully swapped {len(nodes)} nodes for input '{input_name}'"
Expand Down
42 changes: 42 additions & 0 deletions src/sparseml/onnx/utils/graph_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,48 @@ def sort_nodes_topologically(self):
self._model.graph.ClearField("node")
self._model.graph.node.extend(updated_node_list)

def get_orphaned_nodes(
self,
graph_output_ids: Optional[Iterable[str]] = None,
) -> List[NodeProto]:
"""
:param graph_output_ids: iterable of output ids in graph. if not supplied,
will be read from the model
:return: list of all nodes in the graph that are not inputs to
other nodes or outputs of the graph
"""
if graph_output_ids is None:
graph_output_ids = {output.name for output in self._model.graph.output}

orphaned_nodes = []
for node in self.nodes:
node_is_orphaned = True
# iterate over possible output ids, in practice, there is almost
# always only 1
for out_id in node.output:
if out_id in self._input_id_to_nodes or out_id in graph_output_ids:
node_is_orphaned = False
if node_is_orphaned:
orphaned_nodes.append(node)
return orphaned_nodes

def delete_orphaned_node_branches(self):
"""
Deletes all nodes in the graph that are not inputs to other nodes or outputs of
the graph. Additionally deletes all nodes that would become orphaned
after the node deletion until the graph contains no orphaned nodes
"""
graph_output_ids = {output.name for output in self._model.graph.output}
orphaned_nodes = self.get_orphaned_nodes(graph_output_ids=graph_output_ids)

while orphaned_nodes:
# no need to refresh self, delete nodes should update internal graph edges
self.delete_nodes(orphaned_nodes)
self.update()
self.delete_unused_initializers()
# update now orphaned nodes, can only run up to len(nodes) times
orphaned_nodes = self.get_orphaned_nodes(graph_output_ids=graph_output_ids)

def _store_node_edges(self, node: NodeProto):
for output_id in node.output:
self._output_id_to_node[output_id] = node
Expand Down
Loading