Skip to content

Commit

Permalink
Implement local attention
Browse files Browse the repository at this point in the history
Co-authored-by: Timothee Lacroix <t@mistral.ai>
  • Loading branch information
tridao and timlacroix committed Sep 26, 2023
1 parent 4c8ff91 commit 083e8f5
Show file tree
Hide file tree
Showing 9 changed files with 706 additions and 255 deletions.
58 changes: 45 additions & 13 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ void set_params_fprop(Flash_fwd_params &params,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
bool is_causal) {
int window_size_left,
int window_size_right) {

// Reset the parameters
memset(&params, 0, sizeof(params));
Expand Down Expand Up @@ -105,7 +106,15 @@ void set_params_fprop(Flash_fwd_params &params,
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
TORCH_CHECK(p_dropout < 1.f);

params.is_causal = is_causal;
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
params.is_causal = window_size_left < 0 && window_size_right == 0;

if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;

params.is_seqlens_k_cumulative = true;
}

Expand Down Expand Up @@ -138,7 +147,8 @@ void set_params_dgrad(Flash_bwd_params &params,
void *dsoftmax_sum_d,
float p_dropout,
float softmax_scale,
bool is_causal) {
int window_size_left,
int window_size_right) {

set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
Expand All @@ -149,7 +159,8 @@ void set_params_dgrad(Flash_bwd_params &params,
softmax_lse_d,
p_dropout,
softmax_scale,
is_causal);
window_size_left,
window_size_right);

// Set the pointers and strides.
params.do_ptr = dout.data_ptr();
Expand Down Expand Up @@ -242,6 +253,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const float p_dropout,
const float softmax_scale,
bool is_causal,
const int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_) {

Expand Down Expand Up @@ -281,10 +294,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size_og % 8 == 0;
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
Expand Down Expand Up @@ -353,7 +367,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
window_size_left,
window_size_right);

// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
Expand Down Expand Up @@ -421,9 +436,12 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_) {

if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
Expand Down Expand Up @@ -534,7 +552,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
window_size_left,
window_size_right);

// number of times random will be generated per thread, to offset philox counter in thc random
// state
Expand Down Expand Up @@ -600,8 +619,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
const int window_size_left,
int window_size_right,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {

if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
Expand Down Expand Up @@ -748,7 +771,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_d.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
window_size_left,
window_size_right);

auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
Expand Down Expand Up @@ -804,9 +828,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const int window_size_left,
int window_size_right,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state
) {
c10::optional<at::Tensor> &rng_state) {

if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
Expand Down Expand Up @@ -969,7 +996,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
window_size_left,
window_size_right);

auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
Expand Down Expand Up @@ -1019,6 +1047,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
const int window_size_left,
int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits
) {
Expand Down Expand Up @@ -1059,10 +1089,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size_og % 8 == 0;
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
Expand Down Expand Up @@ -1125,7 +1156,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
is_causal);
window_size_left,
window_size_right);

at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) {
Expand Down
3 changes: 3 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ struct Flash_fwd_params : public Qkv_params {
float rp_dropout;
float scale_softmax_rp_dropout;

// Local window size
int window_size_left, window_size_right;

// Random state.
at::PhiloxCudaState philox_args;

Expand Down
66 changes: 59 additions & 7 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ inline __device__ void convert_dKV(const Params &params) {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {

using Element = typename Kernel_traits::Element;
Expand All @@ -447,6 +447,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;

int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
if (Is_local) {
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
}

const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Expand Down Expand Up @@ -655,14 +658,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;

int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM);
// We're guaranteed that m_block_min <= m_block:
int m_block_min = (!Is_causal && !Is_local)
? 0
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
// If not local, we're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
// However, if local, then this possible to have some blocks of K & V not attending to any query.
// We might need to exit early and write 0 to dK and dV for those blocks.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
if (Is_local && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
clear(tdKrdK);
clear(tdVrdV);
Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
}

if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
tQsQ.data() = tQsQ.data() + size(sQ);
Expand Down Expand Up @@ -777,12 +819,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// However, it's possible that the values in acc_s are so large that they overflow
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// So we need to mask out the elements beyond actual_seqlen_k.
if (!Is_causal) {
if (!Is_causal && !Is_local) {
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
flash::apply_mask(scores, binfo.actual_seqlen_k,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
}
} else {
} else if (Is_causal) {
// Putting this causal masking right after acc_s is *much* slower for some reason.
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
Expand All @@ -795,6 +837,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
}
} else if (Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
flash::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, AtomLayoutMS * 16,
params.window_size_left, params.window_size_right);
}

}
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
Expand Down Expand Up @@ -1510,7 +1562,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {

const int n_block = blockIdx.x;
Expand All @@ -1519,7 +1571,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// The block index for the head.
const int bidh = blockIdx.z;

compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit 083e8f5

Please sign in to comment.