Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper: move to tensor cpu before converting to np array at decode time #31954

Merged
merged 2 commits into from
Jul 14, 2024

Conversation

gante
Copy link
Member

@gante gante commented Jul 14, 2024

What does this PR do?

Follow up to #27818

pytest --doctest-modules src/transformers/models/whisper/generation_whisper.py -vv started failing on main due to the PR above.

In a nutshell, if Whisper was running on GPU, the generated tensors would also be on GPU. The new decoding code called token_ids.numpy(), which failed if the token_ids tensor was on GPU. This PR moves it to the CPU before the numpy conversion :)

cc @sanchit-gandhi

@gante gante requested a review from amyeroberts July 14, 2024 14:00
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Just a question about the properties of token_ids

token_ids = token_ids.numpy()
if hasattr(token_ids, "numpy"):
if "torch" in str(type(token_ids)):
token_ids = token_ids.cpu().numpy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following from this - will token_ids ever have a grad? In which case, this will also fail on the cpu call

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_ids, the output of generate, will not have gradients :) generate is decorated with @no_grad

@gante gante merged commit a5c642f into huggingface:main Jul 14, 2024
19 checks passed
@gante gante deleted the fix_whisper_doctest branch July 14, 2024 15:39
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jul 19, 2024
MHRDYN7 pushed a commit to MHRDYN7/transformers that referenced this pull request Jul 23, 2024
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants