Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ALiBi with causal=True unexpected bias? #186

Closed
KeremTurgutlu opened this issue May 21, 2023 · 13 comments
Closed

ALiBi with causal=True unexpected bias? #186

KeremTurgutlu opened this issue May 21, 2023 · 13 comments
Assignees

Comments

@KeremTurgutlu
Copy link

KeremTurgutlu commented May 21, 2023

Looking at the alibi bias generation function it seems to me that when causal=True a fixed relative position bias (let's ignore head constant for now) is broadcasted to all queries in the q-k attention scores . Let's assume seq_len=5:

When we have full=False, e.g. causal=True and prefix_lm=False, we get the following alibi bias before head constant is applied:

alibi_bias = torch.arange(1 - seq_len, 1).view(1, 1, 1, seq_len)
alibi_bias, alibi_bias.shape

# shape to be broadcasted is bs x heads x seqlen (query) x seqlen (key)
(tensor([[[[-4, -3, -2, -1,  0]]]]), torch.Size([1, 1, 1, 5]))

Reading the AliBi paper my expectation would be an attention bias looking like:

Screenshot 2023-05-21 at 2 22 54 PM
# key positions after 0 will be ignored by attn_mask
tensor([[[[ 0, -1, -2, -3, -4],
          [-1,  0, -1, -2, -3],
          [-2, -1,  0, -1, -2],
          [-3, -2, -1,  0, -1],
          [-4, -3, -2, -1,  0]]]])

tensor([[[[ 0, X, X, X, X],
          [-1,  0, X, X, X],
          [-2, -1,  0, X, X],
          [-3, -2, -1,  0, X],
          [-4, -3, -2, -1,  0]]]])

When full=True, e.g. causal=False and prefix_lm=True we get the alibi bias below. I think this alibi bias is suitable for any type of model, e.g. CLM(causal lm), PLM(prefix lm), even BERT style mask filling:

tensor([[[[ 0, -1, -2, -3, -4],
          [-1,  0, -1, -2, -3],
          [-2, -1,  0, -1, -2],
          [-3, -2, -1,  0, -1],
          [-4, -3, -2, -1,  0]]]])
  1. CLM: Once causal mask is applied positions after 0 will be ignore e.g.:
tensor([[[[ 0, X, X, X, X],
          [-1,  0, X, X, X],
          [-2, -1,  0, X, X],
          [-3, -2, -1,  0, X],
          [-4, -3, -2, -1,  0]]]])
  1. PLM: Once prefix lm mask is applied with a prefix window 3:
tensor([[[[ 0, -1, -2, X, X],
          [-1,  0, -1, X, X],
          [-2, -1,  0, X, X],
          [-3, -2, -1,  0, X],
          [-4, -3, -2, -1,  0]]]])
  1. Mask in-filling, no need to apply any mask:
tensor([[[[ 0, -1, -2, -3, -4],
          [-1,  0, -1, -2, -3],
          [-2, -1,  0, -1, -2],
          [-3, -2, -1,  0, -1],
          [-4, -3, -2, -1,  0]]]])
@vchiley
Copy link
Contributor

vchiley commented May 22, 2023

full=True is the generic correct base case

for CLM, the mask you showed:

tensor([[[[ 0,  X,  X,  X,  X],
          [-1,  0,  X,  X,  X],
          [-2, -1,  0,  X,  X],
          [-3, -2, -1,  0,  X],
          [-4, -3, -2, -1,  0]]]])

and

tensor([[[[-4,  X,  X,  X,  X],
          [-4, -3,  X,  X,  X],
          [-4, -3, -2,  X,  X],
          [-4, -3, -2, -1,  X],
          [-4, -3, -2, -1,  0]]]])

are mathematically equivalent (since softmax(x) = softmax(x + c) for any constant c)
The bias tensor on the bottom is equivalent to having tensor([[[[-4, -3, -2, -1, 0]]]]) and broadcasting it along the other dimension, then applying causal masking.
Note: this only works if you are applying causal masking.

Attention is pretty BW limited; passing in tensor([[[[-4, -3, -2, -1, 0]]]]) instead of the full SxS matrix, saves BW and makes the kernel run faster.
If you are running the attention kernel with alibi, and are using causal masking, then you should use the tensor([[[[-4, -3, -2, -1, 0]]]]) tensor instead of the full SxS tensor.

This is noted in this thread as well: ofirpress/attention_with_linear_biases#5

@vchiley vchiley self-assigned this May 22, 2023
@KeremTurgutlu
Copy link
Author

KeremTurgutlu commented May 22, 2023

@vchiley Thanks for the detailed and fast response! Yes, I've missed that they are actually mathematically equivalent. What does BW limited mean? I would only expect a memory change when using a smaller broadcasted tensor([[[[-4, -3, -2, -1, 0]]]]) but wouldn't imagine speed difference. I will definitely test and time it out, so thanks for pointing it.

For my use case, I plan to experiment with mixture of denoisers similar to UL2 and Palm2. I plan to have batches that can have both PLM and CLM. If its more efficient I can also have a full batch with CLM and another with PLM - but need to test how much faster it is and if gradients/performance affected. So for that, I set prefix_lm=True. For example;

# batch size = 4, prefix_mask for the batch
tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], # plm with context length = 3
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.], # plm with context length = 8
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], # plm with context length = 5
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]) # clm

