diff --git a/src/transformers/models/clvp/processing_clvp.py b/src/transformers/models/clvp/processing_clvp.py index 4e015cea1f8475..ebccab89d0fca3 100644 --- a/src/transformers/models/clvp/processing_clvp.py +++ b/src/transformers/models/clvp/processing_clvp.py @@ -73,7 +73,6 @@ def __call__(self, *args, **kwargs): inputs["attention_mask"] = encodings["attention_mask"] return inputs - # Copied from transformers.models.whisper.processing_whisper.WhisperProcessor.batch_decode with Whisper->Clvp def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to ClvpTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please diff --git a/src/transformers/models/whisper/processing_whisper.py b/src/transformers/models/whisper/processing_whisper.py index f22aae143e6bc4..07ece4314b249b 100644 --- a/src/transformers/models/whisper/processing_whisper.py +++ b/src/transformers/models/whisper/processing_whisper.py @@ -84,6 +84,13 @@ def batch_decode(self, *args, **kwargs): This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ + + # If segments are present in args, we are performing long-form generation and need to return long form timestamps. + # The long-form timestamps are already present in segments and should be passed as kwargs to batch_decode. + if isinstance(args[0], dict) and "segments" in args[0]: + kwargs["longform_timestamps"] = args[0].pop("segments") + args = tuple(args[0]["sequences"].unsqueeze(0)) + return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index f19f218569a293..6b94619add8b00 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -558,7 +558,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre ] return "".join(outputs) - def _compute_offsets(self, token_ids, time_precision=0.02): + def _compute_offsets(self, token_ids, time_precision=0.02, longform_timestamps=None): """ Compute offsets for a given tokenized input @@ -567,6 +567,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. + longform_timestamps (List[dict], *optional*): + Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -587,7 +589,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02): consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1) last_slice = np.where(timestamp_tokens)[0][0] - for current_slice in consecutive: + for i, current_slice in enumerate(consecutive): sliced_tokens = token_ids[last_slice:current_slice] if len(sliced_tokens) > 1: start_timestamp_position = sliced_tokens[0].item() - timestamp_begin @@ -596,15 +598,27 @@ def _compute_offsets(self, token_ids, time_precision=0.02): sliced_tokens = self._preprocess_token_ids(sliced_tokens) text = self._decode(sliced_tokens) text = self._filter_timestamp_ids(text) - offsets.append( - { - "text": text, - "timestamp": ( - start_timestamp_position * time_precision, - end_timestamp_position * time_precision, - ), - } - ) + + if longform_timestamps is not None: + offsets.append( + { + "text": text, + "timestamp": ( + longform_timestamps[0][i]["start"].item(), + longform_timestamps[0][i]["end"].item(), + ), + } + ) + else: + offsets.append( + { + "text": text, + "timestamp": ( + start_timestamp_position * time_precision, + end_timestamp_position * time_precision, + ), + } + ) last_slice = current_slice return offsets @@ -713,7 +727,11 @@ def decode( # retrieve offsets if output_offsets: - offsets = self._compute_offsets(token_ids, time_precision=time_precision) + longform_timestamps = kwargs.get("longform_timestamps") + offsets = self._compute_offsets( + token_ids, time_precision=time_precision, longform_timestamps=longform_timestamps + ) + return {"text": text, "offsets": offsets} return text diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index d1205d1a8ec01b..540056df8bd807 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -200,7 +200,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre return "".join(outputs) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets - def _compute_offsets(self, token_ids, time_precision=0.02): + def _compute_offsets(self, token_ids, time_precision=0.02, longform_timestamps=None): """ Compute offsets for a given tokenized input @@ -209,6 +209,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. + longform_timestamps (List[dict], *optional*): + Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -229,7 +231,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02): consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1) last_slice = np.where(timestamp_tokens)[0][0] - for current_slice in consecutive: + for i, current_slice in enumerate(consecutive): sliced_tokens = token_ids[last_slice:current_slice] if len(sliced_tokens) > 1: start_timestamp_position = sliced_tokens[0].item() - timestamp_begin @@ -238,15 +240,27 @@ def _compute_offsets(self, token_ids, time_precision=0.02): sliced_tokens = self._preprocess_token_ids(sliced_tokens) text = self._decode(sliced_tokens) text = self._filter_timestamp_ids(text) - offsets.append( - { - "text": text, - "timestamp": ( - start_timestamp_position * time_precision, - end_timestamp_position * time_precision, - ), - } - ) + + if longform_timestamps is not None: + offsets.append( + { + "text": text, + "timestamp": ( + longform_timestamps[0][i]["start"].item(), + longform_timestamps[0][i]["end"].item(), + ), + } + ) + else: + offsets.append( + { + "text": text, + "timestamp": ( + start_timestamp_position * time_precision, + end_timestamp_position * time_precision, + ), + } + ) last_slice = current_slice return offsets @@ -359,7 +373,11 @@ def decode( # retrieve offsets if output_offsets: - offsets = self._compute_offsets(token_ids, time_precision=time_precision) + longform_timestamps = kwargs.get("longform_timestamps") + offsets = self._compute_offsets( + token_ids, time_precision=time_precision, longform_timestamps=longform_timestamps + ) + return {"text": text, "offsets": offsets} return text diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 5fc66f9a20551d..688ee4f6656d92 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2001,6 +2001,72 @@ def test_tiny_timestamp_generation(self): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow + def test_tiny_longform_timestamps_generation(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + + sample = self._load_datasamples(1) + input_speech = np.concatenate(sample * 10) + + input_features = processor(input_speech, return_tensors="pt", truncation=False, sampling_rate=16_000) + input_features = input_features.to(torch_device) + + generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True) + + EXPECTED_TRANSCRIPT = [ + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "offsets": [ + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (0.0, 6.0), + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (6.0, 12.0), + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (12.0, 18.0), + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (18.0, 24.0), + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (24.0, 29.0), + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (29.0, 35.0), + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (35.0, 41.0), + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (41.0, 47.0), + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (47.0, 53.0), + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (53.0, 58.20000076293945), + }, + ], + } + ] + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow def test_large_timestamp_generation(self): set_seed(0)