Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlashAttention works with single GPU, but crash with accelerate DP on multiple GPU (FlashAttention only support fp16 and bf16 data type) #822

Open
Andcircle opened this issue Feb 10, 2024 · 8 comments

Comments

@Andcircle
Copy link

System Info

`Accelerate` version: 0.22.0
Platform: Linux-5.10.192-183.736.amzn2.x86_64-x86_64-with-glibc2.29
Python version: 3.8.10
Numpy version: 1.23.1
PyTorch version (GPU?): 2.0.1+cu117 (True)
PyTorch XPU available: False
PyTorch NPU available: False
System RAM: 1121.81 GB
GPU type: NVIDIA A100-SXM4-80GB

transformers              4.37.2
trl                       0.7.11.dev0
flash-attn                2.5.2

out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

Reproduction

The following script works as expected on 1 GPU, but if running on multiple GPU with DP, it will give error:
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

import os
import wandb

import torch
from accelerate import Accelerator
from datasets import load_from_disk
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments
)

from trl import DataCollatorForCompletionOnlyLM

from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, AutoTokenizer

import sys
project_root = '/'.join(os.path.dirname(__file__).split('/')[:-1])
print(project_root)
sys.path.append(project_root)
from utils.meta_loader import write_meta, read_meta

import transformers

# bench
alpha = 16
rank = 64
batch_size = 2
length = 4096
accumlate_steps = 1
lr = 5e-5

train_dataset = load_from_disk("/mnt/localssd/dataset/llava_processed_dataset/train")
eval_dataset = load_from_disk("/mnt/localssd/dataset/llava_processed_dataset/test")    

run_name = "llava_debug"
save_dir = "/mnt/localssd/llava_debug"

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    # load_in_8bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
    # llm_int8_skip_modules=["multi_modal_projector"]
)

model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    # "llava-hf/bakLlava-v1-hf",
    quantization_config=bnb_config,
    trust_remote_code=True, 
    device_map={'':torch.cuda.current_device()},
    torch_dtype=torch.float16,
    use_flash_attention_2=True
    )

target_modules = [
    "*language_model.*q_proj", 
    "*language_model.*k_proj", 
    "*language_model.*v_proj", 
    "*language_model.*o_proj", 
    "*language_model.*gate_proj", 
    "*language_model.*up_proj", 
    "*language_model.*down_proj", 
    "*language_model.*lm_head"]

modules_to_save = ["linear_1", "linear_2"]
    
peft_config = LoraConfig(
    lora_alpha=alpha,
    lora_dropout=0.1,
    r=rank,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules,
    modules_to_save=modules_to_save
)

tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model = get_peft_model(model, peft_config)

training_arguments = TrainingArguments(
    output_dir=save_dir,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=accumlate_steps,
    optim="paged_adamw_32bit",
    save_steps=500,
    logging_steps=10,
    learning_rate=lr,
    fp16=True,
    max_grad_norm=0.3,
    num_train_epochs=100,
    warmup_ratio=0.03,
    # group_by_length=True,
    lr_scheduler_type="constant",
    run_name=run_name,
    evaluation_strategy="steps",
    eval_steps=200,
    ddp_find_unused_parameters=False,
    gradient_checkpointing=True,
    # weight_decay=0.01,
    # dataloader_num_workers=NUM_PROC//2
)


model.config.use_cache = False # not use for fine tuning

def test_data_collator(datas):
    result = {}
    input_ids = [torch.Tensor(d['input_ids']) for d in datas]
    attention_mask = [torch.Tensor(d['attention_mask']) for d in datas]
    pixel_values = [torch.Tensor(d['pixel_values']) for d in datas]
    labels = [torch.Tensor(d['labels']) for d in datas]
    
    result['input_ids'] = torch.concat(input_ids).type(torch.int64)
    result['attention_mask'] = torch.concat(attention_mask).type(torch.int64)
    result['pixel_values'] = torch.concat(pixel_values)
    result['labels'] = torch.concat(labels).type(torch.int64)
    return result
    

trainer = transformers.Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_arguments,
    data_collator=test_data_collator
)

trainer.train()```

### Expected behavior

Expect the behavior should be the same for both single GPU and Multi GPU
@tridao
Copy link
Contributor

tridao commented Feb 10, 2024

I'm not familiar with accelerate or how transformers uses FlashAttention, you'd probably get better help asking on those repos.

@ArthurZucker
Copy link

I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue.
Reproducer is here with attn_implementation="flash_attention_2" and the corresponding PR on transformers.

- `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.27.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
  • flash_attn=2.5.3 + torch nightly so (2.3 ish)

