-
Notifications
You must be signed in to change notification settings - Fork 519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ALiBi with causal=True
unexpected bias?
#186
Comments
for CLM, the mask you showed:
and
are mathematically equivalent (since softmax(x) = softmax(x + c) for any constant c) Attention is pretty BW limited; passing in This is noted in this thread as well: ofirpress/attention_with_linear_biases#5 |
@vchiley Thanks for the detailed and fast response! Yes, I've missed that they are actually mathematically equivalent. What does BW limited mean? I would only expect a memory change when using a smaller broadcasted For my use case, I plan to experiment with mixture of denoisers similar to UL2 and Palm2. I plan to have batches that can have both PLM and CLM. If its more efficient I can also have a full batch with CLM and another with PLM - but need to test how much faster it is and if gradients/performance affected. So for that, I set
I think the symmetric implementation in this repo looks good to me since I plan to use. a decoder only model, thanks for the thread reference. |
BW limited == Bandwidth limited
With |
I haven't done granular layer-wise profiling but in my case using This is using torch attn impl. still waiting for triton update on the flash attention side the latest nvidia NGC docker image I use has pytorch 2.0). Tried the manual edit approach Dao-AILab/flash-attention#232 but didn't work. Eventually I hope to test triton attn + torch.compile + prefix lm. I also tried FSDP DDP but interestingly it was much slower, but that's another discussion (maybe amp vs fp16 difference - deepspeed could fit bs=16 but fsdp only 12 to a A100-80GB). Also, models are loaded from HF, so it is not using the original code from llm-foundry repo - I don't know if you have further optimizations which might be missing in the HF remote code. Using a simple flag in the same script to test both approaches otherwise everything else is same: def create_prefix_masks(batch_size, sequence_length, prefix_lengths, device):
prefix_mask = torch.zeros(batch_size, sequence_length, device=device)
for pm,n in zip(prefix_mask, prefix_lengths): pm[:n].add_(1)
return prefix_mask
# Forward pass.
labels = batch.pop("labels")
if args.prefix_lm:
logging.info("Create dummy prefix mask with 0 prefix length - CLM")
bs = len(batch["input_ids"])
prefix_mask = create_prefix_masks(bs, args.block_size, prefix_lengths=[0]*bs, device=batch["input_ids"].device)
outputs = model(**batch, prefix_mask=prefix_mask.bool())
else:
outputs = model(**batch)
|
The effect I was talking about are only true for What is "FSDP DDP"? |
btw this, once merged, will enable torch2 to be used with the triton kernel. |
hmm, I see I will try triton impl once it is merged. Is it still based on the hazy implementation? Since PR adds a new file it was difficult to see the diff for changes. This is the diff I could find, I guess you are also pinning to https://github.com/vchiley/triton/tree/main/python. Not sure what is the problem with triton 2.0.0., I will also give a try this fix triton-lang/triton#1098.
FSDP: Fully Sharded Data Parallel, I used the integration from HF accelerate. DDP is just the distributed data parallel mode of it, meaning that model is copied over each core/gpu. Since it is only 1B params we don't need to shard optimizer, gradients or parameters. |
Just checked out the latest code and it seems like pinned triton version works with pytorch 2.0. Although, I didn't see much of a speed up just 5% and it is not compatible with def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
def forward(mod,inp):
return mod(**inp)
def forward_backward(mod,inp):
loss = mod(**inp).logits.sum()
loss.backward()
return loss
q = torch.randn(4,16,8*128).cuda().half()
k = torch.randn(4,16,8*128).cuda().half()
v = torch.randn(4,16,8*128).cuda().half()
%%timeit
out = triton_flash_attn_fn(q,k,v,n_heads=8)
129 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%%timeit
out = scaled_multihead_dot_product_attention(q,k,v,n_heads=8)
133 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
### Test models
# eager model with triton attn
config = MPTConfig()
config.attn_config['attn_impl'] = "triton"
config.attn_config['qk_ln'] = True
config.attn_config['prefix_lm'] = True
model = MPTForCausalLM(config)
model.cuda().half();
bs = 4
seqlen = 256
x = torch.randint(0,config.vocab_size,(bs,seqlen)).cuda()
prefix_mask = torch.zeros(bs, seqlen).cuda().half()
inp = {"input_ids":x, "prefix_mask":prefix_mask}
print("eager/triton[forward]:", timed(lambda: forward(model, inp))[1])
print("eager/triton[forward_backward]:", timed(lambda: forward_backward(model, inp))[1])
eager/triton[forward]: 0.02111692810058594
eager/triton[forward_backward]: 0.062268417358398435
# eager model with torch attn
config = MPTConfig()
config.attn_config['attn_impl'] = 'torch'
config.attn_config['qk_ln'] = True
config.attn_config['prefix_lm'] = True
model = MPTForCausalLM(config)
model.cuda().half();
print("eager/torch[forward]:", timed(lambda: forward(model, inp))[1])
print("eager/torch[forward_backward]:", timed(lambda: forward_backward(model, inp))[1])
eager/torch[forward]: 0.02424831962585449
eager/torch[forward_backward]: 0.06596710205078125 Did you benchmarked torch vs triton, how much improvement did you see? |
In our experience, |
Thanks for sharing, I would be very interested to see your throughput (e.g. seq / sec / gpu) and wall time numbers for training mpt 1B and 7B models with the hardware specs. Maybe this is a pytorch 2.0 difference I am not sure if they optimized few things and if that is causing this small difference between torch vs triton. |
The tables here list the throughput of a large set of networks (using If, using seq len 1024, you can get a 3B model using
|
This is very useful thanks for sharing. So far I've only tested with 2 A100 80 GPUs / pytorch 2.0 / torch attn impl / deepspeed (zero disabled) DDP fp16 / using 1B model and 1024 seqlen. With this setup I got ~ 14-15 seq/sec/gpu but looking at your table you are able to get ~21 seq/sec/gpu which is pretty impressive. I will run some tests similar to your setup. With triton I am able to increase batch size 16 -> 24, which increased the throughput to 18.6 seq/sec/gpu. If I can fix apex installation (which got broken after installing llm-foundry from source - it installed different cuda extension for the pinned triton) and use fused adam / fused cross entropy it can further increase and hopefully get close to your 21 seq/sec/gpu. Hopefully flash attn triton can get updated for the latest triton version so we can combine it with |
Note: those tables use the base model config here (without alibi or prefix mask) |
Looking at the alibi bias generation function it seems to me that when
causal=True
a fixed relative position bias (let's ignore head constant for now) is broadcasted to all queries in the q-k attention scores . Let's assumeseq_len=5
:When we have
full=False
, e.g.causal=True
andprefix_lm=False
, we get the following alibi bias before head constant is applied:Reading the AliBi paper my expectation would be an attention bias looking like:
When
full=True
, e.g.causal=False
andprefix_lm=True
we get the alibi bias below. I think this alibi bias is suitable for any type of model, e.g. CLM(causal lm), PLM(prefix lm), even BERT style mask filling:The text was updated successfully, but these errors were encountered: