Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mamba generation throughput lower than original due to DecodingCGCache #29699

Open
2 of 4 tasks
y1xia0w opened this issue Mar 17, 2024 · 22 comments
Open
2 of 4 tasks

mamba generation throughput lower than original due to DecodingCGCache #29699

y1xia0w opened this issue Mar 17, 2024 · 22 comments
Labels
Compilation Issues related to torchdynamo and torchinductor Feature request Request for a new feature Good Difficult Issue

Comments

@y1xia0w
Copy link

y1xia0w commented Mar 17, 2024

System Info

Python 3.10.13, CUDA 12.1
GPU = NVIDIA GeForce RTX 2080 Ti. Max memory = 10.747 GB.

torch==2.2.1
torchaudio==2.1.0
torchvision==0.16.0
tokenizers==0.15.2
transformers ==git+https://github.com/huggingface/transformers@dd1c9052159ae824c8acef7c2552f9fad5ca020a
triton==2.2.0
causal_conv1d==git+https://github.com/Dao-AILab/causal-conv1d.git@96456720c00393a5c32872d8352d7a7ec31fb3db#egg=causal_conv1d
mamba_ssm==git+https://github.com/state-spaces/mamba.git@9127d1f47f367f5c9cc49c73ad73557089d02cb8#egg=mamba_ssm

Who can help?

text models: @ArthurZucker and @younesbelkada
generate: @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The key model initialization and generation parts are given as below.

Original code repo

In the original code repo

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m")
model.eval()

model.generate(
        input_ids=input_ids,
        max_length=max_length,
        **cg=True**
    )

Then throughput for generating 1K length is

Number of parameters: 129135360
Prompt length: 100, generation length: 1000
Prompt processing + decoding time: 1011 ms

Using the HF library

from transformers import MambaForCausalLM
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
model.eval()

model.generate(
        input_ids=input_ids,
        max_length=max_length
    )

Then throughput for generating 1K length is

Number of parameters: 129135360
Prompt length: 100, generation length: 1000
state-spaces/mamba-130m-hf prompt processing + decoding time: 15970ms

Expected behavior

The "cg=True" is confirmed to be the part has a significant impact on the generation performance for mamba.

I have tried:

  1. Passing the "use_cache=True" as follows won't affect the results
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", use_cache=True)
or
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", cache_params={use_cache: True})
or
model.config.use_cache=True
  1. Modifying the mamba model to force the argument "use_cache=True" in the MambaModel, but still not working.

I assume this is related to the #29605, but modifying the argument directly seems not solving the problem.

@javiermcebrian
Copy link

Hi! Thanks for raising this issue. I agree. I've realized the same since some weeks ago while testing Mamba, as previously I was using the mamba-ssm repo and was able to to generation much faster using the DecodingCGCache. I've been tracking that other issue too, as although it seems just variables refactor it includes a change in the prepare_inputs_for_generation function passing the kwargs (https://github.com/huggingface/transformers/pull/29605/files#diff-e1d4758c08973fdac2c23a8a3710872d943ce8509035205da4a681bc4dcaf1c3R694). I didn't created any issue as I wondered that PR will be merged soon due to its simplicity, but seems it is not. Also, I'm not sure if that PR solves all the issue.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 19, 2024

Hey! Thanks both I'll dive a bit on this, contributions are also welcome.
The use_cache should be using the cache, and in my test it was, I'll check.

@gante
Copy link
Member

gante commented Mar 20, 2024

Hi folks 👋

It doesn't look like a caching issue, but a compilation one -- under the hood, the cg flag in the original repo triggers compilation (cg probably stands for cuda graphs).

Our implementation is not compatible with fullgraph compilation (the equivalent), which may explain the difference.

@ArthurZucker in case you want to reproduce:

from transformers import MambaForCausalLM
import torch
from time import time

model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", device_map="auto")
model.eval()
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

input_ids = torch.arange(100, device="cuda").unsqueeze(0)
max_length = 1000

start = time()
model = torch.compile(model)
out = model.generate(
    input_ids=input_ids,
    max_length=max_length,
)
print(f"Time: {time() - start:.2f}s")
print(out.shape)

@ArthurZucker
Copy link
Collaborator

Might be the einsum that were all replaces by normal operation, saw some issue with this recently as well

@y1xia0w
Copy link
Author

y1xia0w commented Mar 25, 2024

@gante Nice catch and thank you for your effort to address it!
@ArthurZucker Could you please elaborate a bit more and also share the other related issues? Thank you very much!

@ArthurZucker
Copy link
Collaborator

#29544 is what made me think of that!

@javiermcebrian
Copy link

Hi! Thanks for the findings :)
@ArthurZucker is there any update on this issue? I saw #29544 is merged, but I'm not sure if the compilation issue described affecting the CG is fixed or not.
Thanks!

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Apr 23, 2024

