Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLAMA] KV Cache Injection #1709

Merged
merged 7 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sparseml/exporters/transforms/kv_cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
from .transforms_base import *
from .transforms_opt import *
from .transforms_codegen import *
from .transforms_llama import *
from .configs import *
20 changes: 19 additions & 1 deletion src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from sparseml.exporters.transforms.kv_cache.transforms_codegen import (
AdditionalTransformsCodeGen,
)
from sparseml.exporters.transforms.kv_cache.transforms_llama import (
AdditionalTransformsLLAMA,
)
from sparseml.exporters.transforms.kv_cache.transforms_opt import (
AdditionalTransformsOPT,
)
Expand Down Expand Up @@ -112,10 +115,25 @@ class Config:
multiply_batch_by_num_att_heads=True,
)

LLAMA_CONFIG = KeyValueCacheConfig(
model_name="llama",
additional_transforms=AdditionalTransformsLLAMA,
key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="hidden_size",
transpose_value_input=(0, 2, 1, 3),
transpose_key_input=None,
multiply_batch_by_num_att_heads=False,
)


def get_kv_cache_config(
model_path: str,
supported_configs: List[BaseModel] = [OPT_CONFIG, CODEGEN_CONFIG, BLOOM_CONFIG],
supported_configs: List[BaseModel] = [
OPT_CONFIG,
CODEGEN_CONFIG,
BLOOM_CONFIG,
LLAMA_CONFIG,
],
) -> KeyValueCacheConfig:
"""
Get the kv cache config for the model at the given path.
Expand Down
207 changes: 207 additions & 0 deletions src/sparseml/exporters/transforms/kv_cache/transforms_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import numpy
import onnx
from onnx import ModelProto, numpy_helper

from sparseml.exporters.transforms.kv_cache.transforms_base import (
AdditionalTransformsBase,
)
from sparseml.onnx.utils.graph_editor import ONNXGraph
from sparseml.onnx.utils.helpers import get_nodes_by_input_id


__all__ = ["AdditionalTransformsLLAMA"]

_LOGGER = logging.getLogger(__name__)


class AdditionalTransformsLLAMA(AdditionalTransformsBase):

POSITION_IDS_MATCHING_PATTERN = dict(op_type="Range", children_ops=[["Unsqueeze"]])
CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Expand", children_ops=[["Add"]])

def transform(self, model: ModelProto) -> ModelProto:
"""
1. Adds `positions` as an input to the model
2. Adds `causal_mask` as an input to the model
dsikka marked this conversation as resolved.
Show resolved Hide resolved
2. Finds the node that initially creates the `position_ids` tensor
3. Updates the node to use the positions input instead of
computing it from the Range op
4. Finds the nodes that initially create the `causal_mask` tensors
5. Updates the nodes to use the causal_mask input instead of
computing it from the Expand op
6. Update the masks to be floats, as expected by the model

