Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.compile Support For Mamba #31247

Conversation

zhenglongjiepheonix
Copy link
Contributor

@zhenglongjiepheonix zhenglongjiepheonix commented Jun 4, 2024

torch.compile support for mamba! Closes #31246

@zhenglongjiepheonix
Copy link
Contributor Author

It seems that the mamba cache is not compatible with the current cache design used in generate, but we have to initialize the cache before we step into model.forward in order to make dynamo happy, and I think a specific conditional check for mamba in get_cache might not be what we want because it's too specific a patch. We can let user specify and create a mamba cache when using torch.compile, or is there any way to solve this so that we set the cache for users in gerenate? @ArthurZucker

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good!
I think for this we want a general solution that would work for hybrid caches as well (Like jamba / mamba2 / zamba / etc).
Here it's possible to init the cache before going into the forward if you set the NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} with "mamba" ? it's not too bad 😅
otherwise it could be that we redefiine the staticCache for mamba to be the MambaCache class.
cc @zucchini-nlp and @gante 😉

@zhenglongjiepheonix zhenglongjiepheonix changed the title [WIP] Add torch.compile Support For Mamba Add torch.compile Support For Mamba Jun 7, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice 🔥

src/transformers/models/mamba/modeling_mamba.py Outdated Show resolved Hide resolved
src/transformers/models/mamba/modeling_mamba.py Outdated Show resolved Hide resolved
src/transformers/models/mamba/modeling_mamba.py Outdated Show resolved Hide resolved
Comment on lines 128 to 129
def is_initialized(self, layer_idx):
return self.is_cache_initialized[layer_idx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be checked with cache_postiions instead no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine just like we need a flag in whisper, here cache_positions is not so meaningful because we always know how to update and get the cache even if cache_positions is not passed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically yes, but this is adding some complexity, which is not needed in the cache API. Checking the cache positions is more reliable, and is what we want to go with.

  • you don't have to reset and set another tensor
    which is also a win

Let's just use the cache positions

@ArthurZucker
Copy link
Collaborator

Could you share benchmark results?

@zhenglongjiepheonix
Copy link
Contributor Author

zhenglongjiepheonix commented Jun 8, 2024

Could you share benchmark results?

Sure, with mamba-1.4b on a single A100-SXM4-80GB in float16, batch_size=1 inference mode, time is for per token generation:

  • slow_foward
Screen Shot 2024-06-07 at 20 19 37
  • cuda_kernel_forward
Screen Shot 2024-06-07 at 20 20 43
  • compile
Screen Shot 2024-06-07 at 20 22 11

And here is the function I used for benchmarking

@torch.no_grad
def perf():
    tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf")
    tokenizer.pad_token = tokenizer.eos_token
    inputs = tokenizer("Hey how are you doing today ? " * 100, return_tensors="pt", padding=True).to('cuda')

    model = MambaForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf", torch_dtype=torch.float16)
    model.config.use_cache = True
    model.to('cuda')    

    input_ids = inputs.input_ids
    cache = MambaCache(model.config, 1, device=input_ids.device)
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    logits = model(input_ids, cache_params = cache).logits
    next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]

    model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
    torch.cuda.synchronize()
    for i in range(10):
        start.record()
        logits = model(next_token.clone(), cache_params = cache).logits
        next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]

        end.record()
        torch.cuda.synchronize()
        print(f'Step {i}, Total time: {start.elapsed_time(end)} ms, next_token = {next_token.int()}')

As we can see from the results above, it takes a lot of time for the first and second decoding step, and actually in my script I skipped compile for the very first prefilling stage because it takes forever(nearly one hour) to compile, so if we only focus on the decoding phase, then we get a steady 8x speedup
even if comparing with cuda kernel implementation

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use cache positions for the check, and be careful of BC!
Otherwise great work!