I did not test the compilation with mamba, I right now can't investigate 😢 would love if you can.
Now to be fair, we need to compare greedy generation only, not sampling, not beam search, making sure apple are compared with apple. As said, the compilation is not currently supported

@ArthurZucker
Copy link
Collaborator

I'll set this as a feature request: add support for torch.compile for Mamba

@ArthurZucker ArthurZucker added Feature request Request for a new feature Good Difficult Issue Compilation Issues related to torchdynamo and torchinductor labels Apr 23, 2024
@javiermcebrian
Copy link

Hi @ArthurZucker :)
I tested using exactly the same generation config, just greedy search with no sampling. Original Mamba repo was faster than HF during generation.
So then, just to understand, the only possible fix would be supporting compilation for Mamba so that de CG works as in the original repo?
Thanks!

@gante
Copy link
Member

gante commented May 3, 2024

Added it to our generate + torch.compile tracker, which we are actively tackling :)

@junchen14
Copy link

yes, same issue.

is there any update for the solution?

@ArthurZucker
Copy link
Collaborator

#30139 and #31247 should help!

@javiermcebrian
Copy link

Nice to see that updates! Are there any other pending issues to complete this 29699 one? Thanks!

@ArthurZucker
Copy link
Collaborator

I think it's completed

@javiermcebrian
Copy link

Alright! Was the code above re-tested by anyone to compare speed again? Using which transformers release? Thanks!

@y1xia0w
Copy link
Author

y1xia0w commented Aug 13, 2024

@ArthurZucker Thank you for the updates on the nice job!

I have tested the throughput again, there seems some improvement, however, the difference remains significant.

Package info:

torch==2.2.1
transformers==4.44.0
triton==2.2.0
mamba-ssm==git+https://github.com/state-spaces/mamba.git@9127d1f47f367f5c9cc49c73ad73557089d02cb8
causal_conv1d==git+https://github.com/Dao-AILab/causal-conv1d.git@96456720c00393a5c32872d8352d7a7ec31fb3db

I run the generation with the same 100 prompt length and 1000 generation length 5 times and get the average:

# HF transformers v4.44.0, no flash_attention_2 support
# Time: 14.75s
# torch.Size([1, 1000])

# mamba_ssm, without cg=True
# Time: 13.99s
# torch.Size([1, 1000])

# mamba_ssm, with cg=True
# Time: 1.33s
# torch.Size([1, 1000])

I assume the reason of the generation speed difference would be:

  • The original mamba adopt part of the flash attention mechanism in their implemetation, which significantly accelerate the generation speed.
  • The flash-attn capability on mamba architecutre hasn't been supported in the v4.44.0 transformers library.

Reference:
Quote from mamba document in HF library

The current implementation leverages the original cuda kernels: the equivalent of flash attention for Mamba are hosted in the mamba-ssm and the causal_conv1d repositories. Make sure to install them if your hardware supports them!

@ArthurZucker
Copy link
Collaborator

When you said you tested, did you use torch.compile(..., mode="reduce_overhead")?

@ArthurZucker
Copy link
Collaborator

Did you read this #31247 (comment) in the linked PR for compile support ?

@ArthurZucker
Copy link
Collaborator

cg=True uses cuda graphs, to compare apples with apples you have to use the provided perf snippet

@uuuuuurrrrrr
Copy link

想问下基于tensformer调用模型的吞吐量测试代码是在哪里找的?

@ArthurZucker
Copy link
Collaborator

可以在这里找到:

@torch.no_grad
def perf():
    tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf")
    tokenizer.pad_token = tokenizer.eos_token
    inputs = tokenizer("Hey how are you doing today ? " * 100, return_tensors="pt", padding=True).to('cuda')

    model = MambaForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf", torch_dtype=torch.float16)
    model.config.use_cache = True
    model.to('cuda')    

    input_ids = inputs.input_ids
    cache = MambaCache(model.config, 1, device=input_ids.device)
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    logits = model(input_ids, cache_params = cache).logits
    next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]

    model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
    torch.cuda.synchronize()
    for i in range(10):
        start.record()
        logits = model(next_token.clone(), cache_params = cache).logits
        next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]

        end.record()
        torch.cuda.synchronize()
        print(f'Step {i}, Total time: {start.elapsed_time(end)} ms, next_token = {next_token.int()}')

(#31247 (comment))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Compilation Issues related to torchdynamo and torchinductor Feature request Request for a new feature Good Difficult Issue
Projects
None yet
Development

No branches or pull requests

6 participants