Skip to content

Commit

Permalink
[KV Cache Injection] Causal Mask for OPT (#1688)
Browse files Browse the repository at this point in the history
* initial implementation; testing now

* fix a small blunder

* cleanup

* initial implementation

* on to testing with deepsparse

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
  • Loading branch information
dbogunowicz and bogunowicz@arrival.com committed Jul 26, 2023
1 parent d927e9f commit db62ca0
Showing 1 changed file with 117 additions and 91 deletions.
208 changes: 117 additions & 91 deletions src/sparseml/exporters/transforms/kv_cache/transforms_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,122 +12,148 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from onnx import ModelProto, NodeProto
import numpy
import onnx
from onnx import ModelProto

from sparseml.exporters.transforms.kv_cache.transforms_base import (
AdditionalTransformsBase,
)
from sparseml.onnx.utils import ONNXGraph
from sparseml.onnx.utils.graph_editor import ONNXGraph
from sparseml.onnx.utils.helpers import get_nodes_by_input_id


__all__ = ["AdditionalTransformsOPT"]


# name position embeddings weights
_EMBED_POSITIONS_ID = "model.decoder.embed_positions.weight"


class AdditionalTransformsOPT(AdditionalTransformsBase):
"""
Base class for model architecture specific transforms to adjust graph
to take input_id positions as an argument rather than computing them
based on input. This provides a better source of truth rather than
computing in a static graph where factors such as number of tokens,
cache size, and padding may affect accurate, efficient static
computation of the position indices.
Positions should be the same shape as input_ids where each value
is the corresponding integer position of the token id in the overall
sequence. Padding tokens are not counted towards positions and
should be inputted as 0.
When running a model for a single input id with `n` previously
processed tokens (prompt seq len + number of tokens generated already)
This transform will replace the input to the position embeddings gather
by an explicit onnx graph input. Will delete any operations
used to compute the positions that are now no longer used in the
graph. Optionally keeps an offset `Add` node that is unique to the
OPT graph.
Transforms
```
| Graph computed positions
| |
| Add (Optional)
| |
| Gather(model.decoder.embed_positions.weight)
```
Into
```
| Explicit Graph input (deletes now orphaned nodes to compute positions)
| |
| Add (Optional)
| |
| Gather(model.decoder.embed_positions.weight)
```
1. Adds `positions` as an input to the model
2. Adds `causal_mask` as an input to the model
2. Finds the node that initially creates the `embed_position.weight` tensor
3. Updates the node to use the positions input instead of
computing it from the Sub op
4. Finds the nodes that initially create the `causal_mask` tensors
5. Updates the nodes to use the causal_mask input instead of
computing it from the Expand op
:param model: model to update
:return: updated model
"""

POSITION_EMBEDDINGS_IDS_MATCHING_PATTERN = dict(
op_type="Sub",
children_ops=[["Add"]],
)
CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Expand", children_ops=[["Add"]])

def transform(self, model: ModelProto) -> ModelProto:
model = self.add_positions_input(model)
position_embeddings_node = self.find_embed_positions_gather_node(model)
model = self._update_position_embeddings_for_graph_input(
model, position_embeddings_node
model = self.add_causal_mask_input(model)

position_embeddings_nodes = self.find_nodes_by_pattern(
model, pattern=self.POSITION_EMBEDDINGS_IDS_MATCHING_PATTERN
)
return model
if len(position_embeddings_nodes) != 1:
raise ValueError(
"Expected to find exactly one node matching "
f"the pattern {self.POSITION_EMBEDDINGS_IDS_MATCHING_PATTERN}, "
f"found {len(position_embeddings_nodes)}"
)
model = self.inject_positions(model, position_embeddings_nodes, "Add")

@classmethod
def find_embed_positions_gather_node(cls, model: ModelProto) -> NodeProto:
for node in model.graph.node:
if node.op_type != "Gather":
continue
if node.input[0] == _EMBED_POSITIONS_ID:
# found the embed_positions_gather_node
return node
raise RuntimeError(
f"Unable to find position embeddings gather node with id "
f"{_EMBED_POSITIONS_ID} in {cls.__name__}"
causal_mask_nodes = self.find_nodes_by_pattern(
model, pattern=self.CAUSAL_MASK_MATCHING_PATTERN
)
model = self.inject_causal_mask(model, causal_mask_nodes, "Add")
model = self.adjust_causal_mask(model)
return model

def adjust_causal_mask(self, model: ModelProto) -> ModelProto:
"""
Insert a `Where`` node after the causal mask input to change
the initial boolean mask, to a mask of floats expected by the model.
Transform:
```
| causal_mask
| |
| causal_mask_input_child
```
to:
```
| causal_mask
| |
| Where
| |
| causal_mask_input_child
The resulting node will change the input boolean mask
e.g.
```
causal_mask =
[[[[True, True, True, False, False, False],
[True, True, True, True, False, False],
[True, True, True, True, True, False],
[True, True, True, True, True, True]]]]
```
to a mask of floats:
```
x = numpy.finfo(numpy.float32).min
causal_mask_adjusted =
[[[[0.0, 0.0, 0.0, x, x, x],
[0.0, 0.0, 0.0, 0.0, x, x],
[0.0, 0.0, 0.0, 0.0, 0.0, x],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]]
```
:param model: the model to update
:return: the updated model
"""

def _update_position_embeddings_for_graph_input(
self, model: ModelProto, position_embeddings_node: NodeProto
) -> ModelProto:
graph = ONNXGraph(model)

# select target node to update as positions input
position_embeddings_parent = graph.get_node_single_parent(
position_embeddings_node, index=1
condition_true = 0.0
condition_false = numpy.finfo(numpy.float32).min

condition_true_initializer = onnx.helper.make_tensor(
name="condition_true",
data_type=onnx.TensorProto.FLOAT,
# TODO: how to set proper dims?
# This works, but is kinda ugly
dims=[1, 1, 1, 1],
vals=[condition_true],
)

if not isinstance(position_embeddings_parent, NodeProto):
raise RuntimeError(
f"Unable to find input to position embeddings node: "
f"{position_embeddings_node.name} as a node in the given model"
)
condition_false_initializer = onnx.helper.make_tensor(
name="condition_false",
data_type=onnx.TensorProto.FLOAT,
dims=[1, 1, 1, 1],
vals=[condition_false],
)

if position_embeddings_parent.op_type == "Add":
# OPT has a special Add offset for position ids, allow this
# to be where positions are fed instead
target_update_node = position_embeddings_parent
target_input_idx = 0 # assume positions are first input to the Add
else:
target_update_node = position_embeddings_node
target_input_idx = 1 # gather idxs

# reroute target node input to the positions graph input
old_positions_input = target_update_node.input[target_input_idx]
target_update_node.input[target_input_idx] = self.POSITIONS_NAME
graph.update()
self.log_match(target_update_node)

nodes_to_delete = graph.find_orphaned_nodes(
graph.get_node_by_output_id(old_positions_input)
where_node_output = f"{self.CAUSAL_MASK_NAME} adjusted"

where_node = onnx.helper.make_node(
"Where",
inputs=[self.CAUSAL_MASK_NAME, "condition_true", "condition_false"],
outputs=[where_node_output],
)

graph.add_node(where_node)
model.graph.initializer.extend(
[condition_false_initializer, condition_true_initializer]
)
[self.log_match(node) for node in nodes_to_delete]

graph.delete_nodes(nodes_to_delete)
graph.update()
graph.delete_unused_initializers()
# get the node that takes the causal mask as input
# and replace the input with the adjusted causal mask input
causal_mask_input_child = get_nodes_by_input_id(model, self.CAUSAL_MASK_NAME)[0]

for idx, input_name in enumerate(causal_mask_input_child.input):
if input_name == self.CAUSAL_MASK_NAME:
causal_mask_input_child.input[idx] = where_node_output

return model

0 comments on commit db62ca0

Please sign in to comment.