Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Whisper long-form decoding timestamps #32003

Merged
merged 17 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/transformers/models/clvp/processing_clvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __call__(self, *args, **kwargs):
inputs["attention_mask"] = encodings["attention_mask"]
return inputs

# Copied from transformers.models.whisper.processing_whisper.WhisperProcessor.batch_decode with Whisper->Clvp
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to ClvpTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/models/whisper/processing_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ def batch_decode(self, *args, **kwargs):
This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""

if isinstance(args[0], dict) and "segments" in args[0]:
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved
segments = args[0].pop("segments")

kwargs = {"segments": segments, **kwargs}

args = tuple(args[0]["sequences"].unsqueeze(0))
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved

return self.tokenizer.batch_decode(*args, **kwargs)

def decode(self, *args, **kwargs):
Expand Down
44 changes: 32 additions & 12 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre
]
return "".join(outputs)

def _compute_offsets(self, token_ids, time_precision=0.02):
def _compute_offsets(self, token_ids, time_precision=0.02, segments=None):
"""
Compute offsets for a given tokenized input

Expand All @@ -567,6 +567,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
List of tokenized input ids. Can be obtained using the `__call__` method.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
segments (List[dict], `optional`):
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved
Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids.
"""
offsets = []
# ensure torch tensor of token ids is placed on cpu
Expand All @@ -587,7 +589,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)

last_slice = np.where(timestamp_tokens)[0][0]
for current_slice in consecutive:
for i, current_slice in enumerate(consecutive):
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
Expand All @@ -596,15 +598,27 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append(
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
),
}
)

if segments is not None:
offsets.append(
{
"text": text,
"timestamp": (
segments[0][i]["start"].item(),
segments[0][i]["end"].item(),
),
}
)
else:
offsets.append(
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
),
}
)
last_slice = current_slice

return offsets
Expand Down Expand Up @@ -713,7 +727,13 @@ def decode(

# retrieve offsets
if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
if "segments" in kwargs:
segments = kwargs["segments"]
else:
segments = None

kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved
offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments)

return {"text": text, "offsets": offsets}
return text

Expand Down
44 changes: 32 additions & 12 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre
return "".join(outputs)

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
def _compute_offsets(self, token_ids, time_precision=0.02):
def _compute_offsets(self, token_ids, time_precision=0.02, segments=None):
"""
Compute offsets for a given tokenized input

Expand All @@ -209,6 +209,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
List of tokenized input ids. Can be obtained using the `__call__` method.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
segments (List[dict], `optional`):
Timestamps obtained using long form generation in Whisper, to be used to replace predicted timestamps in token_ids.
"""
offsets = []
# ensure torch tensor of token ids is placed on cpu
Expand All @@ -229,7 +231,7 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)

last_slice = np.where(timestamp_tokens)[0][0]
for current_slice in consecutive:
for i, current_slice in enumerate(consecutive):
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
Expand All @@ -238,15 +240,27 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append(
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
),
}
)

if segments is not None:
offsets.append(
{
"text": text,
"timestamp": (
segments[0][i]["start"].item(),
segments[0][i]["end"].item(),
),
}
)
else:
offsets.append(
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
),
}
)
last_slice = current_slice

return offsets
Expand Down Expand Up @@ -359,7 +373,13 @@ def decode(

# retrieve offsets
if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
if "segments" in kwargs:
segments = kwargs["segments"]
else:
segments = None
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved

offsets = self._compute_offsets(token_ids, time_precision=time_precision, segments=segments)

return {"text": text, "offsets": offsets}
return text

Expand Down
66 changes: 66 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,72 @@ def test_tiny_timestamp_generation(self):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_tiny_longform_timestamps_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)

sample = self._load_datasamples(1)
input_speech = np.concatenate(sample * 10)

input_features = processor(input_speech, return_tensors="pt", truncation=False, sampling_rate=16_000)
input_features = input_features.to(torch_device)

generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True)

EXPECTED_TRANSCRIPT = [
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"offsets": [
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (0.0, 6.0),
},
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (6.0, 12.0),
},
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (12.0, 18.0),
},
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (18.0, 24.0),
},
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (24.0, 29.0),
},
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (29.0, 35.0),
},
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (35.0, 41.0),
},
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (41.0, 47.0),
},
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (47.0, 53.0),
},
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (53.0, 58.20000076293945),
},
],
}
]

transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_large_timestamp_generation(self):
set_seed(0)
Expand Down
Loading