Skip to content

Commit

Permalink
[Fix] Fix the KV Cache insertion logic for quantized OPT (#1648)
Browse files Browse the repository at this point in the history
* initial commit

* swapped transpose and quantizelienar

* tiptoeing towards the fix

* cleanup, came up with a better idea for a fix

* revert a mistake

* Delete hehe2.py

* producing good looking graph lets test in deepsparse

* clean implementation, working in opt

* simplify the PR

* ready for rereview

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
Co-authored-by: Alexandre Marques <alexandre@neuralmagic.com>
  • Loading branch information
3 people committed Jul 19, 2023
1 parent a4cb0c3 commit 6782b03
Showing 1 changed file with 138 additions and 17 deletions.
155 changes: 138 additions & 17 deletions src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

_LOGGER = logging.getLogger(__name__)

ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast", "Reshape"]
ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast", "Reshape", "QuantizeLinear"]
ALLOWED_NODES_FOLLOWING_CONCAT = ["Transpose", "QuantizeLinear"]
OUTPUT_CACHE_NAME = """present.{attention_layer_idx}.{cache_type}"""
INPUT_CACHE_NAME = """past_key_values.{attention_layer_idx}.{cache_type}"""

Expand Down Expand Up @@ -196,6 +197,7 @@ def transform(self, model: ModelProto) -> ModelProto:
transpose_input=self.transpose_key_input,
multiply_batch_by_num_att_heads=self.multiply_batch_by_num_att_heads, # noqa E501
)

