Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers 4.36 use_cache issue #28056

Closed
4 tasks
dakinggg opened this issue Dec 15, 2023 · 34 comments
Closed
4 tasks

Transformers 4.36 use_cache issue #28056

dakinggg opened this issue Dec 15, 2023 · 34 comments
Labels

Comments

@dakinggg
Copy link
Contributor

System Info

  • transformers version: 4.36.0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.3.1
  • Accelerate version: 0.25.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu121 (True)

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Sorry that I don't really have a minimal reproducer here as I'm in another training framework, but I still think this might be useful for you.

Running training on llama2 7b, with activation checkpointing, has some issues in 4.36. Comparing to training with 4.35.2

  • if using flash attention, training produces higher loss, is slower, and uses more memory
  • if not using flash attention, crashes with ValueError: Attention mask should be of size (2, 1, 4096, 8192), but is torch.Size([2, 1, 4096, 4096])

If I explicitly set use_cache=False (shouldn't have any impact during training because there is no cache), results with 4.36 are similar to 4.35.2.

Expected behavior

No regression from 4.35.2 -> 4.36.

@ArthurZucker
Copy link
Collaborator

Thanks, pinging @gante as well as he worked on the cache refactoring, let’s keep this in mind

@younesbelkada
Copy link
Contributor

Hi @dakinggg
Thanks very much for reporting, I believe one of your issue (or maybe both of them!) would be solved with #28031
Can you try with transformers main?

@dakinggg
Copy link
Contributor Author

Hey @younesbelkada unfortunately I don't think that fix will work for me, as I use a different training framework to handle activation checkpointing.

It'd be great to understand and fix the root cause so that transformers models are fully usable with raw pytorch. Thanks as always for the quick responses!

@younesbelkada
Copy link
Contributor

Thanks @dakinggg ok sounds great, I'll spend some time to understand to rootcause of it and why it used to fail on transformers main and provide an update here!

@younesbelkada
Copy link
Contributor

younesbelkada commented Dec 15, 2023

Hi @dakinggg

I had a deeper look, consider the snippet below:

import torch
from torch.optim import Adam
from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType

MODEL_ID = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
inputs = tokenizer("hello world what's up", return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
print(inputs)

model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", attn_implementation="eager", torch_dtype=torch.float16)
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=['q_proj', 'v_proj'], inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

optimizer = Adam(model.parameters(), lr=1e-5)
model.train()

for i in range(10):
  outputs = model(labels=inputs['input_ids'], **inputs)
  loss = outputs.loss
  print(loss)
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()

in case the fix in #28031 is not applied what will happen (step by step):

1- outputs = model(labels=inputs['input_ids'], **inputs) will work perfectly fine because:
1.1- use_cache is most of the case set to True on all model configs, therefore it will pass this logic:

and the model will create a non-None past_key_values.
1.2- use_cache will be force-set to False here: but it is too late because past_key_values have been already created above.
1.3- Since past_key_values is set to a non-None value, it will pass this line as well,
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
therefore populating past_key_value for each layer. Note at that point the past_key_values will have a shape of batch_size, 1, seq_len, seq_len

Once that all past key values are populated, the script will call loss.backward() and somehow it fails because:
2- all module's forward are called again
2.1- it ends up with the attention layers being called, since in the previous state the past_key_values were non-None, this line is called:

kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
leading to kv_seq_len being set to 2*seq_len
2.2- ValueError raised here: https://github.com/huggingface/transformers/blob/2788f8d8d5f9cee2fe33a9292b0f3570bd566a6d/src/transformers/models/llama/modeling_llama.py#L714C13-L714C69 since the shapes do not match anymore

I don't 100% master what is going on under the hood when one uses torch's GC but the fix that I proposed in #28031 circumvents this issue, by making sure there are no dummy past_key_values are created in case we are under gradient checkpointing and training regime. Hence, force-setting use_cache to False above the line here:

fixes the issue, as we have been always doing it before the cache refactor.

The fix proposed worked for peft but should be universal to all training frameworks, except if you patch LLama/Mistral modeling classes with other classes, which in that case you should apply the same patch there as well.

Let me know if anything is unclear!

@dakinggg
Copy link
Contributor Author

That mostly makes sense...I'm didn't quite understand why it wasn't an issue in previous versions though. Shouldn't we just never compute past kv during training? regardless of gradient checkpointing or not. Even if it worked, its not good to be creating past kv when we're not generating, as it uses significant extra memory.

As to why the model's forward gets called again, that is because when you activation checkpointing, you don't save all of the activations for the backward pass, only some of them, and then you recompute the rest.

@younesbelkada
Copy link
Contributor

younesbelkada commented Dec 15, 2023

Thanks @dakinggg for your reply!
The reason why it was not failing before is that here:

past_key_values_length = past_key_values[0][0].shape[2]
past_key_values was always None during training leading to that block never being called, whereas now, past_key_values are always created during training since the model will fallback to config's use_cache to create past_key_values
Thanks also for explaining about GC, it makes sense.

@dakinggg
Copy link
Contributor Author

Ahh I see. So it seems to me that the proper fix is to go back to the old behavior where past_key_values is always null during training. We don't ever want to create them unless we are doing generation. I can certainly explicitly set use_cache=False in my code, but this will be a huge pitfall for others if that old behavior is not maintained.

@dakinggg
Copy link
Contributor Author

dakinggg commented Dec 15, 2023

Related, IMO the proper place to default use_cache to True is in prepare_inputs_for_generation, not in the model config.

@ArthurZucker
Copy link
Collaborator

Yep, we'll add a fix

@dakinggg
Copy link
Contributor Author

Thanks!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@dakinggg
Copy link
Contributor Author

not sure if this has been fixed

@ArthurZucker
Copy link
Collaborator

Not yet, I'll be doing some more refactoring to past key values, in hope to fix these issues as well

@lorabit110
Copy link
Contributor

lorabit110 commented Jan 16, 2024

Do we know why it produces higher loss? Should we use 4.35.2 before the refactoring is done?

@ArthurZucker
Copy link
Collaborator

Using flash attention, yes we know where the issue comes from: #28142 and more details was fixed on main and should be released this week

@yangjianxin1
Copy link

I meet the same peoblem, I install the transformers with the main branch, but it doesn't work. Has this problem been solved? thanks! @ArthurZucker

@kiddyboots216
Copy link

This is still an issue on the latest version of Transformers it seems?

@ArthurZucker
Copy link
Collaborator

I don't think this is still an issue? I can't find the exact reproducer cc @dakinggg ?

@dakinggg
Copy link
Contributor Author

dakinggg commented Apr 4, 2024

@ArthurZucker confirmed this is still an issue. I end up with something like this

saved metadata: {'shape': torch.Size([2, 4096, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
recomputed metadata: {'shape': torch.Size([2, 8192, 32, 128]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}

@dakinggg
Copy link
Contributor Author

dakinggg commented Apr 4, 2024

Same fix has continued to work of explicitly specifying use_cache=False when loading the model for training.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

image Running on main the share snippet https://github.com//issues/28056#issuecomment-1858319673 which is the only one I could find seems to work

@ArthurZucker
Copy link
Collaborator

@dakinggg if you have a better snippet down to test and fix

@dakinggg
Copy link
Contributor Author

I suspect you all resolved the issue when using huggingface trainer (so your code snippet works), but not when using other libraries, which may enable activation checkpointing differently, and so any checks that you've put it will not know that activation checkpointing is being done. i don't have a small snippet, my test is just running a training job using llm foundry.

@kiddyboots216
Copy link

I suspect you all resolved the issue when using huggingface trainer (so your code snippet works), but not when using other libraries, which may enable activation checkpointing differently, and so any checks that you've put it will not know that activation checkpointing is being done. i don't have a small snippet, my test is just running a training job using llm foundry.

I did not use huggingface trainer, I used FSDP with a custom train loop, and just loading with use_cache=False works for me.

@dakinggg
Copy link
Contributor Author

Yes, it works fine with use_cache=False. IMO you should not have to specify use_cache=False in order for training to work outside of transformers.

@kiddyboots216
Copy link

Yes, it works fine with use_cache=False. IMO you should not have to specify use_cache=False in order for training to work outside of transformers.

I agree.

@ArthurZucker
Copy link
Collaborator

Or use_cache should be set to False by default.
The behaviour with use_cache=True cannot be guaranteed to work outside, but we are trying our best to abstract this a bit. I think currently there is a minimal number of things in the modeling_llama that rely on use_cache appart from this:

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

Which should automatically disable the use_cache. We can probably set use_cache = False if self.training for a safer path.

The other might be:

        if use_cache and not isinstance(past_key_values, Cache):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)

and that is something we are stuck with because of BC....

@rohitgr7
Copy link
Contributor

facing the same issue too with fsdp+activation checkpointing. Currently disabled cache using model.config.use_cache=False

@gante
Copy link
Member

gante commented May 29, 2024

@ArthurZucker I think explicitly setting use_cache is okay, as it is hard to automatically figure out all situations (explicitly by external libs). We can, however, throw a better exception, suggesting to use use_cache=False in situations like this

Can someone on this thread share what is the exception message you're getting on main if you DON'T set use_cache=False? That way we can improve the error message with the right suggestion, solving the bad usage experience across all external libs 🤗

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@dakinggg
Copy link
Contributor Author

Still an issue, will try to get the trace again this week. Iirc it's the standard pytorch error for act checkpoint metadata inconsistency. Also, why can't transformers just set use cache to false if self.training?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

9 participants