Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouped Query Attention + Refactor Attn #492

Merged
merged 31 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f5411bb
Add grouped query attention implementation + refactor
Jul 26, 2023
35d410a
add config description and support in blocks, unpadding and fix test
Jul 27, 2023
5d26ff6
add tests and clone attn_config
sashaDoubov Jul 27, 2023
d07de71
Merge branch 'main' into generalized_attn
sashaDoubov Jul 27, 2023
952244a
fix comment
sashaDoubov Jul 27, 2023
b950eb0
fix bad pop of attn_impl
sashaDoubov Jul 27, 2023
f5af9d8
change init in test_model to be cpu and then move to gpu for stability
sashaDoubov Jul 27, 2023
985adc2
comment out test temporarily
sashaDoubov Jul 27, 2023
ff4c12c
undo commenting out test
sashaDoubov Jul 31, 2023
e8d61d5
Merge branch 'main' into generalized_attn
sashaDoubov Aug 1, 2023
9f792bf
Merge branch 'main' of github.com:mosaicml/llm-foundry into HEAD
sashaDoubov Aug 3, 2023
d84cfa9
Fixing pyright issues
sashaDoubov Aug 3, 2023
916d8dd
Merge branch 'main' of github.com:mosaicml/llm-foundry into generaliz…
sashaDoubov Aug 3, 2023
a663dde
fix return types
sashaDoubov Aug 3, 2023
f873835
fix typo
sashaDoubov Aug 3, 2023
a53d366
rename to GQA
sashaDoubov Aug 3, 2023
89bf249
Update llmfoundry/models/layers/attention.py
sashaDoubov Aug 3, 2023
083c033
update with pre-commit
sashaDoubov Aug 3, 2023
cf3f039
Update llmfoundry/models/layers/attention.py
sashaDoubov Aug 9, 2023
c0e979e
Update llmfoundry/models/mpt/modeling_mpt.py
sashaDoubov Aug 9, 2023
273a9aa
Merge branch 'main' of github.com:mosaicml/llm-foundry into generaliz…
sashaDoubov Aug 9, 2023
1fa9a19
fix asserts and add deprecation
Aug 10, 2023
3dcbc45
fix docstring and change mutable construct
Aug 10, 2023
24e2257
Update llmfoundry/models/layers/attention.py
sashaDoubov Aug 10, 2023
adbdce0
Update llmfoundry/models/layers/attention.py
sashaDoubov Aug 10, 2023
886b709
Update llmfoundry/models/layers/attention.py
sashaDoubov Aug 10, 2023
b3d19ca
Update llmfoundry/models/layers/attention.py
sashaDoubov Aug 10, 2023
1367c2a
Update llmfoundry/models/layers/attention.py
sashaDoubov Aug 10, 2023
ab78c51
Update llmfoundry/models/layers/attention.py
sashaDoubov Aug 10, 2023
a71693c
change to set and remove extraneous
Aug 10, 2023
85751ba
add default optional + deprecation
sashaDoubov Aug 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 104 additions & 128 deletions llmfoundry/models/layers/attention.py
vchiley marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def scaled_multihead_dot_product_attention(
key: torch.Tensor,
value: torch.Tensor,
n_heads: int,
kv_n_heads: int,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
Expand All @@ -44,11 +45,9 @@ def scaled_multihead_dot_product_attention(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
sashaDoubov marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
torch.Tensor]]]:
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
kv_n_heads = 1 if multiquery else n_heads
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)

Expand All @@ -68,6 +67,11 @@ def scaled_multihead_dot_product_attention(
b, _, s_q, d = q.shape
s_k = k.size(-1)

# grouped query case
if kv_n_heads > 1 and kv_n_heads < n_heads:
k = k.repeat_interleave(n_heads // kv_n_heads, dim=1)
v = v.repeat_interleave(n_heads // kv_n_heads, dim=1)

if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)

Expand Down Expand Up @@ -143,6 +147,7 @@ def flash_attn_fn(
key: torch.Tensor,
value: torch.Tensor,
n_heads: int,
kv_n_heads: int,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
Expand All @@ -151,7 +156,6 @@ def flash_attn_fn(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
torch.Tensor]]]:
try:
Expand Down Expand Up @@ -189,16 +193,13 @@ def flash_attn_fn(

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
key, key_padding_mask)
key_unpad = rearrange(key_unpad,
'nnz (h d) -> nnz h d',
h=1 if multiquery else n_heads)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
value_unpad = rearrange(value_unpad,
'nnz (h d) -> nnz h d',
h=1 if multiquery else n_heads)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

if multiquery:
# multi-query case
if kv_n_heads == 1:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
Expand All @@ -209,6 +210,14 @@ def flash_attn_fn(
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use use the underlying MHA kernels
# done along the head dimension = 1
key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1)
value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads,
dim=1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -238,6 +247,7 @@ def triton_flash_attn_fn(
key: torch.Tensor,
value: torch.Tensor,
n_heads: int,
kv_n_heads: int,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
Expand All @@ -246,7 +256,6 @@ def triton_flash_attn_fn(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
torch.Tensor]]]:
try:
Expand Down Expand Up @@ -318,16 +327,22 @@ def triton_flash_attn_fn(
torch.finfo(query.dtype).min)

query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
value = rearrange(value,
'b s (h d) -> b s h d',
h=1 if multiquery else n_heads)
key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)