I think the symmetric implementation in this repo looks good to me since I plan to use. a decoder only model, thanks for the thread reference.

@vchiley
Copy link
Contributor

vchiley commented May 22, 2023

BW limited == Bandwidth limited
The BW to compute ratio of GPUs means what when a CUDA kernel accesses a float from HBM2 memory, that float should be used about 100 times or else tensorcores will be starved for actual data to process. This is formally known as: BW limited.
If attn_bias.shape == 1xHx1xS, when a bias term is accessed it can be reused in B*S floating point operations.
if attn_bias.shape == BxHxSxS, when a bias term is accessed it can be used in 1 floating point operations.

but wouldn't imagine speed difference

With attn_bias.shape == BxHxSxS, doing some benchmarking with 1B model, I remember going from 50% + MFU to 42% MFU (vs attn_bias.shape == 1xHx1xS with broadcasts)
There is a speed difference.

@KeremTurgutlu
Copy link
Author

KeremTurgutlu commented May 23, 2023

I haven't done granular layer-wise profiling but in my case using prefix mask (bs x L) vs causal is almost same, and even prefix_lm is a bit better. I logged the moving average of throughput for both cases using deepspeed fp16 DDP and MPT-1B model on a 2 x A100-80GB where seqlen = 1024.

This is using torch attn impl. still waiting for triton update on the flash attention side the latest nvidia NGC docker image I use has pytorch 2.0). Tried the manual edit approach Dao-AILab/flash-attention#232 but didn't work. Eventually I hope to test triton attn + torch.compile + prefix lm.

I also tried FSDP DDP but interestingly it was much slower, but that's another discussion (maybe amp vs fp16 difference - deepspeed could fit bs=16 but fsdp only 12 to a A100-80GB).

Also, models are loaded from HF, so it is not using the original code from llm-foundry repo - I don't know if you have further optimizations which might be missing in the HF remote code.

Using a simple flag in the same script to test both approaches otherwise everything else is same:

def create_prefix_masks(batch_size, sequence_length, prefix_lengths, device):
    prefix_mask = torch.zeros(batch_size, sequence_length, device=device)
    for pm,n in zip(prefix_mask, prefix_lengths): pm[:n].add_(1)
    return prefix_mask

# Forward pass.
labels = batch.pop("labels")
        
if args.prefix_lm:
    logging.info("Create dummy prefix mask with 0 prefix length - CLM")
    bs = len(batch["input_ids"])
    prefix_mask = create_prefix_masks(bs, args.block_size, prefix_lengths=[0]*bs, device=batch["input_ids"].device)
    outputs = model(**batch, prefix_mask=prefix_mask.bool())
else:
    outputs = model(**batch)
# prefix lm
steps_per_sec: 1.02 steps_per_sec_per_gpu: 0.51 seqs_per_sec: 32.61 seqs_per_sec_per_gpu: 16.30
# causal
steps_per_sec: 0.93 steps_per_sec_per_gpu: 0.47 seqs_per_sec: 29.90 seqs_per_sec_per_gpu: 14.95

plm.log

clm.log

@vchiley
Copy link
Contributor

vchiley commented May 23, 2023

The effect I was talking about are only true for attn_impl: triton.
attn_impl: torch will be slow with either causal or prefixlm.

What is "FSDP DDP"?

@vchiley
Copy link
Contributor

vchiley commented May 23, 2023

btw this, once merged, will enable torch2 to be used with the triton kernel.

@KeremTurgutlu
Copy link
Author

KeremTurgutlu commented May 24, 2023

The effect I was talking about are only true for attn_impl: triton. attn_impl: torch will be slow with either causal or prefixlm.

hmm, I see I will try triton impl once it is merged. Is it still based on the hazy implementation? Since PR adds a new file it was difficult to see the diff for changes. This is the diff I could find, I guess you are also pinning to https://github.com/vchiley/triton/tree/main/python. Not sure what is the problem with triton 2.0.0., I will also give a try this fix triton-lang/triton#1098.

