diff --git a/src/sparseml/exporters/transforms/kv_cache/transforms_base.py b/src/sparseml/exporters/transforms/kv_cache/transforms_base.py index 793bfc540d3..3f124ad5762 100644 --- a/src/sparseml/exporters/transforms/kv_cache/transforms_base.py +++ b/src/sparseml/exporters/transforms/kv_cache/transforms_base.py @@ -157,7 +157,6 @@ def swap_nodes_for_input( """ graph = ONNXGraph(model) - orphaned_nodes = [] for node in nodes: child_node = graph.get_node_children(node)[0] @@ -172,11 +171,7 @@ def swap_nodes_for_input( if input_name_child_node == output_to_replace: graph.update_node_input(child_node, input_name, idx) - orphaned_nodes.extend(graph.find_orphaned_nodes(node)) - - graph.delete_nodes(orphaned_nodes) - graph.update() - graph.delete_unused_initializers() + graph.delete_orphaned_node_branches() _LOGGER.info( f"Successfully swapped {len(nodes)} nodes for input '{input_name}'" diff --git a/src/sparseml/onnx/utils/graph_editor.py b/src/sparseml/onnx/utils/graph_editor.py index 1b0c2d23a12..54a8eb55f07 100644 --- a/src/sparseml/onnx/utils/graph_editor.py +++ b/src/sparseml/onnx/utils/graph_editor.py @@ -328,6 +328,48 @@ def sort_nodes_topologically(self): self._model.graph.ClearField("node") self._model.graph.node.extend(updated_node_list) + def get_orphaned_nodes( + self, + graph_output_ids: Optional[Iterable[str]] = None, + ) -> List[NodeProto]: + """ + :param graph_output_ids: iterable of output ids in graph. if not supplied, + will be read from the model + :return: list of all nodes in the graph that are not inputs to + other nodes or outputs of the graph + """ + if graph_output_ids is None: + graph_output_ids = {output.name for output in self._model.graph.output} + + orphaned_nodes = [] + for node in self.nodes: + node_is_orphaned = True + # iterate over possible output ids, in practice, there is almost + # always only 1 + for out_id in node.output: + if out_id in self._input_id_to_nodes or out_id in graph_output_ids: + node_is_orphaned = False + if node_is_orphaned: + orphaned_nodes.append(node) + return orphaned_nodes + + def delete_orphaned_node_branches(self): + """ + Deletes all nodes in the graph that are not inputs to other nodes or outputs of + the graph. Additionally deletes all nodes that would become orphaned + after the node deletion until the graph contains no orphaned nodes + """ + graph_output_ids = {output.name for output in self._model.graph.output} + orphaned_nodes = self.get_orphaned_nodes(graph_output_ids=graph_output_ids) + + while orphaned_nodes: + # no need to refresh self, delete nodes should update internal graph edges + self.delete_nodes(orphaned_nodes) + self.update() + self.delete_unused_initializers() + # update now orphaned nodes, can only run up to len(nodes) times + orphaned_nodes = self.get_orphaned_nodes(graph_output_ids=graph_output_ids) + def _store_node_edges(self, node: NodeProto): for output_id in node.output: self._output_id_to_node[output_id] = node