Skip to content

Commit

Permalink
update pattern to identify correct slice nodes; move constant to clas…
Browse files Browse the repository at this point in the history
…s level
  • Loading branch information
dsikka committed Aug 29, 2023
1 parent 66a0298 commit 51e2196
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/sparseml/exporters/transforms/kv_cache/transforms_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class AdditionalTransformsLLAMA(AdditionalTransformsBase):

POSITION_IDS_MATCHING_PATTERN = dict(op_type="Range", children_ops=[["Unsqueeze"]])
CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Expand", children_ops=[["Add"]])
SLICE_MAX_INT_NAME = "slice_max_int"

def transform(self, model: ModelProto) -> ModelProto:
"""
Expand Down Expand Up @@ -89,27 +90,24 @@ def update_slice_nodes_for_positions_input(self, model: ModelProto) -> ModelProt
modeling_llama.py#L180.
By updating the `ends` operator, access is allowed to the entire tables.
The Slice nodes are identified based on if they contain the `data` operator
as an input, which have the name `onnx::Slice_...`. Nodes with this name have
their `ends` operator updated to point to a 1x1 tensor containing the max
int value.
The Slice nodes are identified based on the `data` operator which does not have
a parent input (as identified using the `get_node_single_parent` function).
:param model: model to update
:return: updated model with Slice nodes in the attention heads updated
"""
SLICE_MAX_INT_NAME = "slice_max_int"
arr = numpy.array(numpy.iinfo(numpy.intp).max).reshape(
1,
)
max_int_tensor = numpy_helper.from_array(arr, name=SLICE_MAX_INT_NAME)
max_int_tensor = numpy_helper.from_array(arr, name=self.SLICE_MAX_INT_NAME)

nodes_found = 0
for node in model.graph.node:
if node.op_type == "Slice":
data = node.input[0]
if "onnx::" in data:
node.input[2] = SLICE_MAX_INT_NAME
data_parent = ONNXGraph(model).get_node_single_parent(node, 0)
if data_parent is not None and len(data_parent.input) == 0:
nodes_found += 1
node.input[2] = self.SLICE_MAX_INT_NAME
self.log_match(node)

_LOGGER.info(f"Found {nodes_found} slice nodes to update")
Expand Down

0 comments on commit 51e2196

Please sign in to comment.