value_concat_node, value_input_tensor, value_output_tensor = create_cache(
model=model,
node=value_matmul,
Expand All @@ -220,6 +222,7 @@ def transform(self, model: ModelProto) -> ModelProto:
self.log_match(value_matmul)

# update model with cache inputs, and outputs

model.graph.input.extend(inputs_to_add)
model.graph.output.extend(outputs_to_add)

Expand Down Expand Up @@ -305,13 +308,15 @@ def create_cache(
cache_input_idx = 3 # QLinearMatMul B matrix is at idx 3, not 1

cache_parent = graph.get_node_single_parent(node, index=cache_input_idx)
if isinstance(cache_parent, NodeProto) and cache_parent.op_type == "Transpose":
# move cache to before a transpose if applicable
# this is due to pytorch operations potentially extracting shape values
# from the key tensor before the transpose is applied
pre_cache_input_id = cache_parent.input[0]
# update concat axis
node = cache_parent

if (
isinstance(cache_parent, NodeProto)
and cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT
):
while cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT:
pre_cache_input_id = cache_parent.input[0]
cache_parent = graph.get_node_single_parent(cache_parent, index=0)

else:
pre_cache_input_id = node.input[cache_input_idx]

Expand Down Expand Up @@ -362,10 +367,24 @@ def create_cache(
name=f"concat.{cache_input_name_concat}",
)

for node in model.graph.node:
for input_idx, input_id in enumerate(node.input):
if input_id == pre_cache_input_id and node.name != concat_node.name:
node.input[input_idx] = cache_output_name_concat
for _node in model.graph.node:
for input_idx, input_id in enumerate(_node.input):
if input_id == pre_cache_input_id and _node.name != concat_node.name:
_node.input[input_idx] = cache_output_name_concat

if node.op_type == "MatMulInteger":
quantize_linear = graph.get_node_single_parent(node, cache_input_idx)
quantize_linear_parent = graph.get_node_single_parent(quantize_linear, 0)
if quantize_linear_parent is None:
quantize_linear_parent = concat_node

concat_node = move_quantize_linear_node(
quantize_linear=quantize_linear,
quantize_linear_parent=quantize_linear_parent,
concat=concat_node,
cache_input_idx=cache_input_idx,
graph=graph,
)

graph.add_node(concat_node)

Expand Down Expand Up @@ -551,6 +570,97 @@ def transpose_kv_cache_inputs_outputs(
)


def move_quantize_linear_node(
quantize_linear: NodeProto,
quantize_linear_parent: NodeProto,
concat: NodeProto,
cache_input_idx: str,
graph: ONNXGraph,
) -> NodeProto:
"""
Moves a QuantizeLinear node before the `concat` node, so
that the data that arrives from `ConcatNodeParent` to `concat`
is already quantized (see the diagram below). This is required
so that the `concat` node joins the quantized data from the
`ConcatNodeParent` with the quantized kv cache input.
Transforms
```
| ConcatNodeParent
| |
| | Key/Value Cache(uint8)
| | |
| | ...
| | |
| | |
| |
| Concat
| |
| ...
| |
| QuantizeLinear
| |
| | ...
| | |
| |
| QLinearMatMul
```
to
```
| ConcatNodeParent
| |
| QuantizeLinear
| |
| | Key/Value Cache (uint8)
| | |
| | ...
| | |
| | |
| |
| Concat
| |
| |
| | ...
| | |
| |
| QLinearMatMul
```
:param quantize_linear: The QuantizeLinear node to move.
In reality, this node will be removed and a new node,
that inherits attributes from this node, will be created
in the proper place.
:param quantize_linear_parent: The parent of the QuantizeLinear node.
:param concat: The concat node to move the QuantizeLinear node before.
:param cache_input_idx: The index of the cache input in the concat node.
:param graph: The graph to update.
:return: The updated Concat node.
"""
if quantize_linear.op_type != "QuantizeLinear":
raise ValueError(
f"It is expected that the node: {quantize_linear.name} "
f"has opset: QuantizeLinear, but it has op_type: {quantize_linear.ops_type}"
)
quantize_linear_child = graph.get_node_single_child(quantize_linear)
if quantize_linear_child.op_type != "MatMulInteger":
raise ValueError(
f"It is expected that the node: {quantize_linear.name} "
"has opset: MatMulInteger, but it has "
f"op_type: {quantize_linear_child.op_type}"
)

# remove the dependency on the QuantizeLinear node from its
# neighbouring nodes by connecting output of its parent to its child
quantize_linear_child.input[cache_input_idx] = quantize_linear_parent.output[0]

# get the node precedes the concat node and does not come from
# the kv cache input. Then place the QuantizeLinear node after it
concate_node_parent = graph.get_node_parents(concat)[1]
quantize_linear.input[0] = concate_node_parent.output[0]
concat.input[1] = quantize_linear.output[0]
return concat


def _find_key_value_matmul_pairs(
graph: ONNXGraph,
) -> List[Tuple[NodeProto, NodeProto]]:
Expand Down Expand Up @@ -685,15 +795,26 @@ def _find_key_matmul_from_value_matmul(
def _value_input_idx(value_matmul: NodeProto, model: ModelProto) -> int:
graph = ONNXGraph(model)
# get idx of matmul that the value node is an input of
if len(value_matmul.input) != 2:
expected_num_inputs = 4 if value_matmul.op_type == "MatMulInteger" else 2

if len(value_matmul.input) != expected_num_inputs:
raise ValueError(
f"Expected value matmul to have 2 inputs, got {len(value_matmul.input)}"
f"Expected value matmul to have {expected_num_inputs} "
f"inputs, got {len(value_matmul.input)}"
)

softmax_input_idx = 0 # default to softmax being on left hand side
for idx, parent in enumerate(graph.get_node_parents(value_matmul)):
if isinstance(parent, NodeProto) and parent.op_type == "Softmax":
softmax_input_idx = idx
break
if isinstance(parent, NodeProto):
# if a parent is a softmax or the parent of value matmul is a direct
# child of a softmax (quantized scenario), then the softmax is the
# idx'th input to the value matmul
if (
parent.op_type == "Softmax"
or graph.get_node_single_parent(parent, 0).op_type == "Softmax"
):
softmax_input_idx = idx
break
return 1 - softmax_input_idx # return index that isn't the softmax


Expand Down

0 comments on commit 6782b03

Please sign in to comment.