:param model: model to update
:return: updated model
"""

model = self.update_slice_nodes_for_positions_input(model)
model = self.add_positions_input(model)
model = self.add_causal_mask_input(model)

position_ids_nodes = self.find_nodes_by_pattern(
model, pattern=self.POSITION_IDS_MATCHING_PATTERN
)

if len(position_ids_nodes) != 1:
raise ValueError(
"Expected to find exactly one node matching "
f"the pattern {self.POSITION_IDS_MATCHING_PATTERN}, "
f"found {len(position_ids_nodes)}"
)

model = self.inject_positions(model, position_ids_nodes, "Unsqueeze")

causal_mask_nodes = self.find_nodes_by_pattern(
model, pattern=self.CAUSAL_MASK_MATCHING_PATTERN
)
model = self.inject_causal_mask(model, causal_mask_nodes, "Add")
model = self.adjust_causal_mask(model)
return model

def update_slice_nodes_for_positions_input(self, model: ModelProto) -> ModelProto:
"""
Update the Slice nodes in the attention heads such that ends attribute is set
to the max int value. This value is missing from the export and is required
for the position ids injection.
"""
dsikka marked this conversation as resolved.
Show resolved Hide resolved
SLICE_MAX_INT_NAME = "slice_max_int"
dsikka marked this conversation as resolved.
Show resolved Hide resolved
arr = numpy.array(numpy.iinfo(numpy.intp).max).reshape(
1,
)
max_int_tensor = numpy_helper.from_array(arr, name=SLICE_MAX_INT_NAME)

nodes_found = 0
for node in model.graph.node:
if node.op_type == "Slice":
dsikka marked this conversation as resolved.
Show resolved Hide resolved
data = node.input[0]
if "onnx::" in data:
dsikka marked this conversation as resolved.
Show resolved Hide resolved
node.input[2] = SLICE_MAX_INT_NAME
nodes_found += 1
dsikka marked this conversation as resolved.
Show resolved Hide resolved

_LOGGER.info(f"Found {nodes_found} slice nodes to update")

model.graph.initializer.append(max_int_tensor)
ONNXGraph(model).delete_orphaned_node_branches()
return model

def adjust_causal_mask(self, model: ModelProto) -> ModelProto:
dsikka marked this conversation as resolved.
Show resolved Hide resolved
"""
Insert a `Cast`, `Sub` and `Mul` nodes after the causal mask input to change
the initial int64, to a mask of floats expected by the model.

Transform:
```
| causal_mask
| |
| causal_mask_input_child
```
to:
```
| causal_mask (1 and 0)
| |
| Cast (output -> 1.0 and 0.0)
| |
| Sub (output -> 0.0 and -1.0)
| |
| Mul (output -> 0.0 and numpy.finfo(numpy.float32).min)
| |
| causal_mask_input_child

The resulting node will change the input int64 mask
e.g.
```
causal_mask =
[[[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]]]]
```

to a mask of floats:
```
x = numpy.finfo(numpy.float32).min
causal_mask_adjusted =
[[[[0.0, 0.0, 0.0, x, x, x],
[0.0, 0.0, 0.0, 0.0, x, x],
[0.0, 0.0, 0.0, 0.0, 0.0, x],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]]
```

:param model: the model to update
:return: the updated model
"""

graph = ONNXGraph(model)

ones_initializer = onnx.helper.make_tensor(
name="ones_initializer",
data_type=onnx.TensorProto.FLOAT,
dims=[1],
vals=[1.0],
)

floating_point_limit_initializer = onnx.helper.make_tensor(
name="floating_point_limit_initializer",
data_type=onnx.TensorProto.FLOAT,
dims=[1],
vals=[-numpy.finfo(numpy.float32).min],
)

cast_node = onnx.helper.make_node(
"Cast",
inputs=[self.CAUSAL_MASK_NAME],
outputs=[f"{self.CAUSAL_MASK_NAME}_cast"],
to=onnx.TensorProto.FLOAT,
)

sub_node = onnx.helper.make_node(
"Sub",
inputs=[f"{self.CAUSAL_MASK_NAME}_cast", ones_initializer.name],
outputs=[f"{self.CAUSAL_MASK_NAME}_sub"],
)

mul_node = onnx.helper.make_node(
"Mul",
inputs=[
f"{self.CAUSAL_MASK_NAME}_sub",
floating_point_limit_initializer.name,
],
outputs=[f"{self.CAUSAL_MASK_NAME}_mul"],
)

new_nodes = [cast_node, sub_node, mul_node]

# get the node that takes the causal mask as input
# and replace the input with the adjusted causal mask input
causal_mask_input_child = get_nodes_by_input_id(model, self.CAUSAL_MASK_NAME)[0]

for idx, input_name in enumerate(causal_mask_input_child.input):
if input_name == self.CAUSAL_MASK_NAME:
causal_mask_input_child.input[idx] = f"{self.CAUSAL_MASK_NAME}_mul"

for node in new_nodes:
graph.add_node(node)
self.log_match(node)

model.graph.initializer.extend(
[ones_initializer, floating_point_limit_initializer]
)
_LOGGER.info(f"Successfully adjusted the {self.CAUSAL_MASK_NAME} input")

return model
Loading