diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index f19f218569a293..85e7dd04d8b249 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -1033,7 +1033,7 @@ def new_chunk(): chunk["text"] = resolved_text if return_timestamps == "word": chunk["words"] = _collate_word_timestamps( - tokenizer, resolved_tokens, resolved_token_timestamps, last_language + tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language ) chunks.append(chunk) @@ -1085,7 +1085,7 @@ def new_chunk(): chunk["text"] = resolved_text if return_timestamps == "word": chunk["words"] = _collate_word_timestamps( - tokenizer, resolved_tokens, resolved_token_timestamps, last_language + tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language ) chunks.append(chunk) @@ -1217,12 +1217,16 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None): return total_sequence, [] -def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language): +def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language, return_language): words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language) + + optional_language_field = {"language": language} if return_language else {} + timings = [ { "text": word, "timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]), + **optional_language_field, } for word, indices in zip(words, token_indices) ] diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 11bbde4143f7e8..82c5580f0ea2cc 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -322,7 +322,6 @@ def test_torch_large_with_input_features(self): @slow @require_torch - @slow def test_return_timestamps_in_preprocess(self): pipe = pipeline( task="automatic-speech-recognition", @@ -332,10 +331,10 @@ def test_return_timestamps_in_preprocess(self): ) data = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True) sample = next(iter(data)) - pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="en", task="transcribe") res = pipe(sample["audio"]["array"]) self.assertEqual(res, {"text": " Conquered returned to its place amidst the tents."}) + res = pipe(sample["audio"]["array"], return_timestamps=True) self.assertEqual( res, @@ -344,9 +343,8 @@ def test_return_timestamps_in_preprocess(self): "chunks": [{"timestamp": (0.0, 3.36), "text": " Conquered returned to its place amidst the tents."}], }, ) - pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] - res = pipe(sample["audio"]["array"], return_timestamps="word") + res = pipe(sample["audio"]["array"], return_timestamps="word") # fmt: off self.assertEqual( res, @@ -366,6 +364,63 @@ def test_return_timestamps_in_preprocess(self): ) # fmt: on + @slow + @require_torch + def test_return_timestamps_and_language_in_preprocess(self): + pipe = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny", + chunk_length_s=8, + stride_length_s=1, + return_language=True, + ) + data = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True) + sample = next(iter(data)) + + res = pipe(sample["audio"]["array"]) + self.assertEqual( + res, + { + "text": " Conquered returned to its place amidst the tents.", + "chunks": [{"language": "english", "text": " Conquered returned to its place amidst the tents."}], + }, + ) + + res = pipe(sample["audio"]["array"], return_timestamps=True) + self.assertEqual( + res, + { + "text": " Conquered returned to its place amidst the tents.", + "chunks": [ + { + "timestamp": (0.0, 3.36), + "language": "english", + "text": " Conquered returned to its place amidst the tents.", + } + ], + }, + ) + + res = pipe(sample["audio"]["array"], return_timestamps="word") + # fmt: off + self.assertEqual( + res, + { + 'text': ' Conquered returned to its place amidst the tents.', + 'chunks': [ + {"language": "english",'text': ' Conquered', 'timestamp': (0.5, 1.2)}, + {"language": "english", 'text': ' returned', 'timestamp': (1.2, 1.64)}, + {"language": "english",'text': ' to', 'timestamp': (1.64, 1.84)}, + {"language": "english",'text': ' its', 'timestamp': (1.84, 2.02)}, + {"language": "english",'text': ' place', 'timestamp': (2.02, 2.28)}, + {"language": "english",'text': ' amidst', 'timestamp': (2.28, 2.8)}, + {"language": "english",'text': ' the', 'timestamp': (2.8, 2.98)}, + {"language": "english",'text': ' tents.', 'timestamp': (2.98, 3.48)}, + ], + }, + ) + # fmt: on + @slow @require_torch def test_return_timestamps_in_preprocess_longform(self):