src/transformers/models/mamba/modeling_mamba.py Outdated Show resolved Hide resolved
Comment on lines 128 to 129
def is_initialized(self, layer_idx):
return self.is_cache_initialized[layer_idx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically yes, but this is adding some complexity, which is not needed in the cache API. Checking the cache positions is more reliable, and is what we want to go with.

  • you don't have to reset and set another tensor
    which is also a win

Let's just use the cache positions

src/transformers/models/mamba/modeling_mamba.py Outdated Show resolved Hide resolved
tests/models/mamba/test_modeling_mamba.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost done! Again great work! Let's just define a good api for futur models / for RecurrentGemma for example that will also benefit from this!

Comment on lines 1746 to 1766
elif generation_config.cache_implementation == "mamba":
from ..models.mamba.modeling_mamba import MambaCache, MambaConfig

if not isinstance(self.config, MambaConfig):
raise ValueError(
"You can only specify `cache_implementation` to `mamba` if you are using mamba model"
)

if hasattr(self, "_cache"):
assert isinstance(self._cache, MambaCache), "Only `MambaCache` can be used on mamba model"
need_new_cache = self._cache.conv_states.shape[1] != batch_size
else:
need_new_cache = True

if need_new_cache:
self._cache = MambaCache(
config=self.config, batch_size=batch_size, dtype=self.dtype, device=self.device
)
else:
self._cache.reset()
model_kwargs["cache_params"] = self._cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THe problem with this is that it does not scale with new models. It's not something we want to do at all TBH.
The simplest is to import the MambaCache, and add it to the mapping "mamba": MambaCache.

needs_new_cache should be specific to the cache class.
Maybe this is the best approach as for new cache class it will be a new correct way to say whether or not we reset!

self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool flag should be dynamo compatible, but I trust you on this one and it's fairly small so LGTM

hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
else:

if cache_position.shape[0] == self.conv_kernel_size:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth adding a comment in the code to explain the trick.
More in favor of using cache position[0] to detect decoding if it works, if not then a small comment!

cache_params = MambaCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. cache_postions[0] > 0 breaks full graph I gues?

input_ids = input_ids[:, -1].unsqueeze(-1)
if use_cache:
# `cache_position` should have been initialized in `generate`
assert cache_position is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's raise an error rather than using asserts

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! cc @gante if you can have a look for the generate changes!

@@ -1751,7 +1758,8 @@ def generate(
)

use_dynamic_cache_by_default = False
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
cache_name = getattr(self, "cache_name", "past_key_values")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be better to set this as a class attribute, all that inherit from Cache will have "path_key_values" and mamba will get "cache_params" WDYT?

Copy link
Contributor Author

@zhenglongjiepheonix zhenglongjiepheonix Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently MambaCache is not inherited from Cache because of APIs of Cache are only suitable for transformer models with kv states, so you mean make cache_name a class attribute of Cache and MambaCache with values being past_key_values and cache_params respectively?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache_name = "cache_params" for mamba cache class, and cache_name = "past_key_values" for Cache classes !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also prefer a more verbose version for now. I spent some time looking for the cache_name variable in this review, which is not a good indicator of readability :D

e.g.

if "mamba" in self.__class__.__name__.lower():
    cache_var_name = "cache_params"
else:
    cache_var_name = "past_key_values"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I guess it's a way to see if we are using mamba-related models, and there is an issue with associating cache_name with Cache, we need to know which cache we are creating in order to know the cache name, which brings a circular issue when we are trying to check if users are passing both cache_implementation and a cache instance, let's go with it for now.

src/transformers/generation/utils.py Show resolved Hide resolved
Comment on lines +664 to +668
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be a more compile friendly way to do this, but that will be a todo. https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.ivdr7fmrbeab might have answers, since I do not, LGTM for now!

Copy link
Contributor Author

@zhenglongjiepheonix zhenglongjiepheonix Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's just a matter of making it use data-independent ops, it could either be a flag or using shape-dependent way to see which stage it is in, I think a bool flag in forward will also do the trick, but we have introduced cache_position in order to address this anyway, another way of thinking this is we are kind of altering the length of hidden states by apply padding(positive or negative) before we update the cache, so we need to make sure the cache_position is aligned with the hidden states after padding rather than before padding

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it thanks!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate changes look (mostly) good to me 🤗

@@ -1751,7 +1758,8 @@ def generate(
)

use_dynamic_cache_by_default = False
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
cache_name = getattr(self, "cache_name", "past_key_values")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also prefer a more verbose version for now. I spent some time looking for the cache_name variable in this review, which is not a good indicator of readability :D

e.g.

if "mamba" in self.__class__.__name__.lower():
    cache_var_name = "cache_params"
else:
    cache_var_name = "past_key_values"

@ArthurZucker
Copy link
Collaborator

Looks good!

@zhenglongjiepheonix zhenglongjiepheonix merged commit c75969e into huggingface:main Jul 18, 2024
23 of 24 checks passed
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jul 19, 2024
* modify mamba cache

* set up cache

* add test

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* use_cache_position

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* fix

* cache in generate

* [run-slow] mamba

* address comments

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix cache name

* [run-slow] mamba
@ArthurZucker
Copy link
Collaborator

Congrats on the merge! 🔥

MHRDYN7 pushed a commit to MHRDYN7/transformers that referenced this pull request Jul 23, 2024
* modify mamba cache

* set up cache

* add test

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* use_cache_position

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* fix

* cache in generate

* [run-slow] mamba

* address comments

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix cache name

* [run-slow] mamba
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 24, 2024
* modify mamba cache

* set up cache

* add test

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* use_cache_position

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* fix

* cache in generate

* [run-slow] mamba

* address comments

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix cache name

* [run-slow] mamba
itazap pushed a commit that referenced this pull request Jul 25, 2024
* modify mamba cache

* set up cache

* add test

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* use_cache_position

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* fix

* cache in generate

* [run-slow] mamba

* address comments

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix cache name

* [run-slow] mamba
dataKim1201 pushed a commit to dataKim1201/transformers that referenced this pull request Oct 7, 2024
* modify mamba cache

* set up cache

* add test

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* use_cache_position

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* fix

* cache in generate

* [run-slow] mamba

* address comments

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix cache name

* [run-slow] mamba
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

We Need Compile Support For Mamba!
4 participants