From c4c5d2c3320b6ac3417160d65f229a487195950b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 21 Aug 2023 17:00:14 -0400 Subject: [PATCH 1/5] llama kv cache update --- .../exporters/transforms/kv_cache/__init__.py | 1 + .../exporters/transforms/kv_cache/configs.py | 20 +- .../transforms/kv_cache/transforms_llama.py | 180 ++++++++++++++++++ 3 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 src/sparseml/exporters/transforms/kv_cache/transforms_llama.py diff --git a/src/sparseml/exporters/transforms/kv_cache/__init__.py b/src/sparseml/exporters/transforms/kv_cache/__init__.py index 94ea5287166..bf650cb7395 100644 --- a/src/sparseml/exporters/transforms/kv_cache/__init__.py +++ b/src/sparseml/exporters/transforms/kv_cache/__init__.py @@ -22,4 +22,5 @@ from .transforms_base import * from .transforms_opt import * from .transforms_codegen import * +from .transforms_llama import * from .configs import * diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py index 0bfbbfb1850..e54fa700f75 100644 --- a/src/sparseml/exporters/transforms/kv_cache/configs.py +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -23,6 +23,9 @@ from sparseml.exporters.transforms.kv_cache.transforms_codegen import ( AdditionalTransformsCodeGen, ) +from sparseml.exporters.transforms.kv_cache.transforms_llama import ( + AdditionalTransformsLLAMA, +) from sparseml.exporters.transforms.kv_cache.transforms_opt import ( AdditionalTransformsOPT, ) @@ -112,10 +115,25 @@ class Config: multiply_batch_by_num_att_heads=True, ) +LLAMA_CONFIG = KeyValueCacheConfig( + model_name="llama", + additional_transforms=AdditionalTransformsLLAMA, + key_num_attention_heads="num_attention_heads", + key_num_embedding_hidden_size="hidden_size", + transpose_value_input=(0, 2, 1, 3), + transpose_key_input=None, + multiply_batch_by_num_att_heads=False, +) + def get_kv_cache_config( model_path: str, - supported_configs: List[BaseModel] = [OPT_CONFIG, CODEGEN_CONFIG, BLOOM_CONFIG], + supported_configs: List[BaseModel] = [ + OPT_CONFIG, + CODEGEN_CONFIG, + BLOOM_CONFIG, + LLAMA_CONFIG, + ], ) -> KeyValueCacheConfig: """ Get the kv cache config for the model at the given path. diff --git a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py new file mode 100644 index 00000000000..4f2435d2894 --- /dev/null +++ b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py @@ -0,0 +1,180 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import numpy +import onnx +from onnx import ModelProto + +from sparseml.exporters.transforms.kv_cache.transforms_base import ( + AdditionalTransformsBase, +) +from sparseml.onnx.utils.graph_editor import ONNXGraph +from sparseml.onnx.utils.helpers import get_nodes_by_input_id + + +__all__ = ["AdditionalTransformsLLAMA"] + +_LOGGER = logging.getLogger(__name__) + + +class AdditionalTransformsLLAMA(AdditionalTransformsBase): + + POSITION_IDS_MATCHING_PATTERN = dict(op_type="Range", children_ops=[["Unsqueeze"]]) + CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Expand", children_ops=[["Add"]]) + + def transform(self, model: ModelProto) -> ModelProto: + """ + 1. Adds `positions` as an input to the model + 2. Adds `causal_mask` as an input to the model + 2. Finds the node that initially creates the `position_ids` tensor + 3. Updates the node to use the positions input instead of + computing it from the Range op + 4. Finds the nodes that initially create the `causal_mask` tensors + 5. Updates the nodes to use the causal_mask input instead of + computing it from the Expand op + 6. Update the masks to be floats, as expected by the model + + :param model: model to update + :return: updated model + """ + + model = self.add_positions_input(model) + model = self.add_causal_mask_input(model) + + position_ids_nodes = self.find_nodes_by_pattern( + model, pattern=self.POSITION_IDS_MATCHING_PATTERN + ) + + if len(position_ids_nodes) != 1: + raise ValueError( + "Expected to find exactly one node matching " + f"the pattern {self.POSITION_IDS_MATCHING_PATTERN}, " + f"found {len(position_ids_nodes)}" + ) + + model = self.inject_positions(model, position_ids_nodes, "Unsqueeze") + + causal_mask_nodes = self.find_nodes_by_pattern( + model, pattern=self.CAUSAL_MASK_MATCHING_PATTERN + ) + model = self.inject_causal_mask(model, causal_mask_nodes, "Add") + model = self.adjust_causal_mask(model) + return model + + def adjust_causal_mask(self, model: ModelProto) -> ModelProto: + """ + Insert a `Cast`, `Sub` and `Mul` nodes after the causal mask input to change + the initial int64, to a mask of floats expected by the model. + + Transform: + ``` + | causal_mask + | | + | causal_mask_input_child + ``` + to: + ``` + | causal_mask (1 and 0) + | | + | Cast (output -> 1.0 and 0.0) + | | + | Sub (output -> 0.0 and -1.0) + | | + | Mul (output -> 0.0 and numpy.finfo(numpy.float32).min) + | | + | causal_mask_input_child + + The resulting node will change the input int64 mask + e.g. + ``` + causal_mask = + [[[[1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1]]]] + ``` + + to a mask of floats: + ``` + x = numpy.finfo(numpy.float32).min + causal_mask_adjusted = + [[[[0.0, 0.0, 0.0, x, x, x], + [0.0, 0.0, 0.0, 0.0, x, x], + [0.0, 0.0, 0.0, 0.0, 0.0, x], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]] + ``` + + :param model: the model to update + :return: the updated model + """ + + graph = ONNXGraph(model) + + ones_initializer = onnx.helper.make_tensor( + name="ones_initializer", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=[1.0], + ) + + floating_point_limit_initializer = onnx.helper.make_tensor( + name="floating_point_limit_initializer", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=[-numpy.finfo(numpy.float32).min], + ) + + cast_node = onnx.helper.make_node( + "Cast", + inputs=[self.CAUSAL_MASK_NAME], + outputs=[f"{self.CAUSAL_MASK_NAME}_cast"], + to=onnx.TensorProto.FLOAT, + ) + + sub_node = onnx.helper.make_node( + "Sub", + inputs=[f"{self.CAUSAL_MASK_NAME}_cast", ones_initializer.name], + outputs=[f"{self.CAUSAL_MASK_NAME}_sub"], + ) + + mul_node = onnx.helper.make_node( + "Mul", + inputs=[ + f"{self.CAUSAL_MASK_NAME}_sub", + floating_point_limit_initializer.name, + ], + outputs=[f"{self.CAUSAL_MASK_NAME}_mul"], + ) + + new_nodes = [cast_node, sub_node, mul_node] + + # get the node that takes the causal mask as input + # and replace the input with the adjusted causal mask input + causal_mask_input_child = get_nodes_by_input_id(model, self.CAUSAL_MASK_NAME)[0] + + for idx, input_name in enumerate(causal_mask_input_child.input): + if input_name == self.CAUSAL_MASK_NAME: + causal_mask_input_child.input[idx] = f"{self.CAUSAL_MASK_NAME}_mul" + + for node in new_nodes: + graph.add_node(node) + self.log_match(node) + + model.graph.initializer.extend( + [ones_initializer, floating_point_limit_initializer] + ) + _LOGGER.info(f"Successfully adjusted the {self.CAUSAL_MASK_NAME} input") + + return model From 9a494f8c1ae81c3a66e93e930cff12014e7a54fd Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 22 Aug 2023 13:41:51 -0400 Subject: [PATCH 2/5] add llama transform to update slice nodes during kv cache injection --- .../transforms/kv_cache/transforms_llama.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py index 4f2435d2894..ced4034ca8b 100644 --- a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py +++ b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py @@ -15,7 +15,7 @@ import numpy import onnx -from onnx import ModelProto +from onnx import ModelProto, numpy_helper from sparseml.exporters.transforms.kv_cache.transforms_base import ( AdditionalTransformsBase, @@ -50,6 +50,7 @@ def transform(self, model: ModelProto) -> ModelProto: :return: updated model """ + model = self.update_slice_nodes_for_positions_input(model) model = self.add_positions_input(model) model = self.add_causal_mask_input(model) @@ -73,6 +74,32 @@ def transform(self, model: ModelProto) -> ModelProto: model = self.adjust_causal_mask(model) return model + def update_slice_nodes_for_positions_input(self, model: ModelProto) -> ModelProto: + """ + Update the Slice nodes in the attention heads such that ends attribute is set + to the max int value. This value is missing from the export and is required + for the position ids injection. + """ + SLICE_MAX_INT_NAME = "slice_max_int" + arr = numpy.array(numpy.iinfo(numpy.intp).max).reshape( + 1, + ) + max_int_tensor = numpy_helper.from_array(arr, name=SLICE_MAX_INT_NAME) + + nodes_found = 0 + for node in model.graph.node: + if node.op_type == "Slice": + data = node.input[0] + if "onnx::" in data: + node.input[2] = SLICE_MAX_INT_NAME + nodes_found += 1 + + _LOGGER.info(f"Found {nodes_found} slice nodes to update") + + model.graph.initializer.append(max_int_tensor) + ONNXGraph(model).delete_orphaned_node_branches() + return model + def adjust_causal_mask(self, model: ModelProto) -> ModelProto: """ Insert a `Cast`, `Sub` and `Mul` nodes after the causal mask input to change From 4766f3c7eeae8fd97be5e90053c2a891ae73aec7 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 24 Aug 2023 12:33:07 -0400 Subject: [PATCH 3/5] move adjust_causal_masks, update docstring with additional details --- .../transforms/kv_cache/transforms_base.py | 109 +++++++++++++ .../transforms/kv_cache/transforms_llama.py | 146 ++++-------------- 2 files changed, 137 insertions(+), 118 deletions(-) diff --git a/src/sparseml/exporters/transforms/kv_cache/transforms_base.py b/src/sparseml/exporters/transforms/kv_cache/transforms_base.py index 3f124ad5762..c9331800dfd 100644 --- a/src/sparseml/exporters/transforms/kv_cache/transforms_base.py +++ b/src/sparseml/exporters/transforms/kv_cache/transforms_base.py @@ -16,11 +16,14 @@ from copy import deepcopy from typing import Any, Dict, List, Optional +import numpy +import onnx from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper from sparseml.exporters.transforms.onnx_transform import OnnxTransform from sparseml.exporters.transforms.utils.matching import get_structural_matches from sparseml.onnx.utils.graph_editor import ONNXGraph +from sparseml.onnx.utils.helpers import get_nodes_by_input_id __all__ = ["AdditionalTransformsBase"] @@ -196,3 +199,109 @@ def _get_input_proto(self, model: ModelProto, input_name: str) -> ValueInfoProto f"{self.__name__} - unable to find '{input_name}' in model input" ) return input_proto + + def adjust_causal_mask(self, model: ModelProto) -> ModelProto: + """ + Insert a `Cast`, `Sub` and `Mul` nodes after the causal mask input to change + the initial int64, to a mask of floats expected by the model. + + Transform: + ``` + | causal_mask + | | + | causal_mask_input_child + ``` + to: + ``` + | causal_mask (1 and 0) + | | + | Cast (output -> 1.0 and 0.0) + | | + | Sub (output -> 0.0 and -1.0) + | | + | Mul (output -> 0.0 and numpy.finfo(numpy.float32).min) + | | + | causal_mask_input_child + + The resulting node will change the input int64 mask + e.g. + ``` + causal_mask = + [[[[1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1]]]] + ``` + + to a mask of floats: + ``` + x = numpy.finfo(numpy.float32).min + causal_mask_adjusted = + [[[[0.0, 0.0, 0.0, x, x, x], + [0.0, 0.0, 0.0, 0.0, x, x], + [0.0, 0.0, 0.0, 0.0, 0.0, x], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]] + ``` + + :param model: the model to update + :return: the updated model + """ + + graph = ONNXGraph(model) + + ones_initializer = onnx.helper.make_tensor( + name="ones_initializer", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=[1.0], + ) + + floating_point_limit_initializer = onnx.helper.make_tensor( + name="floating_point_limit_initializer", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=[-numpy.finfo(numpy.float32).min], + ) + + cast_node = onnx.helper.make_node( + "Cast", + inputs=[self.CAUSAL_MASK_NAME], + outputs=[f"{self.CAUSAL_MASK_NAME}_cast"], + to=onnx.TensorProto.FLOAT, + ) + + sub_node = onnx.helper.make_node( + "Sub", + inputs=[f"{self.CAUSAL_MASK_NAME}_cast", ones_initializer.name], + outputs=[f"{self.CAUSAL_MASK_NAME}_sub"], + ) + + mul_node = onnx.helper.make_node( + "Mul", + inputs=[ + f"{self.CAUSAL_MASK_NAME}_sub", + floating_point_limit_initializer.name, + ], + outputs=[f"{self.CAUSAL_MASK_NAME}_mul"], + ) + + new_nodes = [cast_node, sub_node, mul_node] + + # get the node that takes the causal mask as input + # and replace the input with the adjusted causal mask input + causal_mask_input_child = get_nodes_by_input_id(model, self.CAUSAL_MASK_NAME)[0] + + for idx, input_name in enumerate(causal_mask_input_child.input): + if input_name == self.CAUSAL_MASK_NAME: + causal_mask_input_child.input[idx] = f"{self.CAUSAL_MASK_NAME}_mul" + + for node in new_nodes: + graph.add_node(node) + self.log_match(node) + + model.graph.initializer.extend( + [ones_initializer, floating_point_limit_initializer] + ) + _LOGGER.info(f"Successfully adjusted the {self.CAUSAL_MASK_NAME} input") + + return model diff --git a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py index ced4034ca8b..707ca3db2b5 100644 --- a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py +++ b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py @@ -14,14 +14,12 @@ import logging import numpy -import onnx from onnx import ModelProto, numpy_helper from sparseml.exporters.transforms.kv_cache.transforms_base import ( AdditionalTransformsBase, ) from sparseml.onnx.utils.graph_editor import ONNXGraph -from sparseml.onnx.utils.helpers import get_nodes_by_input_id __all__ = ["AdditionalTransformsLLAMA"] @@ -36,15 +34,17 @@ class AdditionalTransformsLLAMA(AdditionalTransformsBase): def transform(self, model: ModelProto) -> ModelProto: """ - 1. Adds `positions` as an input to the model - 2. Adds `causal_mask` as an input to the model - 2. Finds the node that initially creates the `position_ids` tensor - 3. Updates the node to use the positions input instead of + 1 Updates the Slice nodes in the attention heads by extending the `ends` + operator + 2. Adds `positions` as an input to the model + 3. Adds `causal_mask` as an input to the model + 4. Finds the node that initially creates the `position_ids` tensor + 5. Updates the node to use the positions input instead of computing it from the Range op - 4. Finds the nodes that initially create the `causal_mask` tensors - 5. Updates the nodes to use the causal_mask input instead of + 6. Finds the nodes that initially create the `causal_mask` tensors + 7. Updates the nodes to use the causal_mask input instead of computing it from the Expand op - 6. Update the masks to be floats, as expected by the model + 8. Update the masks to be floats, as expected by the model :param model: model to update :return: updated model @@ -76,9 +76,24 @@ def transform(self, model: ModelProto) -> ModelProto: def update_slice_nodes_for_positions_input(self, model: ModelProto) -> ModelProto: """ - Update the Slice nodes in the attention heads such that ends attribute is set - to the max int value. This value is missing from the export and is required - for the position ids injection. + Update the Slice nodes in the attention heads such that the `ends` operator is + set to the max int value. This value is missing from the export and is required + for the position ids injection. This is because the onnx export limits access to + the entire sin_cached and cos_cached tables, which results in an index error + with the position ids: + + https://github.com/huggingface/transformers/blob/ + 7a6efe1e9f756f585f2ffe5ada22cf6b15edd23b/src/transformers/models/llama/ + modeling_llama.py#L180. + + By updating the `ends` operator, access is allowed to the entire tables. + The Slice nodes are identified based on if they contain the `data` operator + as an input, which have the name `onnx::Slice_...`. Nodes with this name have + their `ends` operator updated to point to a 1x1 tensor containing the max + int value. + + :param model: model to update + :return: updated model with Slice nodes in the attention heads updated """ SLICE_MAX_INT_NAME = "slice_max_int" arr = numpy.array(numpy.iinfo(numpy.intp).max).reshape( @@ -93,115 +108,10 @@ def update_slice_nodes_for_positions_input(self, model: ModelProto) -> ModelProt if "onnx::" in data: node.input[2] = SLICE_MAX_INT_NAME nodes_found += 1 + self.log_match(node) _LOGGER.info(f"Found {nodes_found} slice nodes to update") model.graph.initializer.append(max_int_tensor) ONNXGraph(model).delete_orphaned_node_branches() return model - - def adjust_causal_mask(self, model: ModelProto) -> ModelProto: - """ - Insert a `Cast`, `Sub` and `Mul` nodes after the causal mask input to change - the initial int64, to a mask of floats expected by the model. - - Transform: - ``` - | causal_mask - | | - | causal_mask_input_child - ``` - to: - ``` - | causal_mask (1 and 0) - | | - | Cast (output -> 1.0 and 0.0) - | | - | Sub (output -> 0.0 and -1.0) - | | - | Mul (output -> 0.0 and numpy.finfo(numpy.float32).min) - | | - | causal_mask_input_child - - The resulting node will change the input int64 mask - e.g. - ``` - causal_mask = - [[[[1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1]]]] - ``` - - to a mask of floats: - ``` - x = numpy.finfo(numpy.float32).min - causal_mask_adjusted = - [[[[0.0, 0.0, 0.0, x, x, x], - [0.0, 0.0, 0.0, 0.0, x, x], - [0.0, 0.0, 0.0, 0.0, 0.0, x], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]] - ``` - - :param model: the model to update - :return: the updated model - """ - - graph = ONNXGraph(model) - - ones_initializer = onnx.helper.make_tensor( - name="ones_initializer", - data_type=onnx.TensorProto.FLOAT, - dims=[1], - vals=[1.0], - ) - - floating_point_limit_initializer = onnx.helper.make_tensor( - name="floating_point_limit_initializer", - data_type=onnx.TensorProto.FLOAT, - dims=[1], - vals=[-numpy.finfo(numpy.float32).min], - ) - - cast_node = onnx.helper.make_node( - "Cast", - inputs=[self.CAUSAL_MASK_NAME], - outputs=[f"{self.CAUSAL_MASK_NAME}_cast"], - to=onnx.TensorProto.FLOAT, - ) - - sub_node = onnx.helper.make_node( - "Sub", - inputs=[f"{self.CAUSAL_MASK_NAME}_cast", ones_initializer.name], - outputs=[f"{self.CAUSAL_MASK_NAME}_sub"], - ) - - mul_node = onnx.helper.make_node( - "Mul", - inputs=[ - f"{self.CAUSAL_MASK_NAME}_sub", - floating_point_limit_initializer.name, - ], - outputs=[f"{self.CAUSAL_MASK_NAME}_mul"], - ) - - new_nodes = [cast_node, sub_node, mul_node] - - # get the node that takes the causal mask as input - # and replace the input with the adjusted causal mask input - causal_mask_input_child = get_nodes_by_input_id(model, self.CAUSAL_MASK_NAME)[0] - - for idx, input_name in enumerate(causal_mask_input_child.input): - if input_name == self.CAUSAL_MASK_NAME: - causal_mask_input_child.input[idx] = f"{self.CAUSAL_MASK_NAME}_mul" - - for node in new_nodes: - graph.add_node(node) - self.log_match(node) - - model.graph.initializer.extend( - [ones_initializer, floating_point_limit_initializer] - ) - _LOGGER.info(f"Successfully adjusted the {self.CAUSAL_MASK_NAME} input") - - return model From 66a029877e4d84a782159c7118ef33bf28477b4a Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 25 Aug 2023 12:18:09 -0400 Subject: [PATCH 4/5] move back causal mask --- .../transforms/kv_cache/transforms_base.py | 109 ------------------ .../transforms/kv_cache/transforms_llama.py | 108 +++++++++++++++++ 2 files changed, 108 insertions(+), 109 deletions(-) diff --git a/src/sparseml/exporters/transforms/kv_cache/transforms_base.py b/src/sparseml/exporters/transforms/kv_cache/transforms_base.py index c9331800dfd..3f124ad5762 100644 --- a/src/sparseml/exporters/transforms/kv_cache/transforms_base.py +++ b/src/sparseml/exporters/transforms/kv_cache/transforms_base.py @@ -16,14 +16,11 @@ from copy import deepcopy from typing import Any, Dict, List, Optional -import numpy -import onnx from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper from sparseml.exporters.transforms.onnx_transform import OnnxTransform from sparseml.exporters.transforms.utils.matching import get_structural_matches from sparseml.onnx.utils.graph_editor import ONNXGraph -from sparseml.onnx.utils.helpers import get_nodes_by_input_id __all__ = ["AdditionalTransformsBase"] @@ -199,109 +196,3 @@ def _get_input_proto(self, model: ModelProto, input_name: str) -> ValueInfoProto f"{self.__name__} - unable to find '{input_name}' in model input" ) return input_proto - - def adjust_causal_mask(self, model: ModelProto) -> ModelProto: - """ - Insert a `Cast`, `Sub` and `Mul` nodes after the causal mask input to change - the initial int64, to a mask of floats expected by the model. - - Transform: - ``` - | causal_mask - | | - | causal_mask_input_child - ``` - to: - ``` - | causal_mask (1 and 0) - | | - | Cast (output -> 1.0 and 0.0) - | | - | Sub (output -> 0.0 and -1.0) - | | - | Mul (output -> 0.0 and numpy.finfo(numpy.float32).min) - | | - | causal_mask_input_child - - The resulting node will change the input int64 mask - e.g. - ``` - causal_mask = - [[[[1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1]]]] - ``` - - to a mask of floats: - ``` - x = numpy.finfo(numpy.float32).min - causal_mask_adjusted = - [[[[0.0, 0.0, 0.0, x, x, x], - [0.0, 0.0, 0.0, 0.0, x, x], - [0.0, 0.0, 0.0, 0.0, 0.0, x], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]] - ``` - - :param model: the model to update - :return: the updated model - """ - - graph = ONNXGraph(model) - - ones_initializer = onnx.helper.make_tensor( - name="ones_initializer", - data_type=onnx.TensorProto.FLOAT, - dims=[1], - vals=[1.0], - ) - - floating_point_limit_initializer = onnx.helper.make_tensor( - name="floating_point_limit_initializer", - data_type=onnx.TensorProto.FLOAT, - dims=[1], - vals=[-numpy.finfo(numpy.float32).min], - ) - - cast_node = onnx.helper.make_node( - "Cast", - inputs=[self.CAUSAL_MASK_NAME], - outputs=[f"{self.CAUSAL_MASK_NAME}_cast"], - to=onnx.TensorProto.FLOAT, - ) - - sub_node = onnx.helper.make_node( - "Sub", - inputs=[f"{self.CAUSAL_MASK_NAME}_cast", ones_initializer.name], - outputs=[f"{self.CAUSAL_MASK_NAME}_sub"], - ) - - mul_node = onnx.helper.make_node( - "Mul", - inputs=[ - f"{self.CAUSAL_MASK_NAME}_sub", - floating_point_limit_initializer.name, - ], - outputs=[f"{self.CAUSAL_MASK_NAME}_mul"], - ) - - new_nodes = [cast_node, sub_node, mul_node] - - # get the node that takes the causal mask as input - # and replace the input with the adjusted causal mask input - causal_mask_input_child = get_nodes_by_input_id(model, self.CAUSAL_MASK_NAME)[0] - - for idx, input_name in enumerate(causal_mask_input_child.input): - if input_name == self.CAUSAL_MASK_NAME: - causal_mask_input_child.input[idx] = f"{self.CAUSAL_MASK_NAME}_mul" - - for node in new_nodes: - graph.add_node(node) - self.log_match(node) - - model.graph.initializer.extend( - [ones_initializer, floating_point_limit_initializer] - ) - _LOGGER.info(f"Successfully adjusted the {self.CAUSAL_MASK_NAME} input") - - return model diff --git a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py index 707ca3db2b5..20322c3f9be 100644 --- a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py +++ b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py @@ -14,12 +14,14 @@ import logging import numpy +import onnx from onnx import ModelProto, numpy_helper from sparseml.exporters.transforms.kv_cache.transforms_base import ( AdditionalTransformsBase, ) from sparseml.onnx.utils.graph_editor import ONNXGraph +from sparseml.onnx.utils.helpers import get_nodes_by_input_id __all__ = ["AdditionalTransformsLLAMA"] @@ -115,3 +117,109 @@ def update_slice_nodes_for_positions_input(self, model: ModelProto) -> ModelProt model.graph.initializer.append(max_int_tensor) ONNXGraph(model).delete_orphaned_node_branches() return model + + def adjust_causal_mask(self, model: ModelProto) -> ModelProto: + """ + Insert a `Cast`, `Sub` and `Mul` nodes after the causal mask input to change + the initial int64, to a mask of floats expected by the model. + + Transform: + ``` + | causal_mask + | | + | causal_mask_input_child + ``` + to: + ``` + | causal_mask (1 and 0) + | | + | Cast (output -> 1.0 and 0.0) + | | + | Sub (output -> 0.0 and -1.0) + | | + | Mul (output -> 0.0 and numpy.finfo(numpy.float32).min) + | | + | causal_mask_input_child + + The resulting node will change the input int64 mask + e.g. + ``` + causal_mask = + [[[[1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1]]]] + ``` + + to a mask of floats: + ``` + x = numpy.finfo(numpy.float32).min + causal_mask_adjusted = + [[[[0.0, 0.0, 0.0, x, x, x], + [0.0, 0.0, 0.0, 0.0, x, x], + [0.0, 0.0, 0.0, 0.0, 0.0, x], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]] + ``` + + :param model: the model to update + :return: the updated model + """ + + graph = ONNXGraph(model) + + ones_initializer = onnx.helper.make_tensor( + name="ones_initializer", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=[1.0], + ) + + floating_point_limit_initializer = onnx.helper.make_tensor( + name="floating_point_limit_initializer", + data_type=onnx.TensorProto.FLOAT, + dims=[1], + vals=[-numpy.finfo(numpy.float32).min], + ) + + cast_node = onnx.helper.make_node( + "Cast", + inputs=[self.CAUSAL_MASK_NAME], + outputs=[f"{self.CAUSAL_MASK_NAME}_cast"], + to=onnx.TensorProto.FLOAT, + ) + + sub_node = onnx.helper.make_node( + "Sub", + inputs=[f"{self.CAUSAL_MASK_NAME}_cast", ones_initializer.name], + outputs=[f"{self.CAUSAL_MASK_NAME}_sub"], + ) + + mul_node = onnx.helper.make_node( + "Mul", + inputs=[ + f"{self.CAUSAL_MASK_NAME}_sub", + floating_point_limit_initializer.name, + ], + outputs=[f"{self.CAUSAL_MASK_NAME}_mul"], + ) + + new_nodes = [cast_node, sub_node, mul_node] + + # get the node that takes the causal mask as input + # and replace the input with the adjusted causal mask input + causal_mask_input_child = get_nodes_by_input_id(model, self.CAUSAL_MASK_NAME)[0] + + for idx, input_name in enumerate(causal_mask_input_child.input): + if input_name == self.CAUSAL_MASK_NAME: + causal_mask_input_child.input[idx] = f"{self.CAUSAL_MASK_NAME}_mul" + + for node in new_nodes: + graph.add_node(node) + self.log_match(node) + + model.graph.initializer.extend( + [ones_initializer, floating_point_limit_initializer] + ) + _LOGGER.info(f"Successfully adjusted the {self.CAUSAL_MASK_NAME} input") + + return model From 51e21967d17ad0b13ff970275dc52a15462859ff Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 28 Aug 2023 23:13:38 -0400 Subject: [PATCH 5/5] update pattern to identify correct slice nodes; move constant to class level --- .../transforms/kv_cache/transforms_llama.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py index 20322c3f9be..35610611f24 100644 --- a/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py +++ b/src/sparseml/exporters/transforms/kv_cache/transforms_llama.py @@ -33,6 +33,7 @@ class AdditionalTransformsLLAMA(AdditionalTransformsBase): POSITION_IDS_MATCHING_PATTERN = dict(op_type="Range", children_ops=[["Unsqueeze"]]) CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Expand", children_ops=[["Add"]]) + SLICE_MAX_INT_NAME = "slice_max_int" def transform(self, model: ModelProto) -> ModelProto: """ @@ -89,27 +90,24 @@ def update_slice_nodes_for_positions_input(self, model: ModelProto) -> ModelProt modeling_llama.py#L180. By updating the `ends` operator, access is allowed to the entire tables. - The Slice nodes are identified based on if they contain the `data` operator - as an input, which have the name `onnx::Slice_...`. Nodes with this name have - their `ends` operator updated to point to a 1x1 tensor containing the max - int value. + The Slice nodes are identified based on the `data` operator which does not have + a parent input (as identified using the `get_node_single_parent` function). :param model: model to update :return: updated model with Slice nodes in the attention heads updated """ - SLICE_MAX_INT_NAME = "slice_max_int" arr = numpy.array(numpy.iinfo(numpy.intp).max).reshape( 1, ) - max_int_tensor = numpy_helper.from_array(arr, name=SLICE_MAX_INT_NAME) + max_int_tensor = numpy_helper.from_array(arr, name=self.SLICE_MAX_INT_NAME) nodes_found = 0 for node in model.graph.node: if node.op_type == "Slice": - data = node.input[0] - if "onnx::" in data: - node.input[2] = SLICE_MAX_INT_NAME + data_parent = ONNXGraph(model).get_node_single_parent(node, 0) + if data_parent is not None and len(data_parent.input) == 0: nodes_found += 1 + node.input[2] = self.SLICE_MAX_INT_NAME self.log_match(node) _LOGGER.info(f"Found {nodes_found} slice nodes to update")