Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix the KV Cache insertion logic for quantized OPT #1648

Merged
merged 13 commits into from
Jul 19, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

_LOGGER = logging.getLogger(__name__)

ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast"]
ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast", "QuantizeLinear"]
OUTPUT_CACHE_NAME = """present.{attention_layer_idx}.{cache_type}"""
INPUT_CACHE_NAME = """past_key_values.{attention_layer_idx}.{cache_type}"""

Expand Down Expand Up @@ -685,15 +685,26 @@ def _find_key_matmul_from_value_matmul(
def _value_input_idx(value_matmul: NodeProto, model: ModelProto) -> int:
graph = ONNXGraph(model)
# get idx of matmul that the value node is an input of
if len(value_matmul.input) != 2:
expected_num_inputs = 4 if value_matmul.op_type == "MatMulInteger" else 2

if len(value_matmul.input) != expected_num_inputs:
raise ValueError(
f"Expected value matmul to have 2 inputs, got {len(value_matmul.input)}"
f"Expected value matmul to have {expected_num_inputs} "
f"inputs, got {len(value_matmul.input)}"
)

softmax_input_idx = 0 # default to softmax being on left hand side
for idx, parent in enumerate(graph.get_node_parents(value_matmul)):
if isinstance(parent, NodeProto) and parent.op_type == "Softmax":
softmax_input_idx = idx
break
if isinstance(parent, NodeProto):
# if a parent is a softmax or the parent of value matmul is a direct
# child of a softmax (quantized scenario), then the softmax is the
# idx'th input to the value matmul
if (
parent.op_type == "Softmax"
or graph.get_node_single_parent(parent, 0).op_type == "Softmax"
):
softmax_input_idx = idx
break
return 1 - softmax_input_idx # return index that isn't the softmax


Expand Down
Loading