Skip to content

Commit

Permalink
Fix bloom KV cache usage in ORTForCausalLM (#1152)
Browse files Browse the repository at this point in the history
* fix bloom pkv usage with num_beams>1

* Update optimum/onnxruntime/modeling_decoder.py

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Update optimum/onnxruntime/modeling_decoder.py

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Update optimum/onnxruntime/modeling_decoder.py

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* remove transformers import

---------

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
  • Loading branch information
fxmarty and michaelbenayoun authored Jul 6, 2023
1 parent bc5f825 commit 2eab7ab
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 2 deletions.
56 changes: 54 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union

import torch
from huggingface_hub import hf_hub_download
Expand All @@ -36,6 +36,7 @@
from .base import ORTDecoder
from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN
from .modeling_ort import ORTModel
from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache
from .utils import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
Expand Down Expand Up @@ -315,6 +316,7 @@ def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
init_cls: Type["ORTModelDecoder"],
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
Expand Down Expand Up @@ -514,7 +516,7 @@ def _from_pretrained(
else:
onnx_paths.append(decoder_merged_path)

return cls(
return init_cls(
ort_inference_sessions[0],
config,
decoder_with_past_session=ort_inference_sessions[1],
Expand Down Expand Up @@ -695,3 +697,53 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True

@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
**kwargs,
):
if config.model_type == "bloom":
return super()._from_pretrained(model_id, config, init_cls=ORTBloomForCausalLM, **kwargs)
return super()._from_pretrained(model_id, config, init_cls=ORTModelForCausalLM, **kwargs)


class ORTBloomForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

# only last token for input_ids if past is not None
if past_key_values:
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = bloom_convert_to_bloom_cache(past_key_values)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}

# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
standardized_past = bloom_convert_to_standard_cache(past, batch_size=len(beam_idx))

# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
)
return bloom_convert_to_bloom_cache(reordered_past)
1 change: 1 addition & 0 deletions optimum/onnxruntime/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

44 changes: 44 additions & 0 deletions optimum/onnxruntime/models/bloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import TYPE_CHECKING, Tuple


if TYPE_CHECKING:
import torch


def bloom_convert_to_standard_cache(
past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]], batch_size: int
) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)


def bloom_convert_to_bloom_cache(
past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]]
) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)

0 comments on commit 2eab7ab

Please sign in to comment.