diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 82dcba4fdb7b96..5cfd2300346e34 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -587,11 +587,20 @@ def _compute_offsets(self, token_ids, time_precision=0.02): consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1) last_slice = np.where(timestamp_tokens)[0][0] + cur_max_timestamp = 0 + prev_segments_len = 0 for current_slice in consecutive: sliced_tokens = token_ids[last_slice:current_slice] if len(sliced_tokens) > 1: start_timestamp_position = sliced_tokens[0].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin + + if start_timestamp_position < cur_max_timestamp: + # next segment has started + prev_segments_len += cur_max_timestamp + + cur_max_timestamp = end_timestamp_position + # strip timestamp tokens from the text output sliced_tokens = self._preprocess_token_ids(sliced_tokens) text = self._decode(sliced_tokens) @@ -600,8 +609,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): { "text": text, "timestamp": ( - start_timestamp_position * time_precision, - end_timestamp_position * time_precision, + (start_timestamp_position + prev_segments_len) * time_precision, + (end_timestamp_position + prev_segments_len) * time_precision, ), } ) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 5019a9ebcda434..6b6fb3a199003a 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -229,11 +229,20 @@ def _compute_offsets(self, token_ids, time_precision=0.02): consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1) last_slice = np.where(timestamp_tokens)[0][0] + cur_max_timestamp = 0 + prev_segments_len = 0 for current_slice in consecutive: sliced_tokens = token_ids[last_slice:current_slice] if len(sliced_tokens) > 1: start_timestamp_position = sliced_tokens[0].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin + + if start_timestamp_position < cur_max_timestamp: + # next segment has started + prev_segments_len += cur_max_timestamp + + cur_max_timestamp = end_timestamp_position + # strip timestamp tokens from the text output sliced_tokens = self._preprocess_token_ids(sliced_tokens) text = self._decode(sliced_tokens) @@ -242,8 +251,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): { "text": text, "timestamp": ( - start_timestamp_position * time_precision, - end_timestamp_position * time_precision, + (start_timestamp_position + prev_segments_len) * time_precision, + (end_timestamp_position + prev_segments_len) * time_precision, ), } ) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f3d191b4d3c4c6..66a930499f73d1 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2099,6 +2099,65 @@ def test_tiny_timestamp_generation(self): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow + def test_tiny_longform_timestamps_generation(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + + dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") + sample = dataset[0]["audio"] + + input_features = processor( + sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"] + ) + input_features = input_features.to(torch_device) + + generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True) + + EXPECTED_TRANSCRIPT = [ + { + "text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "timestamp": (0.0, 6.5600000000000005), + }, + { + "text": " Nor is Mr. Quilter's manner less interesting than his matter.", + "timestamp": (6.5600000000000005, 11.24), + }, + { + "text": " He tells us that at this festive season of the year, with Christmas and roast beef looming", + "timestamp": (11.24, 16.88), + }, + { + "text": " before us, similarly drawn from eating and its results occur most readily to the mind.", + "timestamp": (16.88, 23.76), + }, + { + "text": " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and", + "timestamp": (23.76, 29.44), + }, + {"text": " can discover in it but little of rocky ithaka.", "timestamp": (29.44, 33.72)}, + { + "text": " Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite itals", + "timestamp": (33.72, 40.32), + }, + {"text": " are as national as a jingo poem.", "timestamp": (40.32, 44.72)}, + { + "text": " Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used", + "timestamp": (44.72, 50.4), + }, + {"text": " to flash his teeth.", "timestamp": (50.4, 52.96)}, + { + "text": " And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like", + "timestamp": (52.96, 58.68), + }, + {"text": " a shampoo and a Turkish bath next man.", "timestamp": (58.68, 61.96)}, + ] + + transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True) + self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT) + @slow def test_large_timestamp_generation(self): set_seed(0)