Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: remove deprecated public decoding functions and streamline logic 🧼 #29956

Merged
merged 15 commits into from
May 1, 2024

Conversation

gante
Copy link
Member

@gante gante commented Mar 29, 2024

What does this PR do?

🧼 🧼 🧼

Calling the internal decoding functions as part of our public API was scheduled for removal in v4.41 (the next release). Its motivation was flexibility and conciseness: having multiple public interfaces for the same functionality forced us to add repeated logic in many places, increasing every time we added a new decoding method.

Due to this removal from the public API, a few things were changed/removed as a logical consequence:

  1. No more documentation/examples in the internal decoding functions. We have a page with examples to call the internal decoding methods from generate;
  2. The arguments to the decoding functions are no longer optional -- the decoding functions are exclusively called from generate. As such, we can remove a lot of boilerplate (x = x if x is not None else self.generation_config.x).
  3. No more tests regarding the internal decoding methods -- already removed in this PR

Tests ran locally:

  • generate doctests (pytest --doctest-modules src/transformers/generation -vv)
  • generate integration tests (RUN_SLOW=1 py.test tests/generation/test_utils.py -vv)
  • cache integration tests (RUN_SLOW=1 py.test tests/test_cache_utils.py -vv) -- same failures as in main
  • llama slow tests (RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv)
  • whisper slow tests (RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv) -- same failures as in main

@gante gante changed the title Generate: remove deprecated public decoding functions and streamline logic Generate: remove deprecated public decoding functions and streamline logic 🧼 Mar 29, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante marked this pull request as ready for review March 29, 2024 17:01
@gante
Copy link
Member Author

gante commented Mar 29, 2024

Tests on all models are passing 🙌 (the failing pipeline test seems unrelated, and passing locally on my end)

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yaaay, thanks for clean-up! Looks so much nicer

@gante gante force-pushed the v_4_41_decoding_functions branch from e3c994f to db8cc74 Compare April 23, 2024 09:44
@@ -65,25 +65,16 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:

- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

functions not public -> no docs -> remove link to the docs

Comment on lines +496 to +497
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see comment on L1434)

Comment on lines -1434 to -1479
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of pulling information from generation_config to pass around through model_kwargs, let's use generation_config directly.

A single object to hold all generation parameterization.

Comment on lines 1581 to 1582
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were being passed through model_kwargs before

top_k: Optional[int] = 1,
penalty_alpha: Optional[float] = 0,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one was not needed -- contrastive search does not sample

@@ -1674,6 +1680,9 @@ def generate(
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On most decoding methods eos_token_ids doesn't need to be passed -- it was used when the decoding method was called directly and stopping_criteria was not passed.

However, beam methods still need it.

@@ -1945,69 +1940,9 @@ def _contrastive_search(
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.

Examples:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function not public -> let's remove the example (preparing its inputs will be more challenging now, as we no longer have API guarantees)

Comment on lines -2261 to -2299
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the main generate body, we set pad_token_id to eos_token_id in this situation -- this exception will never be reached

@gante
Copy link
Member Author

gante commented Apr 23, 2024

ping @ArthurZucker -- ready for review :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Late review but very nice cleanup sir! 🤗

Comment on lines 2256 to 2261
pad_token_id: Optional[int],
output_attentions: bool,
output_hidden_states: bool,
output_scores: bool,
output_logits: bool,
return_dict_in_generate: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

theoretically some of these can be taken from the config / the generation config if it inherits them. But nit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! That's a good idea

@gante gante force-pushed the v_4_41_decoding_functions branch from db8cc74 to a8b6e58 Compare May 1, 2024 13:42
@gante
Copy link
Member Author

gante commented May 1, 2024

Reran slow tests locally, all seems good 👍

@gante gante merged commit d57ffb4 into huggingface:main May 1, 2024
23 checks passed
@gante gante deleted the v_4_41_decoding_functions branch May 1, 2024 16:38
itazap pushed a commit that referenced this pull request May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants