Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Should I be getting more speedup/memory reduction from FlashAttention2 with Mistral? #27329

Closed
2 of 4 tasks
cassianlewis opened this issue Nov 6, 2023 · 3 comments
Closed
2 of 4 tasks

Comments

@cassianlewis
Copy link

System Info

transformers: 4.35.0
python: 3.9.13

Who can help?

@SunMarc
@younesbelkada
@Gant

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Setup model

model_id = "mistralai/Mistral-7B-Instruct-v0.1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# load base LLM model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, 
                                              quantization_config=bnb_config, use_flash_attention_2 = True)

Run code for different batch sizes

results = []
for n in range(1, 25):
    
    print(f'Processing {n} examples')
    tokenized_prompt = tokenizer([context]*n, return_tensors="pt")
    
    
    length = len(tokenized_prompt['input_ids'][0])+1
    print(length)
    t0 = time()
    with torch.no_grad():
        output = model.generate(
            inputs = tokenized_prompt['input_ids'],
            max_new_tokens = 400,
            repetition_penalty = 1.2
        )
    t1 = time()

    time_taken = t1 - t0
    mem_usage = memory()  
    new_token_length = len(output[0]) - length
    tokens_per_second = new_token_length * n / time_taken
    time_per_batch = time_taken/n

    print('Time taken = ', time_taken)
    print(f'Tokens/s = {tokens_per_second}')

    gc.collect()
    torch.cuda.empty_cache()

    results.append({'batch_size': n, 'time_taken': time_taken, 
                    'tokens_per_second': tokens_per_second, 'memory_usage': mem_usage, 'time_per_batch':time_per_batch})

Expected behavior

Results

Very little speedup/memory improvement:
flash

Profiling

With FA2:
Screenshot 2023-11-06 at 18 22 46

Without FA2
Screenshot 2023-11-06 at 18 16 58

Would expect better performance given these

@cassianlewis cassianlewis changed the title Should I be getting more speedup/memory reduction from FlashAttention2? Should I be getting more speedup/memory reduction from FlashAttention2 with Mistral? Nov 6, 2023
@younesbelkada
Copy link
Contributor

Hi @cassianlewis
Thanks a lot for the extensive benchmark here!
Many things in your setup can add an overhead here, I think we should test it without use_double_quant=True and make sure to use the same dtype for compute_dtype and torch_dtype. For better efficiency, I would try with torch_dtype=torch.float16 and bnb_4bit_compute_dtype=torch.float16. Moreover, note that the input context length has a big impact in the final latency. Please take a look at the experiments I made here: #26464 (comment) and let me know if you have more questions.

@cassianlewis
Copy link
Author

cassianlewis commented Nov 7, 2023

Hi @younesbelkada, thanks for the reply.
I did actually do testing with fp16 and the results were largely the same:

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

# load base LLM model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir = 'm', torch_dtype=torch.float16, 
                                               quantization_config=bnb_config, use_flash_attention_2 = True)

Results

Slight improvements in time/memory. FYI the input is ~2k tokens (but no padding) so this may benefit from longer input sequences.
flash_fp16

I then tested it with max_new_tokens = 1 (ie just looking at prefill) and saw a much larger discrepancy:

flash_s

This is more in line with what you posted in #26464 (comment)

So it looks like, for decoding at least, the speedups are fairly minimal for the kind of input sequence lengths/batch sizes I'm using.

@younesbelkada
Copy link
Contributor

younesbelkada commented Nov 7, 2023

Thanks a lot for this benchmark ! Yes I think this result is pretty much inline with my findings.
Flash Attention is great for prefill indeed (which is proven by one of your plots), although for generation it is not 'as fast' as prefill you can note that you can fit larger batch size with the same memory budget (~10 for native vs ~16 for FA2). I think FA-2 + HF transformers really shines in the context of training / fine-tuning because you can fit much larger sequence length / batch size, hence improve efficiency of your training setup.
If we want to increase generation throughput one needs to use static KV cache + flash decoding

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants