From 474bc5ce6448f8571cd9230f2f55a139d7616d6d Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Tue, 16 Jul 2024 16:40:54 +0200 Subject: [PATCH 01/16] fix lo form timestamps in decode_batch --- .../models/whisper/processing_whisper.py | 8 ++++ .../models/whisper/tokenization_whisper.py | 47 ++++++++++++++----- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/whisper/processing_whisper.py b/src/transformers/models/whisper/processing_whisper.py index f22aae143e6bc4..5618fe08bf5a80 100644 --- a/src/transformers/models/whisper/processing_whisper.py +++ b/src/transformers/models/whisper/processing_whisper.py @@ -84,6 +84,14 @@ def batch_decode(self, *args, **kwargs): This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ + + if isinstance(args[0], dict) and 'segments' in args[0]: + segments = args[0].pop('segments') + + kwargs = {"segments": segments, **kwargs} + + args = tuple(args[0]['sequences'].unsqueeze(0)) + return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index f19f218569a293..01d949a1d8ff48 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -558,7 +558,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre ] return "".join(outputs) - def _compute_offsets(self, token_ids, time_precision=0.02): + def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): """ Compute offsets for a given tokenized input @@ -567,6 +567,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. + segments (List[dict], `optional`, defaults to None): + Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -587,7 +589,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02): consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1) last_slice = np.where(timestamp_tokens)[0][0] - for current_slice in consecutive: + for i, current_slice in enumerate(consecutive): sliced_tokens = token_ids[last_slice:current_slice] if len(sliced_tokens) > 1: start_timestamp_position = sliced_tokens[0].item() - timestamp_begin @@ -596,15 +598,28 @@ def _compute_offsets(self, token_ids, time_precision=0.02): sliced_tokens = self._preprocess_token_ids(sliced_tokens) text = self._decode(sliced_tokens) text = self._filter_timestamp_ids(text) - offsets.append( - { - "text": text, - "timestamp": ( - start_timestamp_position * time_precision, - end_timestamp_position * time_precision, - ), - } - ) + + if segments is not None: + + offsets.append( + { + "text": text, + "timestamp": ( + segments[0][i]['start'].item(), + segments[0][i]['end'].item(), + ), + } + ) + else: + offsets.append( + { + "text": text, + "timestamp": ( + start_timestamp_position * time_precision, + end_timestamp_position * time_precision, + ), + } + ) last_slice = current_slice return offsets @@ -713,7 +728,14 @@ def decode( # retrieve offsets if output_offsets: - offsets = self._compute_offsets(token_ids, time_precision=time_precision) + + if "segments" in kwargs: + segments = kwargs['segments'] + else: + segments = None + + offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments ) + return {"text": text, "offsets": offsets} return text @@ -852,6 +874,7 @@ def get_prompt_ids(self, text: str, return_tensors="np"): return batch_encoding["input_ids"] def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): + if not isinstance(token_ids, list): token_ids = self._convert_to_list(token_ids) From 743cc59a4d455a3efab989bda0d5f43a8ef2f03e Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Wed, 17 Jul 2024 11:21:49 +0100 Subject: [PATCH 02/16] Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- src/transformers/models/whisper/tokenization_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 01d949a1d8ff48..243f708196709a 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -567,7 +567,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. - segments (List[dict], `optional`, defaults to None): + segments (List[dict], `optional`): Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] From 06cebb1c269dc3b70dfc14df284d05c465f6c41a Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Wed, 17 Jul 2024 11:21:55 +0100 Subject: [PATCH 03/16] Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- src/transformers/models/whisper/tokenization_whisper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 243f708196709a..d86ef7e912cdb6 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -874,7 +874,6 @@ def get_prompt_ids(self, text: str, return_tensors="np"): return batch_encoding["input_ids"] def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): - if not isinstance(token_ids, list): token_ids = self._convert_to_list(token_ids) From 362ab6a3cf5bf0f6f02fe2bfe05d1eeedcde79d2 Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Wed, 17 Jul 2024 13:15:29 +0200 Subject: [PATCH 04/16] add test --- .../models/whisper/processing_whisper.py | 10 +- .../models/whisper/tokenization_whisper.py | 22 ++--- tests/models/whisper/test_modeling_whisper.py | 97 +++++++++++++++++++ 3 files changed, 113 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/whisper/processing_whisper.py b/src/transformers/models/whisper/processing_whisper.py index 5618fe08bf5a80..a363b73cdfb178 100644 --- a/src/transformers/models/whisper/processing_whisper.py +++ b/src/transformers/models/whisper/processing_whisper.py @@ -84,14 +84,14 @@ def batch_decode(self, *args, **kwargs): This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ - - if isinstance(args[0], dict) and 'segments' in args[0]: + + if isinstance(args[0], dict) and 'segments' in args[0]: segments = args[0].pop('segments') - + kwargs = {"segments": segments, **kwargs} - + args = tuple(args[0]['sequences'].unsqueeze(0)) - + return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 01d949a1d8ff48..06b9abfe4848ca 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -568,7 +568,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. segments (List[dict], `optional`, defaults to None): - Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. + Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -598,9 +598,9 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): sliced_tokens = self._preprocess_token_ids(sliced_tokens) text = self._decode(sliced_tokens) text = self._filter_timestamp_ids(text) - - if segments is not None: - + + if segments is not None: + offsets.append( { "text": text, @@ -610,7 +610,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): ), } ) - else: + else: offsets.append( { "text": text, @@ -728,14 +728,14 @@ def decode( # retrieve offsets if output_offsets: - - if "segments" in kwargs: + + if "segments" in kwargs: segments = kwargs['segments'] - else: + else: segments = None - + offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments ) - + return {"text": text, "offsets": offsets} return text @@ -874,7 +874,7 @@ def get_prompt_ids(self, text: str, return_tensors="np"): return batch_encoding["input_ids"] def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): - + if not isinstance(token_ids, list): token_ids = self._convert_to_list(token_ids) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 5fc66f9a20551d..5a7ab791f1a28e 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2001,6 +2001,103 @@ def test_tiny_timestamp_generation(self): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + @slow + def test_tiny_longform_timestamps_generation(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + + sample = self._load_datasamples(1) + input_speech = np.concatenate(sample * 10) + + input_features = processor(input_speech, return_tensors="pt", truncation=False, sampling_rate=16_000) + input_features = input_features.to(torch_device) + + generated_ids = model.generate(**input_features, return_timestamps=True, return_segments = True) + + EXPECTED_TRANSCRIPT = [ + { + "text": ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.', + "offsets": [ + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 0.0, + 6.0 + ) + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 6.0, + 12.0 + ) + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 12.0, + 18.0 + ) + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 18.0, + 24.0 + ) + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 24.0, + 29.0 + ) + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 29.0, + 35.0 + ) + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 35.0, + 41.0 + ) + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 41.0, + 47.0 + ) + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 47.0, + 53.0 + ) + }, + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": ( + 53.0, + 58.20000076293945 + ) + } + ] + } + ] + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow def test_large_timestamp_generation(self): set_seed(0) From 9483e0946df15163f6c8667e296c4522cb1eab49 Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Wed, 17 Jul 2024 13:18:42 +0200 Subject: [PATCH 05/16] make style --- .../models/whisper/processing_whisper.py | 6 +- .../models/whisper/tokenization_whisper.py | 13 ++-- tests/models/whisper/test_modeling_whisper.py | 59 +++++-------------- 3 files changed, 22 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/whisper/processing_whisper.py b/src/transformers/models/whisper/processing_whisper.py index a363b73cdfb178..b09e45fe2bcbf5 100644 --- a/src/transformers/models/whisper/processing_whisper.py +++ b/src/transformers/models/whisper/processing_whisper.py @@ -85,12 +85,12 @@ def batch_decode(self, *args, **kwargs): refer to the docstring of this method for more information. """ - if isinstance(args[0], dict) and 'segments' in args[0]: - segments = args[0].pop('segments') + if isinstance(args[0], dict) and "segments" in args[0]: + segments = args[0].pop("segments") kwargs = {"segments": segments, **kwargs} - args = tuple(args[0]['sequences'].unsqueeze(0)) + args = tuple(args[0]["sequences"].unsqueeze(0)) return self.tokenizer.batch_decode(*args, **kwargs) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index d29fdb6bf55f7a..9dcf892ba08ec2 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -568,7 +568,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. segments (List[dict], `optional`): - Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. + Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -600,13 +600,12 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): text = self._filter_timestamp_ids(text) if segments is not None: - offsets.append( { "text": text, "timestamp": ( - segments[0][i]['start'].item(), - segments[0][i]['end'].item(), + segments[0][i]["start"].item(), + segments[0][i]["end"].item(), ), } ) @@ -728,13 +727,12 @@ def decode( # retrieve offsets if output_offsets: - if "segments" in kwargs: - segments = kwargs['segments'] + segments = kwargs["segments"] else: segments = None - offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments ) + offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments) return {"text": text, "offsets": offsets} return text @@ -874,7 +872,6 @@ def get_prompt_ids(self, text: str, return_tensors="np"): return batch_encoding["input_ids"] def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): - if not isinstance(token_ids, list): token_ids = self._convert_to_list(token_ids) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 5a7ab791f1a28e..688ee4f6656d92 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2001,7 +2001,6 @@ def test_tiny_timestamp_generation(self): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - @slow def test_tiny_longform_timestamps_generation(self): set_seed(0) @@ -2015,83 +2014,53 @@ def test_tiny_longform_timestamps_generation(self): input_features = processor(input_speech, return_tensors="pt", truncation=False, sampling_rate=16_000) input_features = input_features.to(torch_device) - generated_ids = model.generate(**input_features, return_timestamps=True, return_segments = True) + generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True) EXPECTED_TRANSCRIPT = [ { - "text": ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.', + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", "offsets": [ { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 0.0, - 6.0 - ) + "timestamp": (0.0, 6.0), }, { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 6.0, - 12.0 - ) + "timestamp": (6.0, 12.0), }, { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 12.0, - 18.0 - ) + "timestamp": (12.0, 18.0), }, { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 18.0, - 24.0 - ) + "timestamp": (18.0, 24.0), }, { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 24.0, - 29.0 - ) + "timestamp": (24.0, 29.0), }, { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 29.0, - 35.0 - ) + "timestamp": (29.0, 35.0), }, { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 35.0, - 41.0 - ) + "timestamp": (35.0, 41.0), }, { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 41.0, - 47.0 - ) + "timestamp": (41.0, 47.0), }, { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 47.0, - 53.0 - ) + "timestamp": (47.0, 53.0), }, { "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "timestamp": ( - 53.0, - 58.20000076293945 - ) - } - ] + "timestamp": (53.0, 58.20000076293945), + }, + ], } ] From 9d9438a200cf94d77e099b31c7f1df16b9cb7fb5 Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Wed, 17 Jul 2024 13:25:48 +0200 Subject: [PATCH 06/16] fix copies --- .../models/clvp/processing_clvp.py | 1 - .../whisper/tokenization_whisper_fast.py | 44 ++++++++++++++----- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/clvp/processing_clvp.py b/src/transformers/models/clvp/processing_clvp.py index 4e015cea1f8475..ebccab89d0fca3 100644 --- a/src/transformers/models/clvp/processing_clvp.py +++ b/src/transformers/models/clvp/processing_clvp.py @@ -73,7 +73,6 @@ def __call__(self, *args, **kwargs): inputs["attention_mask"] = encodings["attention_mask"] return inputs - # Copied from transformers.models.whisper.processing_whisper.WhisperProcessor.batch_decode with Whisper->Clvp def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to ClvpTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index d1205d1a8ec01b..c1eb1bb76ea71c 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -200,7 +200,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre return "".join(outputs) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets - def _compute_offsets(self, token_ids, time_precision=0.02): + def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): """ Compute offsets for a given tokenized input @@ -209,6 +209,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. + segments (List[dict], `optional`): + Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -229,7 +231,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02): consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1) last_slice = np.where(timestamp_tokens)[0][0] - for current_slice in consecutive: + for i, current_slice in enumerate(consecutive): sliced_tokens = token_ids[last_slice:current_slice] if len(sliced_tokens) > 1: start_timestamp_position = sliced_tokens[0].item() - timestamp_begin @@ -238,15 +240,27 @@ def _compute_offsets(self, token_ids, time_precision=0.02): sliced_tokens = self._preprocess_token_ids(sliced_tokens) text = self._decode(sliced_tokens) text = self._filter_timestamp_ids(text) - offsets.append( - { - "text": text, - "timestamp": ( - start_timestamp_position * time_precision, - end_timestamp_position * time_precision, - ), - } - ) + + if segments is not None: + offsets.append( + { + "text": text, + "timestamp": ( + segments[0][i]["start"].item(), + segments[0][i]["end"].item(), + ), + } + ) + else: + offsets.append( + { + "text": text, + "timestamp": ( + start_timestamp_position * time_precision, + end_timestamp_position * time_precision, + ), + } + ) last_slice = current_slice return offsets @@ -359,7 +373,13 @@ def decode( # retrieve offsets if output_offsets: - offsets = self._compute_offsets(token_ids, time_precision=time_precision) + if "segments" in kwargs: + segments = kwargs["segments"] + else: + segments = None + + offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments) + return {"text": text, "offsets": offsets} return text From 588e765c3cdaa311d84264d355bfae7b77c10a14 Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:38:35 +0100 Subject: [PATCH 07/16] Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/whisper/tokenization_whisper_fast.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index c1eb1bb76ea71c..9229a8addf8dfe 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -373,10 +373,7 @@ def decode( # retrieve offsets if output_offsets: - if "segments" in kwargs: - segments = kwargs["segments"] - else: - segments = None + segments = kwargs.get("segments") offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments) From 4f2dc502d9261e9dc26da7b5c57805f38407f2e4 Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:38:47 +0100 Subject: [PATCH 08/16] Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/whisper/tokenization_whisper.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 9dcf892ba08ec2..f1eef946d1cbed 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -727,11 +727,7 @@ def decode( # retrieve offsets if output_offsets: - if "segments" in kwargs: - segments = kwargs["segments"] - else: - segments = None - + segments = kwargs.get("segments") offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments) return {"text": text, "offsets": offsets} From 21cbe42952635cd7480b2950363deea057dddbd9 Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:39:13 +0100 Subject: [PATCH 09/16] Update src/transformers/models/whisper/processing_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/whisper/processing_whisper.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/processing_whisper.py b/src/transformers/models/whisper/processing_whisper.py index b09e45fe2bcbf5..96bc1fadb59e40 100644 --- a/src/transformers/models/whisper/processing_whisper.py +++ b/src/transformers/models/whisper/processing_whisper.py @@ -86,10 +86,7 @@ def batch_decode(self, *args, **kwargs): """ if isinstance(args[0], dict) and "segments" in args[0]: - segments = args[0].pop("segments") - - kwargs = {"segments": segments, **kwargs} - + kwargs["segments"] = args[0].pop("segments") args = tuple(args[0]["sequences"].unsqueeze(0)) return self.tokenizer.batch_decode(*args, **kwargs) From e31add36a77696617ad83a40f73dc4e514e11d36 Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:39:27 +0100 Subject: [PATCH 10/16] Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/whisper/tokenization_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index f1eef946d1cbed..aab08d184d456a 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -567,7 +567,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. - segments (List[dict], `optional`): + segments (List[dict], *optional*): Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] From 5c96e6ceeb97a0b652e814d44f5a4d2b3a4fcdb2 Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Thu, 18 Jul 2024 11:05:28 +0200 Subject: [PATCH 11/16] apply review suggestions --- src/transformers/models/whisper/processing_whisper.py | 2 ++ .../models/whisper/tokenization_whisper.py | 10 +++++----- .../models/whisper/tokenization_whisper_fast.py | 10 +++++----- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/whisper/processing_whisper.py b/src/transformers/models/whisper/processing_whisper.py index 96bc1fadb59e40..3046a9c5f370b5 100644 --- a/src/transformers/models/whisper/processing_whisper.py +++ b/src/transformers/models/whisper/processing_whisper.py @@ -85,6 +85,8 @@ def batch_decode(self, *args, **kwargs): refer to the docstring of this method for more information. """ + # If segments are present in args, we are performing long-form generation and need to return long form timestamps. + # The long-form timestamps are already present in segments and should be passed as kwargs to batch_decode. if isinstance(args[0], dict) and "segments" in args[0]: kwargs["segments"] = args[0].pop("segments") args = tuple(args[0]["sequences"].unsqueeze(0)) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index aab08d184d456a..d380e3e56311ac 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -558,7 +558,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre ] return "".join(outputs) - def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): + def _compute_offsets(self, token_ids, time_precision=0.02, longform_timestamps=None): """ Compute offsets for a given tokenized input @@ -567,7 +567,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. - segments (List[dict], *optional*): + longform_timestamps (List[dict], *optional*): Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] @@ -599,13 +599,13 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): text = self._decode(sliced_tokens) text = self._filter_timestamp_ids(text) - if segments is not None: + if longform_timestamps is not None: offsets.append( { "text": text, "timestamp": ( - segments[0][i]["start"].item(), - segments[0][i]["end"].item(), + longform_timestamps[0][i]["start"].item(), + longform_timestamps[0][i]["end"].item(), ), } ) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 9229a8addf8dfe..09498cb9278c51 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -200,7 +200,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre return "".join(outputs) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets - def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): + def _compute_offsets(self, token_ids, time_precision=0.02, longform_timestamps=None): """ Compute offsets for a given tokenized input @@ -209,7 +209,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. - segments (List[dict], `optional`): + longform_timestamps (List[dict], *optional*): Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] @@ -241,13 +241,13 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segments=None): text = self._decode(sliced_tokens) text = self._filter_timestamp_ids(text) - if segments is not None: + if longform_timestamps is not None: offsets.append( { "text": text, "timestamp": ( - segments[0][i]["start"].item(), - segments[0][i]["end"].item(), + longform_timestamps[0][i]["start"].item(), + longform_timestamps[0][i]["end"].item(), ), } ) From 7e2cff932fb7f6559596479e2dfe8e682acb3e9d Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Thu, 18 Jul 2024 11:16:19 +0200 Subject: [PATCH 12/16] fix --- src/transformers/models/whisper/tokenization_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index d380e3e56311ac..9372390392f06b 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -728,7 +728,7 @@ def decode( # retrieve offsets if output_offsets: segments = kwargs.get("segments") - offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments) + offsets = self._compute_offsets(token_ids, time_precision=time_precision, longform_timestamps=segments) return {"text": text, "offsets": offsets} return text From b3d1ce51828516031df10056c3eadd47fa393ba6 Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Thu, 18 Jul 2024 11:21:01 +0200 Subject: [PATCH 13/16] fix copies --- src/transformers/models/whisper/tokenization_whisper_fast.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 09498cb9278c51..85ad1a783d23e7 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -374,8 +374,7 @@ def decode( # retrieve offsets if output_offsets: segments = kwargs.get("segments") - - offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments) + offsets = self._compute_offsets(token_ids, time_precision=time_precision, longform_timestamps=segments) return {"text": text, "offsets": offsets} return text From fdea6e8ba686953a7c0488de0999dd452e0cf8fc Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Thu, 18 Jul 2024 12:41:08 +0200 Subject: [PATCH 14/16] fix --- src/transformers/models/whisper/processing_whisper.py | 2 +- src/transformers/models/whisper/tokenization_whisper.py | 6 ++++-- .../models/whisper/tokenization_whisper_fast.py | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/whisper/processing_whisper.py b/src/transformers/models/whisper/processing_whisper.py index 3046a9c5f370b5..07ece4314b249b 100644 --- a/src/transformers/models/whisper/processing_whisper.py +++ b/src/transformers/models/whisper/processing_whisper.py @@ -88,7 +88,7 @@ def batch_decode(self, *args, **kwargs): # If segments are present in args, we are performing long-form generation and need to return long form timestamps. # The long-form timestamps are already present in segments and should be passed as kwargs to batch_decode. if isinstance(args[0], dict) and "segments" in args[0]: - kwargs["segments"] = args[0].pop("segments") + kwargs["longform_timestamps"] = args[0].pop("segments") args = tuple(args[0]["sequences"].unsqueeze(0)) return self.tokenizer.batch_decode(*args, **kwargs) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 9372390392f06b..6b94619add8b00 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -727,8 +727,10 @@ def decode( # retrieve offsets if output_offsets: - segments = kwargs.get("segments") - offsets = self._compute_offsets(token_ids, time_precision=time_precision, longform_timestamps=segments) + longform_timestamps = kwargs.get("longform_timestamps") + offsets = self._compute_offsets( + token_ids, time_precision=time_precision, longform_timestamps=longform_timestamps + ) return {"text": text, "offsets": offsets} return text diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 85ad1a783d23e7..540056df8bd807 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -373,8 +373,10 @@ def decode( # retrieve offsets if output_offsets: - segments = kwargs.get("segments") - offsets = self._compute_offsets(token_ids, time_precision=time_precision, longform_timestamps=segments) + longform_timestamps = kwargs.get("longform_timestamps") + offsets = self._compute_offsets( + token_ids, time_precision=time_precision, longform_timestamps=longform_timestamps + ) return {"text": text, "offsets": offsets} return text From 1a0c349aa2b6f3f6e2429b694b8a57291cf9fa19 Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Fri, 19 Jul 2024 09:00:12 +0100 Subject: [PATCH 15/16] Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/whisper/tokenization_whisper_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 540056df8bd807..edc9c778023f3f 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -209,7 +209,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, longform_timestamps=N List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. - longform_timestamps (List[dict], *optional*): + longform_timestamps (`List[dict]`, *optional*): Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = [] From 42b8478816afdffcd87652ee3692781804e983b8 Mon Sep 17 00:00:00 2001 From: kamilakesbi Date: Fri, 19 Jul 2024 10:05:49 +0200 Subject: [PATCH 16/16] fix-copies --- src/transformers/models/whisper/tokenization_whisper_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index edc9c778023f3f..540056df8bd807 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -209,7 +209,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02, longform_timestamps=N List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, `optional`, defaults to 0.02): The time ratio to convert from token to time. - longform_timestamps (`List[dict]`, *optional*): + longform_timestamps (List[dict], *optional*): Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids. """ offsets = []