1a2,4
> Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
> update imports to use 'triton_pre_mlir'
> 
46,47c49,50
< import triton
< import triton.language as tl
---
> import triton_pre_mlir as triton
> import triton_pre_mlir.language as tl

What is "FSDP DDP"?

FSDP: Fully Sharded Data Parallel, I used the integration from HF accelerate. DDP is just the distributed data parallel mode of it, meaning that model is copied over each core/gpu. Since it is only 1B params we don't need to shard optimizer, gradients or parameters.

@KeremTurgutlu
Copy link
Author

KeremTurgutlu commented May 25, 2023

Just checked out the latest code and it seems like pinned triton version works with pytorch 2.0. Although, I didn't see much of a speed up just 5% and it is not compatible with torch.compile since it is using a custom triton installation now.

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

def forward(mod,inp):
    return mod(**inp)

def forward_backward(mod,inp):
    loss = mod(**inp).logits.sum()
    loss.backward()
    return loss

q = torch.randn(4,16,8*128).cuda().half()
k = torch.randn(4,16,8*128).cuda().half()
v = torch.randn(4,16,8*128).cuda().half()

%%timeit
out = triton_flash_attn_fn(q,k,v,n_heads=8)

129 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%%timeit
out = scaled_multihead_dot_product_attention(q,k,v,n_heads=8)

133 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

### Test models

# eager model with triton attn
config = MPTConfig()
config.attn_config['attn_impl'] = "triton"
config.attn_config['qk_ln'] = True
config.attn_config['prefix_lm'] = True
model = MPTForCausalLM(config)
model.cuda().half();

bs = 4
seqlen = 256
x = torch.randint(0,config.vocab_size,(bs,seqlen)).cuda()
prefix_mask = torch.zeros(bs, seqlen).cuda().half()
inp = {"input_ids":x, "prefix_mask":prefix_mask}

print("eager/triton[forward]:", timed(lambda: forward(model, inp))[1])
print("eager/triton[forward_backward]:", timed(lambda: forward_backward(model, inp))[1])

eager/triton[forward]: 0.02111692810058594
eager/triton[forward_backward]: 0.062268417358398435

# eager model with torch attn
config = MPTConfig()
config.attn_config['attn_impl'] = 'torch'
config.attn_config['qk_ln'] = True
config.attn_config['prefix_lm'] = True
model = MPTForCausalLM(config)
model.cuda().half();

print("eager/torch[forward]:", timed(lambda: forward(model, inp))[1])
print("eager/torch[forward_backward]:", timed(lambda: forward_backward(model, inp))[1])

eager/torch[forward]: 0.02424831962585449
eager/torch[forward_backward]: 0.06596710205078125

Did you benchmarked torch vs triton, how much improvement did you see?

@vchiley
Copy link
Contributor

vchiley commented May 26, 2023

In our experience, attn_impl: triton is more memory efficient and is much faster.

@KeremTurgutlu
Copy link
Author

Thanks for sharing, I would be very interested to see your throughput (e.g. seq / sec / gpu) and wall time numbers for training mpt 1B and 7B models with the hardware specs. Maybe this is a pytorch 2.0 difference I am not sure if they optimized few things and if that is causing this small difference between torch vs triton.

@vchiley
Copy link
Contributor

vchiley commented May 27, 2023

The tables here list the throughput of a large set of networks (using attn_impl: triton).
Some of the highlights can be seen here.

If, using seq len 1024, you can get a 3B model using attn_impl: torch to produce almost 62% flop utilization, let me know how you did that 👀

attn_impl: torch uses way more memory and therefore requires using smaller micro batch sizes durring training. Lower micro batch size results in much lower FLOP utilization and throughput.

@KeremTurgutlu
Copy link
Author

KeremTurgutlu commented May 27, 2023

This is very useful thanks for sharing. So far I've only tested with 2 A100 80 GPUs / pytorch 2.0 / torch attn impl / deepspeed (zero disabled) DDP fp16 / using 1B model and 1024 seqlen. With this setup I got ~ 14-15 seq/sec/gpu but looking at your table you are able to get ~21 seq/sec/gpu which is pretty impressive. I will run some tests similar to your setup.

With triton I am able to increase batch size 16 -> 24, which increased the throughput to 18.6 seq/sec/gpu. If I can fix apex installation (which got broken after installing llm-foundry from source - it installed different cuda extension for the pinned triton) and use fused adam / fused cross entropy it can further increase and hopefully get close to your 21 seq/sec/gpu.

Hopefully flash attn triton can get updated for the latest triton version so we can combine it with torch.compile.

@vchiley
Copy link
Contributor

vchiley commented May 27, 2023

Note: those tables use the base model config here (without alibi or prefix mask)

@vchiley vchiley closed this as completed May 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants