From 6fcd3f2a787702b070cf01ad2e19d73fcdb89cfd Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Fri, 14 Jul 2023 18:12:22 +0200 Subject: [PATCH] cleanup, came up with a better idea for a fix --- .../kv_cache/cache_keys_and_values.py | 18 +++++------------- src/sparseml/onnx/utils/graph_editor.py | 18 +----------------- 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py b/src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py index 6788723688f..955fd7a71a2 100644 --- a/src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py +++ b/src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py @@ -307,10 +307,13 @@ def create_cache( cache_parent = graph.get_node_single_parent(node, index=cache_input_idx) - if isinstance(cache_parent, NodeProto) and cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT: + if ( + isinstance(cache_parent, NodeProto) + and cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT + ): while cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT: pre_cache_input_id = cache_parent.input[0] - cache_parent = graph.get_node_single_parent(cache_parent, index= 0) + cache_parent = graph.get_node_single_parent(cache_parent, index=0) else: pre_cache_input_id = node.input[cache_input_idx] @@ -362,19 +365,8 @@ def create_cache( name=f"concat.{cache_input_name_concat}", ) - for _node in model.graph.node: - for input_idx, input_id in enumerate(node.input): - if input_id == pre_cache_input_id and _node.name != concat_node.name: - _node.input[input_idx] = cache_output_name_concat - graph.add_node(concat_node) - if node.op_type == "MatMulInteger": - node_parent = graph.get_node_single_parent(node, index=cache_input_idx) - node_grandparent = graph.get_node_single_parent(node_parent, index=0) - if node_parent.op_type == "QuantizeLinear" and node_grandparent.op_type == "Transpose": - pass - return concat_node, cache_input_info, cache_output_info diff --git a/src/sparseml/onnx/utils/graph_editor.py b/src/sparseml/onnx/utils/graph_editor.py index 3c586a90e50..1b0c2d23a12 100644 --- a/src/sparseml/onnx/utils/graph_editor.py +++ b/src/sparseml/onnx/utils/graph_editor.py @@ -15,7 +15,7 @@ """ Helper functions to edit ONNX Graphs. """ -import copy + from collections import defaultdict from typing import Iterable, List, Optional, Union @@ -242,22 +242,6 @@ def delete_unused_initializers(self): ] ) # delete inits that have no edge - def swap_nodes(self, child: NodeProto, parent: NodeProto): - """ - Given a child and parent node, swap the position of the child node - with the parent node. It is assumed that the child node has only one - parent node and the parent node has only one child node. - - :param child: One of the nodes to swap - :param parent: Second node to swap - """ - - # swap inputs - child.input[0], parent.input[0] = parent.input[0], child.input[0] - # swap outputs - child.output[0], parent.output[0] = parent.output[0], child.output[0] - - def find_orphaned_nodes(self, node: NodeProto) -> List[NodeProto]: """ Given a node, that is to be removed from the graph, find all nodes that