Skip to content

Commit

Permalink
Export of Openai Whisper with batched prompts (#19854)
Browse files Browse the repository at this point in the history
Adds an example to demonstrate the export of openai whipser
implemenation with batch_size > 1 and addition of prompts for each audio
snippet.

Also handles the scenario for when prompts are not of the same size. For
example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1], the
final decoder_input_ids will look as such after padding:
`[prev_token, p1_id_1, p1_id_2, start_token, lang_token,
transcribe_token]
[prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token,
transcribe_token]`

---------

Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
  • Loading branch information
shubhambhokare1 and kunal-vaishnavi committed Apr 3, 2024
1 parent 19793de commit be831e1
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,13 @@ def main(argv=None):
# Wrap parity check in try-except to allow export to continue in case this produces an error
try:
with torch.no_grad():
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
# Verify batched decoding with prompts for whisper openai implementation
if args.model_impl == "openai" and args.use_forced_decoder_ids:
max_diff = WhisperHelper.verify_onnx(
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
)
else:
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def chain_model(args):
decoder_model = onnx.load_model(args.decoder_path, load_external_data=True)
decoder_model.graph.name = "decoder subgraph"

config = WhisperConfig.from_pretrained(args.model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path)
config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)

# Create inputs/outputs for WhisperBeamSearch op
temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def create_dummy(
device: torch.device,
float16: bool = False,
use_int32_inputs: bool = False,
model_impl: str = "hf",
): # -> WhisperDecoderInputs:
"""Create dummy inputs for WhisperDecoder.
Expand Down Expand Up @@ -170,7 +171,7 @@ def create_dummy(
cross_attention_past_shape = [
batch_size,
num_attention_heads,
encode_sequence_length,
encode_sequence_length if model_impl == "hf" else past_decode_sequence_length,
head_size,
]

Expand Down Expand Up @@ -228,6 +229,7 @@ def export_onnx(
past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0,
device=device,
use_int32_inputs=use_int32_inputs,
model_impl=decoder.model_impl,
)
input_list = inputs.to_list()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# license information.
# --------------------------------------------------------------------------

import copy
import logging
import os
import tempfile
Expand Down Expand Up @@ -51,12 +50,15 @@ def forward(
self,
encoder_input_ids: torch.Tensor,
decoder_input_ids: torch.Tensor = None,
remove_hooks: bool = False,
):
encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids)
# Decoder out: (logits, past_key_values, encoder_hidden_state)
if self.model_impl == "openai":
encoder_hidden_states.unsqueeze(0)
decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states)
decinit_out, present = self.whisper_decoder_openai_init(
decoder_input_ids, encoder_hidden_states, remove_hooks=remove_hooks
)
return decinit_out, encoder_hidden_states, present
else:
decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states)
Expand Down Expand Up @@ -131,9 +133,7 @@ def export_onnx(
)
input_list = inputs.to_list()

# TODO : Investigate whether copy of model if needed
cloned_model = copy.deepcopy(model).to(device)
out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids)
out = model(inputs.encoder_input_ids, inputs.decoder_input_ids, remove_hooks=True)
present = out[2]
present_names = PastKeyValuesHelper.get_input_names(present, encoder=True)

Expand Down
192 changes: 140 additions & 52 deletions onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,22 +314,13 @@ def optimize_onnx(
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)

@staticmethod
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
def pt_transcription_for_verify_onnx(
processor: WhisperProcessor,
pt_model: torch.nn.Module,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
pt_model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path)
config = WhisperConfig.from_pretrained(model_name_or_path)

# Try to import `datasets` pip package
try:
from datasets import load_dataset
Expand All @@ -342,14 +333,18 @@ def verify_onnx(
from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features

start_id = [config.decoder_start_token_id] # ex: [50258]
prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]

batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1
input_features_ = []
if batch_size == 1:
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
else:
input_features_ = [
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
]
assert len(input_features_) == batch_size
input_features = torch.cat((input_features_[0], input_features_[1]))

max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1
length_penalty, repetition_penalty = 1.0, 1.0
inputs = {
"input_features": input_features.to(device),
Expand All @@ -362,10 +357,97 @@ def verify_onnx(
"early_stopping": True,
"use_cache": True,
}
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()

if prompt_mode:
prompts = ["John has doubts", "Maria has grave doubts"]
prompt_ids = [processor.get_prompt_ids(p) for p in prompts]
pt_transcription = []
pt_outputs = []
# The looping for model.generate is necessary here due to the limitation as per
# https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.prompt_ids
# prompt_ids input requires a tensor of rank 1
for i in range(batch_size):
inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i])
inputs["input_features"] = input_features_[i].to(device)
pt_output = pt_model.generate(**inputs).detach().cpu().numpy()
pt_outputs.append(pt_output)
pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0])
inputs["input_features"] = input_features
del inputs["prompt_ids"]
else:
prompt_ids = []
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]]
pt_outputs = list(pt_outputs)
del inputs["early_stopping"]
del inputs["use_cache"]
return inputs, pt_transcription, pt_outputs, prompt_ids

@staticmethod
def select_transcription_options(
batch_size: int,
prompt_mode: bool,
):
if batch_size > 1 and prompt_mode:
expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky"
expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_options = {
expected_transcription_no_comma_prompt1,
expected_transcription_no_comma_prompt2,
expected_transcription_misspelled_prompt1,
expected_transcription_misspelled_prompt2,
}
else:
expected_transcription_no_comma = (
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)
expected_transcription_with_comma = (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
)
expected_transcription_with_quote_and_comma = (
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
)
expected_transcription_options = {
expected_transcription_no_comma,
expected_transcription_with_comma,
expected_transcription_with_quote_and_comma,
}
return expected_transcription_options

@staticmethod
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
pt_model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir)
config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)

inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx(
processor,
pt_model,
device,
batch_size=batch_size,
prompt_mode=prompt_mode,
)

start_id = [config.decoder_start_token_id] # ex: [50258]
prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]

ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
ort_to_np = {
Expand All @@ -386,8 +468,24 @@ def verify_onnx(
elif name == "prefix_vocab_mask":
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
elif name == "decoder_input_ids":
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
if not prompt_mode:
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
else:
# This logic handles the scenario for when prompts are not of the same size
# For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1]
# The final decoder_input_ids will look as such after padding
# [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token]
# [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]
ort_prompts = []
for i in range(batch_size):
ort_prompts.append(decoder_prompt_ids[i].tolist())
max_len = max(len(p) for p in ort_prompts)
padded_prompts = []
for p in ort_prompts:
padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))]
padded_prompts.append(padded_prompt + forced_decoder_ids)
inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype])
elif name == "logits_processor":
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
elif name == "cross_qk_layer_head":
Expand All @@ -398,36 +496,26 @@ def verify_onnx(
inputs[name] = np.array([1.0], dtype=ort_to_np[dtype])
else:
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
ort_outputs = ort_session.run(None, inputs)[0][0]

expected_transcription_no_comma = (
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)
expected_transcription_with_comma = (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
)
expected_transcription_with_quote_and_comma = (
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
)
expected_transcription_options = {
expected_transcription_no_comma,
expected_transcription_with_comma,
expected_transcription_with_quote_and_comma,
}
pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]
ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0]

parity = (
pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options
)
ort_outputs = ort_session.run(None, inputs)[0][:, 0, :]
ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)
expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode)

parity = 1
for i in range(batch_size):
parity *= (
pt_transcription[i] in expected_transcription_options
and ort_transcription[i] in expected_transcription_options
)
max_diff = 0

if not parity:
if pt_outputs.shape != ort_outputs.shape:
diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])]
else:
diff = pt_outputs - ort_outputs
max_diff = max(diff.min(), diff.max(), key=abs)
for i in range(batch_size):
if pt_outputs[i].shape != ort_outputs[i].shape:
diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])]
else:
diff = pt_outputs[i] - ort_outputs[i]
max_diff_i = max(diff.min(), diff.max(), key=abs)
max_diff = max(max_diff, max_diff_i)

if max_diff != 0:
logger.warning(f"PyTorch outputs: {pt_transcription}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(
tokens,
audio_features,
past=None,
remove_hooks=False,
):
# Create a kv_cache for past_values
past_kv_cache = dict()
Expand All @@ -44,8 +45,9 @@ def forward(
past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx]
past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1]

hooks = None
if not self.kv_cache:
self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks()
self.kv_cache, hooks = self.whisper_model.install_kv_cache_hooks()

logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache)

Expand Down Expand Up @@ -73,4 +75,10 @@ def forward(
present_self = [
present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self
]

# Remove forward hooks to avoid model cloning step
if hooks is not None and remove_hooks:
self.kv_cache = {}
for hook in hooks:
hook.remove()
return logits, present_self

0 comments on commit be831e1

Please sign in to comment.