if multiquery:
# multi-query case
if kv_n_heads == 1:
# necessary to repeat instead of expand tensor because
# output contains NaN in edge cases such as with head dimension = 8
key = key.repeat(1, 1, n_heads, 1)
value = value.repeat(1, 1, n_heads, 1)
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use use the underlying MHA kernels
# done along dim = 2, unlike the implementation for flash and torch attn
key = key.repeat_interleave(n_heads // kv_n_heads, dim=2)
value = value.repeat_interleave(n_heads // kv_n_heads, dim=2)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func( # type: ignore
Expand All @@ -338,8 +353,8 @@ def triton_flash_attn_fn(
return output, None, past_key_value


class MultiheadAttention(nn.Module):
"""Multi-head self attention.
class GeneralizedAttention(nn.Module):
"""A generalization of Multi-head, Multi-Query, Multi-Grouped Attention.

Using torch or triton attention implementation enables user to also use
additive bias.
Expand All @@ -349,6 +364,7 @@ def __init__(
self,
d_model: int,
n_heads: int,
kv_n_heads: int,
attn_impl: str = 'triton',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
Expand All @@ -367,6 +383,14 @@ def __init__(

self.d_model = d_model
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads

self.head_dim = d_model // n_heads

assert self.kv_n_heads > 0, 'kv_n_heads should be greater than zero'
sashaDoubov marked this conversation as resolved.
Show resolved Hide resolved
assert self.kv_n_heads <= self.n_heads, 'The number of KV heads should be less than or equal to Q heads'
assert self.n_heads % self.kv_n_heads == 0, 'Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads'

self.softmax_scale = softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
Expand All @@ -377,17 +401,21 @@ def __init__(
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
self.d_model,
3 * self.d_model,
self.d_model + 2 * self.kv_n_heads * self.head_dim,
**fc_kwargs,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = (d_model, 2 * d_model)
fuse_splits = [
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
self.Wqkv._fused = (0, fuse_splits) # type: ignore

if self.qk_ln:
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
self.q_ln = norm_class(self.d_model, device=device)
self.k_ln = norm_class(self.d_model, device=device)
self.k_ln = norm_class(self.kv_n_heads * self.head_dim,
device=device)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
Expand Down Expand Up @@ -432,7 +460,11 @@ def forward(
if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)

query, key, value = qkv.chunk(3, dim=2)
query, key, value = qkv.split([
self.d_model, self.kv_n_heads * self.head_dim,
self.kv_n_heads * self.head_dim
],
dim=2)
sashaDoubov marked this conversation as resolved.
Show resolved Hide resolved

key_padding_mask = attention_mask

Expand All @@ -447,6 +479,7 @@ def forward(
key,
value,
self.n_heads,
self.kv_n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
Expand All @@ -460,8 +493,8 @@ def forward(
return self.out_proj(context), attn_weights, past_key_value


class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
class MultiheadAttention(GeneralizedAttention):
"""Multi-head self attention.

Using torch or triton attention implementation enables user to also use
additive bias.
Expand All @@ -481,113 +514,55 @@ def __init__(
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__()

self.attn_impl = attn_impl
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln

self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.softmax_scale = softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.head_dim)
self.attn_dropout_p = attn_pdrop

fc_kwargs = {}
if fc_type != 'te':
fc_kwargs['device'] = device
# NOTE: if we ever want to make attn TensorParallel, I'm pretty sure we'll
# want to split Wqkv into Wq and Wkv where Wq can be TensorParallel but
# Wkv shouldn't be TensorParallel
# - vchiley
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
d_model,
d_model + 2 * self.head_dim,
**fc_kwargs,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = (d_model, d_model + self.head_dim)
self.Wqkv._fused = (0, fuse_splits) # type: ignore

if self.qk_ln:
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
self.q_ln = norm_class(d_model, device=device)
self.k_ln = norm_class(self.head_dim, device=device)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'triton':
self.attn_fn = triton_flash_attn_fn
if verbose:
warnings.warn(
'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\
'it uses more memory. When training larger models this can trigger ' +\
'alloc retries which hurts performance. If encountered, we recommend ' +\
'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.'
)
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available() and verbose:
warnings.warn(
'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\
'`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\
'we recommend using `attn_impl: triton`.'
)
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')
super().__init__(
d_model=d_model,
n_heads=n_heads,
kv_n_heads=n_heads, # for MHA, same # heads as kv groups
attn_impl=attn_impl,
clip_qkv=clip_qkv,
qk_ln=qk_ln,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
verbose=verbose,
device=device)


class MultiQueryAttention(GeneralizedAttention):
"""Multi-Query self attention.

self.out_proj = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model,
**fc_kwargs,
)
self.out_proj._is_residual = True # type: ignore
Using torch or triton attention implementation enables user to also use
additive bias.
"""

def forward(
def __init__(
self,
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = True,
needs_weights: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)

if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)

query, key, value = qkv.split(
[self.d_model, self.head_dim, self.head_dim], dim=2)

key_padding_mask = attention_mask

if self.qk_ln:
# Applying layernorm to qk
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
self.n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
multiquery=True,
)

return self.out_proj(context), attn_weights, past_key_value
d_model: int,
n_heads: int,
attn_impl: str = 'triton',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__(
d_model=d_model,
n_heads=n_heads,
kv_n_heads=1, # for MQA, 1 head
attn_impl=attn_impl,
clip_qkv=clip_qkv,
qk_ln=qk_ln,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
verbose=verbose,
device=device)


def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
Expand Down Expand Up @@ -678,4 +653,5 @@ def build_alibi_bias(
ATTN_CLASS_REGISTRY = {
'multihead_attention': MultiheadAttention,
'multiquery_attention': MultiQueryAttention,
'grouped_query_attention': GeneralizedAttention
}
Loading
Loading