Skip to content

Commit

Permalink
cleanup, came up with a better idea for a fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bogunowicz@arrival.com committed Jul 14, 2023
1 parent d2fffbd commit 6fcd3f2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,13 @@ def create_cache(

cache_parent = graph.get_node_single_parent(node, index=cache_input_idx)

if isinstance(cache_parent, NodeProto) and cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT:
if (
isinstance(cache_parent, NodeProto)
and cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT
):
while cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT:
pre_cache_input_id = cache_parent.input[0]
cache_parent = graph.get_node_single_parent(cache_parent, index= 0)
cache_parent = graph.get_node_single_parent(cache_parent, index=0)

else:
pre_cache_input_id = node.input[cache_input_idx]
Expand Down Expand Up @@ -362,19 +365,8 @@ def create_cache(
name=f"concat.{cache_input_name_concat}",
)

for _node in model.graph.node:
for input_idx, input_id in enumerate(node.input):
if input_id == pre_cache_input_id and _node.name != concat_node.name:
_node.input[input_idx] = cache_output_name_concat

graph.add_node(concat_node)

if node.op_type == "MatMulInteger":
node_parent = graph.get_node_single_parent(node, index=cache_input_idx)
node_grandparent = graph.get_node_single_parent(node_parent, index=0)
if node_parent.op_type == "QuantizeLinear" and node_grandparent.op_type == "Transpose":
pass

return concat_node, cache_input_info, cache_output_info


Expand Down
18 changes: 1 addition & 17 deletions src/sparseml/onnx/utils/graph_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
Helper functions to edit ONNX Graphs.
"""
import copy

from collections import defaultdict
from typing import Iterable, List, Optional, Union

Expand Down Expand Up @@ -242,22 +242,6 @@ def delete_unused_initializers(self):
]
) # delete inits that have no edge

def swap_nodes(self, child: NodeProto, parent: NodeProto):
"""
Given a child and parent node, swap the position of the child node
with the parent node. It is assumed that the child node has only one
parent node and the parent node has only one child node.
:param child: One of the nodes to swap
:param parent: Second node to swap
"""

# swap inputs
child.input[0], parent.input[0] = parent.input[0], child.input[0]
# swap outputs
child.output[0], parent.output[0] = parent.output[0], child.output[0]


def find_orphaned_nodes(self, node: NodeProto) -> List[NodeProto]:
"""
Given a node, that is to be removed from the graph, find all nodes that
Expand Down

0 comments on commit 6fcd3f2

Please sign in to comment.