Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix the KV Cache insertion logic for quantized OPT #1648

Merged
merged 13 commits into from
Jul 19, 2023
155 changes: 138 additions & 17 deletions src/sparseml/exporters/transforms/kv_cache/cache_keys_and_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

_LOGGER = logging.getLogger(__name__)

ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast", "Reshape"]
ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast", "Reshape", "QuantizeLinear"]
ALLOWED_NODES_FOLLOWING_CONCAT = ["Transpose", "QuantizeLinear"]
OUTPUT_CACHE_NAME = """present.{attention_layer_idx}.{cache_type}"""
INPUT_CACHE_NAME = """past_key_values.{attention_layer_idx}.{cache_type}"""

Expand Down Expand Up @@ -196,6 +197,7 @@ def transform(self, model: ModelProto) -> ModelProto:
transpose_input=self.transpose_key_input,
multiply_batch_by_num_att_heads=self.multiply_batch_by_num_att_heads, # noqa E501
)

value_concat_node, value_input_tensor, value_output_tensor = create_cache(
model=model,
node=value_matmul,
Expand All @@ -220,6 +222,7 @@ def transform(self, model: ModelProto) -> ModelProto:
self.log_match(value_matmul)

# update model with cache inputs, and outputs

model.graph.input.extend(inputs_to_add)
model.graph.output.extend(outputs_to_add)

Expand Down Expand Up @@ -305,13 +308,15 @@ def create_cache(
cache_input_idx = 3 # QLinearMatMul B matrix is at idx 3, not 1

cache_parent = graph.get_node_single_parent(node, index=cache_input_idx)
if isinstance(cache_parent, NodeProto) and cache_parent.op_type == "Transpose":
# move cache to before a transpose if applicable
# this is due to pytorch operations potentially extracting shape values
# from the key tensor before the transpose is applied
pre_cache_input_id = cache_parent.input[0]
# update concat axis
node = cache_parent

if (
isinstance(cache_parent, NodeProto)
and cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT
):
while cache_parent.op_type in ALLOWED_NODES_FOLLOWING_CONCAT:
pre_cache_input_id = cache_parent.input[0]
cache_parent = graph.get_node_single_parent(cache_parent, index=0)

else:
pre_cache_input_id = node.input[cache_input_idx]

Expand Down Expand Up @@ -362,10 +367,24 @@ def create_cache(
name=f"concat.{cache_input_name_concat}",
)

for node in model.graph.node:
for input_idx, input_id in enumerate(node.input):
if input_id == pre_cache_input_id and node.name != concat_node.name:
node.input[input_idx] = cache_output_name_concat
for _node in model.graph.node:
for input_idx, input_id in enumerate(_node.input):
if input_id == pre_cache_input_id and _node.name != concat_node.name:
_node.input[input_idx] = cache_output_name_concat

if node.op_type == "MatMulInteger":
quantize_linear = graph.get_node_single_parent(node, cache_input_idx)
quantize_linear_parent = graph.get_node_single_parent(quantize_linear, 0)
if quantize_linear_parent is None:
quantize_linear_parent = concat_node

concat_node = move_quantize_linear_node(
quantize_linear=quantize_linear,
quantize_linear_parent=quantize_linear_parent,
concat=concat_node,
cache_input_idx=cache_input_idx,
graph=graph,
)

graph.add_node(concat_node)

Expand Down Expand Up @@ -551,6 +570,97 @@ def transpose_kv_cache_inputs_outputs(
)


def move_quantize_linear_node(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
quantize_linear: NodeProto,
quantize_linear_parent: NodeProto,
concat: NodeProto,
cache_input_idx: str,
graph: ONNXGraph,
) -> NodeProto:
"""
Moves a QuantizeLinear node before the `concat` node, so
that the data that arrives from `ConcatNodeParent` to `concat`
is already quantized (see the diagram below). This is required
so that the `concat` node joins the quantized data from the
`ConcatNodeParent` with the quantized kv cache input.

Transforms
```
| ConcatNodeParent
| |
| | Key/Value Cache(uint8)
| | |
| | ...
| | |
| | |
| |
| Concat
| |
| ...
| |
| QuantizeLinear
| |
| | ...
| | |
| |
| QLinearMatMul
```
to

```
| ConcatNodeParent
| |
| QuantizeLinear
| |
| | Key/Value Cache (uint8)
| | |
| | ...
| | |
| | |
| |
| Concat
| |
| |
| | ...
| | |
| |
| QLinearMatMul
```
:param quantize_linear: The QuantizeLinear node to move.
In reality, this node will be removed and a new node,
that inherits attributes from this node, will be created
in the proper place.
:param quantize_linear_parent: The parent of the QuantizeLinear node.
:param concat: The concat node to move the QuantizeLinear node before.
:param cache_input_idx: The index of the cache input in the concat node.
:param graph: The graph to update.
:return: The updated Concat node.
"""
if quantize_linear.op_type != "QuantizeLinear":
raise ValueError(
f"It is expected that the node: {quantize_linear.name} "
f"has opset: QuantizeLinear, but it has op_type: {quantize_linear.ops_type}"
)
quantize_linear_child = graph.get_node_single_child(quantize_linear)
if quantize_linear_child.op_type != "MatMulInteger":
raise ValueError(
f"It is expected that the node: {quantize_linear.name} "
"has opset: MatMulInteger, but it has "
f"op_type: {quantize_linear_child.op_type}"
)

# remove the dependency on the QuantizeLinear node from its
# neighbouring nodes by connecting output of its parent to its child
quantize_linear_child.input[cache_input_idx] = quantize_linear_parent.output[0]

# get the node precedes the concat node and does not come from
# the kv cache input. Then place the QuantizeLinear node after it
concate_node_parent = graph.get_node_parents(concat)[1]
quantize_linear.input[0] = concate_node_parent.output[0]
concat.input[1] = quantize_linear.output[0]
return concat


def _find_key_value_matmul_pairs(
graph: ONNXGraph,
) -> List[Tuple[NodeProto, NodeProto]]:
Expand Down Expand Up @@ -685,15 +795,26 @@ def _find_key_matmul_from_value_matmul(
def _value_input_idx(value_matmul: NodeProto, model: ModelProto) -> int:
graph = ONNXGraph(model)
# get idx of matmul that the value node is an input of
if len(value_matmul.input) != 2:
expected_num_inputs = 4 if value_matmul.op_type == "MatMulInteger" else 2

if len(value_matmul.input) != expected_num_inputs:
raise ValueError(
f"Expected value matmul to have 2 inputs, got {len(value_matmul.input)}"
f"Expected value matmul to have {expected_num_inputs} "
f"inputs, got {len(value_matmul.input)}"
)

softmax_input_idx = 0 # default to softmax being on left hand side
for idx, parent in enumerate(graph.get_node_parents(value_matmul)):
if isinstance(parent, NodeProto) and parent.op_type == "Softmax":
softmax_input_idx = idx
break
if isinstance(parent, NodeProto):
# if a parent is a softmax or the parent of value matmul is a direct
# child of a softmax (quantized scenario), then the softmax is the
# idx'th input to the value matmul
if (
parent.op_type == "Softmax"
or graph.get_node_single_parent(parent, 0).op_type == "Softmax"
):
softmax_input_idx = idx
break
return 1 - softmax_input_idx # return index that isn't the softmax


Expand Down
Loading