@ArthurZucker
Copy link

ArthurZucker commented Feb 12, 2024

>>> from flash_attn import flash_attn_func
>>> import torch
>>> print(torch.__version__)
2.3.0.dev20240208+cu121
>>> flash_attn_func(torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), 1, softmax_scale=1, causal=True)

....

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:51, in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
     49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
     52     q,
     53     k,
     54     v,
     55     None,
     56     alibi_slopes,
     57     dropout_p,
     58     softmax_scale,
     59     causal,
     60     window_size[0],
     61     window_size[1],
     62     return_softmax,
     63     None,
     64 )
     65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: FlashAttention only support fp16 and bf16 data type

this doesn't work for me again, might be because I have.
cc @tridao not sure how relevant this is

@tridao
Copy link
Contributor

tridao commented Feb 12, 2024

this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is

The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).

@ArthurZucker
Copy link

ArthurZucker commented Feb 12, 2024

The error is before that, but it seems it's torch nightly, the transformers snippet works with torch2.2 ! (vs getting the FlashAttention only support fp16 and bf16 data type with nightly)
So more reliable.
(I am getting RuntimeError: q must be on CUDA with my snippet on torch2.2 so different error)

@tridao
Copy link
Contributor

tridao commented Feb 12, 2024

I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue. Reproducer is here with attn_implementation="flash_attention_2" and the corresponding PR on transformers.

- `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.27.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
  • flash_attn=2.5.3 + torch nightly so (2.3 ish)

I can't run the reproducer right now bc StaticCache is not in transformers 4.37.2 (latest stable version).

@yiakwy-xpu-ml-framework-team
Copy link

yiakwy-xpu-ml-framework-team commented Mar 19, 2024

this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is

The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).

Yeah flash attention uses (batch , seqlen, nheads, headdim ) to represent inputs, however in many software (triton, for example) we have reasons to use (batch, nheads, seqlen, headim) for easy arrangement of layout.

Actually they are equivalent with this mapping:

    def permute(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.nheads, self.headim)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

But it is weird that the error (I have tested in the lastest version) says "FlashAttention only support fp16 and bf16 data type".

# mha_fwd https://github.com/Dao-AILab/flash-attention/blob/6bbc532388e61185a92e2a563126739967b4c8c5/csrc/flash_attn/flash_api.cpp#L339-L339
    bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
    bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
    TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
    // We will support Turing in the near future
    // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");

    auto q_dtype = q.dtype();
    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
                "FlashAttention only support fp16 and bf16 data type");
    if (q_dtype == torch::kBFloat16) {
        TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
    }

I have checked the repo, we need to update our C++ templates to support various dtype, I have experiences in near memory chip op libs. Currently I have to do these unnecessary cast to help teams to use flash attention v2:

    if q.dtype == torch.float32:
        q = q.to(torch.float16, non_blocking=True)
        k = k.to(torch.float16, non_blocking=True)
        v = v.to(torch.float16, non_blocking=True)
    elif q.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz):
        capability = torch.cuda.get_device_capability()
        if capability[0] <= 8:
            raise RuntimeError("Flash attention for FP8 (need hoper TE support) is currently only supported for compute capability >= 80")
        else:
            # TODO (yiakwy) : add FP8 support
            raise NotImplemented
        
    output = flash_attn_func(q, k, v, dropout_p=self.dropout.p, causal=is_causal)
    output = revert_mold_flash_attn_input(output)
        
    if output_attentions:
        raise Exception("Does not support output attention weights inside flash attention.")
    
    if output.dtype != torch.float32:
        # TODO (yiakwy) : add support of fp16 and bf16
        # if output dtype is not FP32 (by default Flash attetnion generate FP16 output), we need to cast it back
        output = output.to(torch.float32, non_blocking=True)

So we need to update the error information, right ?

@thepowerfuldeez
Copy link

I confirm that flash-attn==2.5.6 doesn't work with torch==2.3.0a0+40ec155e58.nv24.3 nightly even though inputs are indeed in torch.bfloat16 format!
I rolled back to torch2.2 stable and reinstalled flash-attn and now it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants