Skip to content

Commit

Permalink
[KV Cache Injection] Causal Mask for CodeGen (#1676)
Browse files Browse the repository at this point in the history
* initial implementation; testing now

* fix a small blunder

* cleanup

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
  • Loading branch information
dbogunowicz and bogunowicz@arrival.com authored Jul 25, 2023
1 parent 0005056 commit 9368269
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 63 deletions.
6 changes: 3 additions & 3 deletions src/sparseml/exporters/transforms/kv_cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# isort:skip_file

from .cache_keys_and_values import *
from .positions_adjustment_base import *
from .positions_adjustment_opt import *
from .positions_adjustment_codegen import *
from .transforms_base import *
from .transforms_opt import *
from .transforms_codegen import *
from .configs import *
173 changes: 158 additions & 15 deletions src/sparseml/exporters/transforms/kv_cache/transforms_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,182 @@
# limitations under the License.

from copy import deepcopy
from typing import Any, Dict, List, Optional

from onnx import ModelProto
from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper

from sparseml.exporters.transforms.onnx_transform import OnnxTransform
from sparseml.exporters.transforms.utils.matching import get_structural_matches
from sparseml.onnx.utils.graph_editor import ONNXGraph


__all__ = ["AdditionalTransformsBase"]


class AdditionalTransformsBase(OnnxTransform):

POSITIONS_NAME = "positions" # matches intermediate var name in torch
POSITIONS_NAME = "positions"
CAUSAL_MASK_NAME = "causal_mask"

@classmethod
def add_positions_input(cls, model: ModelProto) -> ModelProto:
def add_causal_mask_input(self, model: ModelProto) -> ModelProto:
"""
Adds positions as an input to the model
Adds causal mask as an input to the model.
Causal mask is a boolean tensor of shape
[batch_size, 1, input_ids_length, sequence_length]
where the value is False if the position is masked and True
otherwise.
:param model: model to update
:return: updated model
"""
input_ids = self._get_input_proto(model, "input_ids")
attention_mask = self._get_input_proto(model, "attention_mask")

batch_size = input_ids.type.tensor_type.shape.dim[0].dim_param
input_ids_length = input_ids.type.tensor_type.shape.dim[1].dim_value
sequence_length = attention_mask.type.tensor_type.shape.dim[1].dim_param

causal_mask_input = helper.make_tensor_value_info(
name=self.CAUSAL_MASK_NAME,
elem_type=TensorProto.BOOL,
shape=[batch_size, 1, input_ids_length, sequence_length],
)
model.graph.input.append(causal_mask_input)
return model

def add_positions_input(self, model: ModelProto) -> ModelProto:
"""
Adds positions as an input to the model.
Positions is a tensor of shape and dtype
equal to input_ids.
:param model: model to update
:return: updated model
"""
# positions tensor should have shape equal to input_ids
input_ids_info = [
input_ids = self._get_input_proto(model, "input_ids")
positions_input = deepcopy(input_ids)
positions_input.name = self.POSITIONS_NAME
model.graph.input.append(positions_input)
return model

def find_nodes_by_pattern(
self, model: ModelProto, pattern: Dict[str, Any]
) -> List[NodeProto]:
"""
Find the all the nodes in the `model` that
match the specified `pattern`.
:param model: the ONNX model
:param pattern: a dictionary of arguments and variables
expected by the `get_structural_matches` function. For
more information, see the documentation for that function.
:return: a list of nodes that match the pattern
"""
graph = ONNXGraph(model)
matches = get_structural_matches(graph, **pattern)
if not matches:
raise ValueError(f"Unable to find pattern:\n{pattern}\nin model")
return [match.node for match in matches]

def inject_causal_mask(
self,
model: ModelProto,
nodes: List[NodeProto],
nodes_parent_op_type: Optional[str] = None,
) -> ModelProto:
"""
Injects causal mask to the graph, replacing the specified nodes.
:param model: the ONNX model to inject the causal mask into
:param nodes: the nodes to replace with the causal mask
:param nodes_parent_op_type: the parent op type of the nodes to replace
:return: the updated model
"""

return self.swap_nodes_for_input(
model, nodes, self.CAUSAL_MASK_NAME, nodes_parent_op_type
)

def inject_positions(
self,
model: ModelProto,
nodes: List[NodeProto],
nodes_parent_op_type: Optional[str] = None,
) -> ModelProto:
"""
Injects positions to the graph, replacing the specified nodes.
:param model: the ONNX model to inject the positions into
:param nodes: the nodes to replace with the positions
:param nodes_parent_op_type: the parent op type of the nodes to replace
:return: the updated model
"""

return self.swap_nodes_for_input(
model, nodes, self.POSITIONS_NAME, nodes_parent_op_type
)

def swap_nodes_for_input(
self,
model: ModelProto,
nodes: List[NodeProto],
input_name: str,
nodes_parent_op_type: Optional[str] = None,
) -> ModelProto:

"""
Injects the specified input to the graph, replacing the specified nodes.
:param model: the ONNX model to inject the input into
:param nodes: the nodes to replace with the input
:param input_name: the name of the input to replace the nodes with
:param nodes_parent_op_type: the parent op type of the nodes to replace
:return: the updated model
"""

graph = ONNXGraph(model)
orphaned_nodes = []
for node in nodes:
child_node = graph.get_node_children(node)[0]

if nodes_parent_op_type:
assert child_node.op_type == nodes_parent_op_type, (
f"Expected to find {nodes_parent_op_type} node, "
f"found {child_node.op_type}"
)
output_to_replace = node.output[0]
self.log_match(node)
for idx, input_name_child_node in enumerate(child_node.input):
if input_name_child_node == output_to_replace:
graph.update_node_input(child_node, input_name, idx)

orphaned_nodes.extend(graph.find_orphaned_nodes(node))

graph.delete_nodes(orphaned_nodes)
graph.update()
graph.delete_unused_initializers()

return model

def _get_input_proto(self, model: ModelProto, input_name: str) -> ValueInfoProto:
"""
Get the input proto for the specified input name.
:param model: the ONNX model
:param input_name: the name of the input
"""
input_proto = [
input_info
for input_info in model.graph.input
if input_info.name == "input_ids"
if input_info.name == input_name
][0]
if not input_ids_info:
if not input_proto:
raise RuntimeError(
f"{cls.__name__} - unable to find 'input_ids' in model input"
f"{self.__name__} - unable to find '{input_name}' in model input"
)

positions_input_info = deepcopy(input_ids_info)
positions_input_info.name = cls.POSITIONS_NAME
model.graph.input.append(positions_input_info)
return model
return input_proto
67 changes: 22 additions & 45 deletions src/sparseml/exporters/transforms/kv_cache/transforms_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,78 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from onnx import ModelProto, NodeProto

from onnx import ModelProto

from sparseml.exporters.transforms.kv_cache.transforms_base import (
AdditionalTransformsBase,
)
from sparseml.exporters.transforms.utils.matching import get_structural_matches
from sparseml.onnx.utils import ONNXGraph


__all__ = ["AdditionalTransformsCodeGen"]


class AdditionalTransformsCodeGen(AdditionalTransformsBase):

# The pattern that matches the node that creates
# the `position_ids` tensor
POSITION_IDS_MATCHING_PATTERN = dict(op_type="Range")
# The patterns that match nodes that create
# the `position_ids` and `causal_mask` tensors
POSITION_IDS_MATCHING_PATTERN = dict(op_type="Range", children_ops=[["Unsqueeze"]])
CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Slice", children_ops=[["Where"]])

def transform(self, model: ModelProto) -> ModelProto:
"""
1. Adds `positions` as an input to the model
2. Adds `causal_mask` as an input to the model
2. Finds the node that initially creates the `position_ids` tensor
3. Updates the node to use the positions input instead of
computing it from the Range op
4. Finds the nodes that initially create the `causal_mask` tensors
5. Updates the nodes to use the causal_mask input instead of
computing it from the Slice op
:param model: model to update
:return: updated model
"""
model = self.add_positions_input(model)
position_ids_node = self.find_position_ids_range_node(model)
model = self._update_position_embeddings_for_graph_input(
model, position_ids_node
)
return model
model = self.add_causal_mask_input(model)

def find_position_ids_range_node(self, model: ModelProto) -> NodeProto:
"""
Find the node that creates the `position_ids` tensor
:param model: the ONNX model
:return: the node that creates the `position_ids` tensor
"""
graph = ONNXGraph(model)
position_ids_node = get_structural_matches(
graph, **self.POSITION_IDS_MATCHING_PATTERN
position_ids_nodes = self.find_nodes_by_pattern(
model, pattern=self.POSITION_IDS_MATCHING_PATTERN
)
if len(position_ids_node) != 1:
if len(position_ids_nodes) != 1:
raise ValueError(
f"Expected to find 1 position node, found {len(position_ids_node)}"
"Expected to find exactly one node matching "
f"the pattern {self.POSITION_IDS_MATCHING_PATTERN}, "
f"found {len(position_ids_nodes)}"
)
return position_ids_node[0].node

def _update_position_embeddings_for_graph_input(
self, model: ModelProto, position_embeddings_ids_node: NodeProto
) -> ModelProto:

graph = ONNXGraph(model)
node = position_embeddings_ids_node
child_node = graph.get_node_children(node)[0]
# child_node is the `Unsqueeze` node
assert (
child_node.op_type == "Unsqueeze"
), f"Expected to find `Unsqueeze` node, found {child_node.op_type}"
output_to_replace = node.output[0]
self.log_match(child_node)
for idx, input_name in enumerate(child_node.input):
if input_name == output_to_replace:
graph.update_node_input(child_node, self.POSITIONS_NAME, idx)

orphaned_nodes = graph.find_orphaned_nodes(node)
[self.log_match(node) for node in orphaned_nodes]
graph.delete_nodes(orphaned_nodes)
graph.update()
graph.delete_unused_initializers()
model = self.inject_positions(model, position_ids_nodes, "Unsqueeze")

causal_mask_nodes = self.find_nodes_by_pattern(
model, pattern=self.CAUSAL_MASK_MATCHING_PATTERN
)
model = self.inject_causal_mask(model, causal_mask_nodes, "Where")
return model

0 comments on commit 9368269

Please sign in to comment.