Skip to content

Commit

Permalink
swapped transpose and quantizelienar
Browse files Browse the repository at this point in the history
  • Loading branch information
bogunowicz@arrival.com committed Jul 10, 2023
1 parent d3c6038 commit 87d03f9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ def create_cache(
cache_input_idx = 3 # QLinearMatMul B matrix is at idx 3, not 1

cache_parent = graph.get_node_single_parent(node, index=cache_input_idx)
if cache_parent.op_type == "QuantizeLinear":
cache_grandparents = [_node for _node in graph.get_node_parents(cache_parent) if isinstance(_node, NodeProto)]
if len(cache_grandparents) == 1 and cache_grandparents[0].op_type == "Transpose":
graph.swap_nodes(child=cache_parent, parent=cache_grandparents[0])
cache_parent = cache_grandparents[0]

if isinstance(cache_parent, NodeProto) and cache_parent.op_type == "Transpose":
# move cache to before a transpose if applicable
# this is due to pytorch operations potentially extracting shape values
Expand Down
18 changes: 17 additions & 1 deletion src/sparseml/onnx/utils/graph_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
Helper functions to edit ONNX Graphs.
"""

import copy
from collections import defaultdict
from typing import Iterable, List, Optional, Union

Expand Down Expand Up @@ -242,6 +242,22 @@ def delete_unused_initializers(self):
]
) # delete inits that have no edge

def swap_nodes(self, child: NodeProto, parent: NodeProto):
"""
Given a child and parent node, swap the position of the child node
with the parent node. It is assumed that the child node has only one
parent node and the parent node has only one child node.
:param child: One of the nodes to swap
:param parent: Second node to swap
"""

# swap inputs
child.input[0], parent.input[0] = parent.input[0], child.input[0]
# swap outputs
child.output[0], parent.output[0] = parent.output[0], child.output[0]


def find_orphaned_nodes(self, node: NodeProto) -> List[NodeProto]:
"""
Given a node, that is to be removed from the graph, find all nodes that
Expand Down

0 comments on commit 87d03f9

Please sign in to comment.