Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: remove deprecated code due to Cache and cache_position being default #31898

Merged
merged 9 commits into from
Jul 14, 2024

Conversation

gante
Copy link
Member

@gante gante commented Jul 10, 2024

What does this PR do?

  1. Simplifies prepare_inputs_for_generation on models using Cache
  2. Removes the unused _reorder_cache function on models using Cache

Slow tests run (and passing / same failures as in main):

  • (Cache integration tests) RUN_SLOW=1 py.test -vv tests/utils/test_cache_utils.py
  • (generate integration tests) RUN_SLOW=1 py.test -vv tests/generation/test_utils.py
  • (reference LLM model, llama) RUN_SLOW=1 py.test -vv tests/models/llama/test_modeling_llama.py
  • (reference MoE model, mixtral) RUN_SLOW=1 py.test -vv tests/models/mixtral/test_modeling_mixtral.py
  • Slow tests for ALL other models in the diff

@@ -689,13 +689,16 @@ def _update_model_kwargs_for_generation(
dim=-1,
)

if (
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR cache_position now always exists, regardless of use_cache

Comment on lines -1253 to -1154
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_update_causal_mask handles the corner case this was originally meant to cover

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💘

Comment on lines -1295 to -1197
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was used with legacy caches only

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line holds the logic from end-to-end generate compilation. The other two lines are exceptions to ensure we don't lose BC. The comment at the top should be MUCH clearer now.

cache_position=None,
use_cache=True,
**kwargs,
self, input_ids, cache_position, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: cache_position now is a mandatory input

@gante gante requested a review from ArthurZucker July 10, 2024 18:58
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SOOOOO much cleaner

Comment on lines -1253 to -1154
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💘

Comment on lines -1295 to -1197
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante changed the title Generate: remove deprecated code due to DynamicCache and cache_position being default Generate: remove deprecated code due to Cache and cache_position being default Jul 11, 2024
@gante gante mentioned this pull request Jul 11, 2024
3 tasks
@gante gante marked this pull request as ready for review July 11, 2024 19:38
@gante gante requested a review from ArthurZucker July 11, 2024 19:38
@gante
Copy link
Member Author

gante commented Jul 11, 2024

@ArthurZucker ready for a final check :) (tons of slow tests ran on my end, should be safe to merge)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM then! i trust our tests for this, would be nice to see the results of the full suit! cc @ydshieh

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 12, 2024

LGTM then! i trust our tests for this, would be nice to see the results of the full suit! cc @ydshieh

Do you want me to trigger a full (GitHub Action) CI for this PR during this weekend (before merge)?

@ArthurZucker
Copy link
Collaborator

yeah would be nice !

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 12, 2024

OK will do (once today's CI run is over)

@gante
Copy link
Member Author

gante commented Jul 12, 2024

@ydshieh please ping me when the run is over 🤗

btw, there are MANY broken tests in the list of models changed in this PR (on main), mostly SDPA, FA2, and integration tests :o I should work on it 🤔

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 12, 2024

@gante

The run (triggered for this PR) is likely to be over tomorrow morning (if I trigger it this evening). I will let you know in any case.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 12, 2024

CI is running here

@gante
Copy link
Member Author

gante commented Jul 14, 2024

slow CI looks good (same issues as in main), merging 🤗

@gante gante merged commit 739a631 into huggingface:main Jul 14, 2024
23 checks passed
@gante gante deleted the short_prep_inputs branch July 14, 2024 14:17
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jul 19, 2024
…being default (huggingface#31898)

* tmp commit

* shorter

* nit

* explicit kwargs

* propagate changes

* mass propagation with a few manual touches (let's see how CI behaves)

* fix cacheless case

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* make fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
MHRDYN7 pushed a commit to MHRDYN7/transformers that referenced this pull request Jul 23, 2024
…being default (huggingface#31898)

* tmp commit

* shorter

* nit

* explicit kwargs

* propagate changes

* mass propagation with a few manual touches (let's see how CI behaves)

* fix cacheless case

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* make fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 24, 2024
…being default (huggingface#31898)

* tmp commit

* shorter

* nit

* explicit kwargs

* propagate changes

* mass propagation with a few manual touches (let's see how CI behaves)

* fix cacheless case

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* make fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
itazap pushed a commit that referenced this pull request Jul 25, 2024
…being default (#31898)

* tmp commit

* shorter

* nit

* explicit kwargs

* propagate changes

* mass propagation with a few manual touches (let's see how CI behaves)

* fix cacheless case

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* make fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants