Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix the KV Cache insertion logic for quantized OPT #1648

Merged
merged 13 commits into from
Jul 19, 2023
211 changes: 186 additions & 25 deletions src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

_LOGGER = logging.getLogger(__name__)

ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast", "Reshape"]
ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast", "Reshape", "QuantizeLinear"]
ALLOWED_NODES_FOLLOWING_CONCAT = ["Transpose", "QuantizeLinear"]
OUTPUT_CACHE_NAME = """present.{attention_layer_idx}.{cache_type}"""
INPUT_CACHE_NAME = """past_key_values.{attention_layer_idx}.{cache_type}"""

Expand Down Expand Up @@ -172,15 +173,19 @@ def transform(self, model: ModelProto) -> ModelProto:
# Inject kv cache concatenated with the current keys/values as the output
inputs_to_add = []
outputs_to_add = []
nodes_to_remove = []

# get default int8 type to use if graph is quantized
use_uint8_if_quantized = _use_uint8_if_quantized(graph)

for idx, (key_matmul, value_matmul) in enumerate(key_value_matmul_pairs):

value_input_idx = _value_input_idx(value_matmul, model)

key_concat_node, key_input_tensor, key_output_tensor = create_cache(
(
key_concat_node,
key_input_tensor,
key_output_tensor,
key_nodes_to_remove,
) = create_cache(
model=model,
node=key_matmul,
cache_input_idx=1,
Expand All @@ -196,7 +201,12 @@ def transform(self, model: ModelProto) -> ModelProto:
transpose_input=self.transpose_key_input,
multiply_batch_by_num_att_heads=self.multiply_batch_by_num_att_heads, # noqa E501
)
value_concat_node, value_input_tensor, value_output_tensor = create_cache(
(
value_concat_node,
value_input_tensor,
value_output_tensor,
value_nodes_to_remove,
) = create_cache(
model=model,
node=value_matmul,
cache_input_idx=value_input_idx,
Expand All @@ -215,6 +225,7 @@ def transform(self, model: ModelProto) -> ModelProto:

inputs_to_add.extend([key_input_tensor, value_input_tensor])
outputs_to_add.extend([key_output_tensor, value_output_tensor])
nodes_to_remove.extend(key_nodes_to_remove + value_nodes_to_remove)

self.log_match(key_matmul)
self.log_match(value_matmul)
Expand All @@ -223,6 +234,11 @@ def transform(self, model: ModelProto) -> ModelProto:
model.graph.input.extend(inputs_to_add)
model.graph.output.extend(outputs_to_add)

# update the graph
graph.update()
# remove nodes that were deleted from the graph
graph.delete_nodes(nodes_to_remove)

_set_attention_mask_to_dynamic(model)

return model
Expand All @@ -240,7 +256,7 @@ def create_cache(
batch_size: int = 1,
multiply_batch_by_num_att_heads: bool = True,
transpose_input: Optional[Tuple[int, int, int, int]] = None,
) -> Tuple[NodeProto, ValueInfoProto, ValueInfoProto]:
) -> Tuple[NodeProto, ValueInfoProto, ValueInfoProto, List[NodeProto]]:
"""
Injects a cache (value or key) into the graph for a given Matmul node.

Expand All @@ -262,8 +278,15 @@ def create_cache(
before the concat node. If `multiply_batch_by_num_att_heads` is True,
the transpose is applied after the batch size is multiplied by the
number of attention heads.
:return: tuple of concat node to add, cache input to add, and cache output to add,
updates existing nodes in-place
:return: tuple of:
- concat node to add
- cache input to add
- cache output to add
- list of output nodes to remove (wile one can add nodes to the
graphs on the fly, one may not be able to remove nodes from
the model until the graph is updated. Since the update may
take some time, we aggregate the list of nodes to remove and
return them to deal with them later)
"""
CACHE_INPUT_DIMS = [
batch_size,
Expand Down Expand Up @@ -305,13 +328,15 @@ def create_cache(
cache_input_idx = 3 # QLinearMatMul B matrix is at idx 3, not 1

cache_parent = graph.get_node_single_parent(node, index=cache_input_idx)
if isinstance(cache_parent, NodeProto) and cache_parent.op_type == "Transpose":
# move cache to before a transpose if applicable
# this is due to pytorch operations potentially extracting shape values
# from the key tensor before the transpose is applied
pre_cache_input_id = cache_parent.input[0]
# update concat axis
node = cache_parent

if (
isinstance(cache_parent, NodeProto)
and cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT
):
while cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT:
pre_cache_input_id = cache_parent.input[0]
cache_parent = graph.get_node_single_parent(cache_parent, index=0)

else:
pre_cache_input_id = node.input[cache_input_idx]

Expand Down Expand Up @@ -362,14 +387,36 @@ def create_cache(
name=f"concat.{cache_input_name_concat}",
)

for node in model.graph.node:
for input_idx, input_id in enumerate(node.input):
if input_id == pre_cache_input_id and node.name != concat_node.name:
node.input[input_idx] = cache_output_name_concat
for _node in model.graph.node:
for input_idx, input_id in enumerate(_node.input):
if input_id == pre_cache_input_id and _node.name != concat_node.name:
_node.input[input_idx] = cache_output_name_concat

nodes_to_remove = []

if node.op_type == "MatMulInteger":
quantize_linear = graph.get_node_single_parent(node, cache_input_idx)
quantize_linear_parent = graph.get_node_single_parent(quantize_linear, 0)
if quantize_linear_parent is None:
quantize_linear_parent = concat_node

concat_node, old_quantize_linear_node = move_quantize_linear_node(
quantize_linear=quantize_linear,
quantize_linear_parent=quantize_linear_parent,
concat=concat_node,
cache_input_idx=cache_input_idx,
graph=graph,
)
nodes_to_remove.append(old_quantize_linear_node)

graph.add_node(concat_node)

return concat_node, cache_input_info, cache_output_info
return (
concat_node,
cache_input_info,
cache_output_info,
nodes_to_remove,
)


def reshape_kv_cache_inputs_outputs(
Expand Down Expand Up @@ -551,6 +598,109 @@ def transpose_kv_cache_inputs_outputs(
)


def move_quantize_linear_node(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
quantize_linear: NodeProto,
quantize_linear_parent: NodeProto,
concat: NodeProto,
cache_input_idx: str,
graph: ONNXGraph,
) -> Tuple[ModelProto, NodeProto]:
"""
Moves a QuantizeLinear node before the `concat` node, so
that the data that arrives from `ConcatNodeParent` to `concat`
is already quantized (see the diagram below). This is required
so that the `concat` node joins the quantized data from the
`ConcatNodeParent` with the quantized kv cache input.

Transforms
```
| ConcatNodeParent
| |
| | Key/Value Cache(uint8)
| | |
| | ...
| | |
| | |
| |
| Concat
| |
| ...
| |
| QuantizeLinear
| |
| | ...
| | |
| |
| QLinearMatMul
```
to

```
| ConcatNodeParent
| |
| QuantizeLinear
| |
| | Key/Value Cache (uint8)
| | |
| | ...
| | |
| | |
| |
| Concat
| |
| |
| | ...
| | |
| |
| QLinearMatMul
```
:param quantize_linear: The QuantizeLinear node to move.
In reality, this node will be removed and a new node,
that inherits attributes from this node, will be created
in the proper place.
:param quantize_linear_parent: The parent of the QuantizeLinear node.
:param concat: The concat node to move the QuantizeLinear node before.
:param cache_input_idx: The index of the cache input in the concat node.
:param graph: The graph to update.
:return: The updated Concat and the old QuantizeLinear node.
The old QuantizeLinear node should be removed from the graph (the
deletion is being done outside of this function for efficiency).
"""
if quantize_linear.op_type != "QuantizeLinear":
raise ValueError(
f"It is expected that the node: {quantize_linear.name} "
f"has opset: QuantizeLinear, but it has op_type: {quantize_linear.ops_type}"
)
quantize_linear_child = graph.get_node_single_child(quantize_linear)
if quantize_linear_child.op_type != "MatMulInteger":
raise ValueError(
f"It is expected that the node: {quantize_linear.name} "
"has opset: MatMulInteger, but it has "
f"op_type: {quantize_linear_child.op_type}"
)

ql_input, scale, zero_point = quantize_linear.input
new_ql_name = f"{quantize_linear.name}.moved"

# remove the dependency of the model graph on the QuantizeLinear node
# by connecting output of its parent to its child
quantize_linear_child.input[cache_input_idx] = quantize_linear_parent.output[0]

# get the node precedes the concat node and does not come from
# the kv cache input
concate_node_parent = graph.get_node_parents(concat)[1]

new_quantize_linear = onnx.helper.make_node(
op_type="QuantizeLinear",
inputs=[concate_node_parent.output[0], scale, zero_point],
outputs=[new_ql_name],
name=new_ql_name,
)
concat.input[1] = new_ql_name
graph.add_node(new_quantize_linear)
return concat, quantize_linear


def _find_key_value_matmul_pairs(
graph: ONNXGraph,
) -> List[Tuple[NodeProto, NodeProto]]:
Expand Down Expand Up @@ -685,15 +835,26 @@ def _find_key_matmul_from_value_matmul(
def _value_input_idx(value_matmul: NodeProto, model: ModelProto) -> int:
graph = ONNXGraph(model)
# get idx of matmul that the value node is an input of
if len(value_matmul.input) != 2:
expected_num_inputs = 4 if value_matmul.op_type == "MatMulInteger" else 2

if len(value_matmul.input) != expected_num_inputs:
raise ValueError(
f"Expected value matmul to have 2 inputs, got {len(value_matmul.input)}"
f"Expected value matmul to have {expected_num_inputs} "
f"inputs, got {len(value_matmul.input)}"
)

softmax_input_idx = 0 # default to softmax being on left hand side
for idx, parent in enumerate(graph.get_node_parents(value_matmul)):
if isinstance(parent, NodeProto) and parent.op_type == "Softmax":
softmax_input_idx = idx
break
if isinstance(parent, NodeProto):
# if a parent is a softmax or the parent of value matmul is a direct
# child of a softmax (quantized scenario), then the softmax is the
# idx'th input to the value matmul
if (
parent.op_type == "Softmax"
or graph.get_node_single_parent(parent, 0).op_type == "Softmax"
):
softmax_input_idx = idx
break
return 1 - softmax_input_idx # return index that isn't the softmax


Expand Down
Loading