Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch.repeat instead of expand on key & value in Triton MQA to prevent NaNs with certain h_dims #442

Merged
merged 3 commits into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,10 @@ def triton_flash_attn_fn(
h=1 if multiquery else n_heads)

if multiquery:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
# necessary to repeat instead of expand tensor because
# output contains NaN in edge cases such as with head dimension = 8
key = key.repeat(1, 1, n_heads, 1)
value = value.repeat(1, 1, n_heads, 1)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,3 +1385,46 @@ def test_hf_init(tmp_path,
updated_params = next(model.parameters()).clone().data

assert not torch.equal(original_params, updated_params)


@pytest.mark.gpu
def test_head_dim_8_triton_mqa_attn(batch_size=2):
test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')
test_cfg.device = torch.cuda.current_device()

test_cfg.batch_size = batch_size

hf_config = MPTConfig(
init_device='cpu',
d_model=128,
n_heads=16,
n_layers=1,
expansion_ratio=2,
max_seq_len=128,
emb_pdrop=0.1,
resid_pdrop=0.2,
attn_config={
'attn_impl': 'triton',
'attn_type': 'multiquery_attention'
},
)
test_cfg.device = torch.cuda.current_device()

tokenizer = build_tokenizer(test_cfg.tokenizer)

mpt = MPTForCausalLM(hf_config)

model = HuggingFaceModelWithZLoss(mpt, tokenizer, shift_labels=True)

model = model.to(test_cfg.device)
batch = gen_random_batch(batch_size, test_cfg)

assert batch['input_ids'].shape == torch.Size(
[batch_size, test_cfg.max_seq_len])

model.train()

with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
output = model(batch)

assert not torch.isnan(output.logits).any()