diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index d8fea1f990..6af87346c8 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -19,8 +19,12 @@ jobs: strategy: matrix: include: - - name: 'cpu' - container: mosaicml/pytorch:latest + - name: 'cpu-latest' + container: mosaicml/pytorch:latest_cpu # mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04 + markers: 'not gpu' + pytest_command: 'coverage run -m pytest' + - name: 'cpu-2.0.1' + container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04 markers: 'not gpu' pytest_command: 'coverage run -m pytest' name: ${{ matrix.name }} diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 3494ea4f75..d228802ddc 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -19,8 +19,12 @@ jobs: strategy: matrix: include: - - name: 'gpu' - container: mosaicml/pytorch:latest + - name: 'gpu-latest' + container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 + markers: 'gpu' + pytest_command: 'coverage run -m pytest' + - name: 'gpu-2.0.1' + container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04 markers: 'gpu' pytest_command: 'coverage run -m pytest' name: ${{ matrix.name }} diff --git a/.github/workflows/pytest-cpu.yaml b/.github/workflows/pytest-cpu.yaml index 45f130100a..c5fe309cf3 100644 --- a/.github/workflows/pytest-cpu.yaml +++ b/.github/workflows/pytest-cpu.yaml @@ -27,7 +27,7 @@ jobs: set -ex export PATH=/composer-python:$PATH python -m pip install --upgrade 'pip<23' wheel - python -m pip install --upgrade .[all] + python -m pip install --upgrade .[dev] - name: Run Tests id: tests run: | diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index a4d06135f7..ffd132c27c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -32,10 +32,11 @@ jobs: PYPI_PACKAGE_NAME="llm-foundry-test-$(date +%Y%m%d%H%M%S)" fi - # Remove the xentropy-cuda-lib dependency as PyPI does not support direct installs. The - # error message for importing FusedCrossEntropy gives instructions on how to install if a - # user tries to use it without this dependency. + # Remove the xentropy-cuda-lib and triton-pre-mlir dependencies as PyPI does not support + # direct installs. The error message for importing FusedCrossEntropy/flash_attn_triton + # gives instructions on how to install if a user tries to use it without this dependency. sed '/xentropy-cuda-lib@git+https:\/\/github.com\/HazyResearch\/flash-attention.git@.*/d' -i setup.py + sed '/triton-pre-mlir@git+https:\/\/github.com\/vchiley\/triton.git@.*/d' -i setup.py python -m pip install --upgrade build twine python -m build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cdc2f59606..00e55dad38 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,6 @@ default_language_version: python: python3 +exclude: llmfoundry/models/layers/flash_attn_triton.py repos: - repo: https://github.com/google/yapf rev: v0.32.0 diff --git a/README.md b/README.md index 2980317efa..57a1a404b2 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,8 @@ Here's what you need to get started with our LLM stack: # Installation +This assumes you already have PyTorch and CMake installed. + To get started, clone this repo and install the requirements: diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index a6c055c99b..78c1056c7d 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -43,10 +43,12 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss): """ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): + trust_remote_code = om_model_config.get('trust_remote_code', True) + use_auth_token = om_model_config.get('use_auth_token', False) config = AutoConfig.from_pretrained( om_model_config.pretrained_model_name_or_path, - trust_remote_code=om_model_config.get('trust_remote_code', True), - use_auth_token=om_model_config.get('use_auth_token', False), + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, ) # set config overrides @@ -87,19 +89,24 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): if om_model_config.pretrained: model = AutoModelForCausalLM.from_pretrained( om_model_config.pretrained_model_name_or_path, - trust_remote_code=om_model_config.get( - 'trust_remote_code', True), - use_auth_token=om_model_config.get('use_auth_token', False), + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, config=config) else: - model = AutoModelForCausalLM.from_config(config) + model = AutoModelForCausalLM.from_config( + config, + trust_remote_code=trust_remote_code, + ) elif init_device == 'meta': if om_model_config.pretrained: raise ValueError( 'Setting cfg.pretrained=True is not supported when init_device="meta".' ) with init_empty_weights(include_buffers=False): - model = AutoModelForCausalLM.from_config(config) + model = AutoModelForCausalLM.from_config( + config, + trust_remote_code=trust_remote_code, + ) else: raise ValueError( f'init_device="{init_device}" must be either "cpu" or "meta".') diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index d7d830a596..2e5b886571 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn from einops import rearrange +from packaging import version from torch import nn from llmfoundry.models.layers.norm import LPLayerNorm @@ -207,11 +208,27 @@ def triton_flash_attn_fn( multiquery=False, ): try: - from flash_attn import flash_attn_triton # type: ignore + from llmfoundry.models.layers.flash_attn_triton import flash_attn_func except: - raise RuntimeError( - 'Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202' - ) + _installed = False + if version.parse(torch.__version__) < version.parse('2.0.0'): + _installed = True + # if torch1.13.1 revert to using triton flash attn from HazyResearch + # with flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202 + try: + from flash_attn.flash_attn_triton import flash_attn_func + except: + _installed = False + if not _installed: + # installing triton-pre-mlir works for both torch1.13.1 and torch2.0+ + # default recommendation is to install this variant + raise RuntimeError( + 'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU ' + 'and `pip install .[gpu]` if installing from llm-foundry source or ' + '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` ' + 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). ' + 'Note: (1) requires you have CMake and PyTorch already installed.' + ) check_valid_inputs(query, key, value) @@ -257,9 +274,8 @@ def triton_flash_attn_fn( value = value.expand(*value.shape[:2], n_heads, value.size(-1)) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - attn_output = flash_attn_triton.flash_attn_func(query, key, value, - attn_bias, reset_is_causal, - softmax_scale) + attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, + softmax_scale) output = attn_output.view(*attn_output.shape[:2], -1) diff --git a/llmfoundry/models/layers/flash_attn_triton.py b/llmfoundry/models/layers/flash_attn_triton.py new file mode 100644 index 0000000000..9276d0f917 --- /dev/null +++ b/llmfoundry/models/layers/flash_attn_triton.py @@ -0,0 +1,835 @@ +""" +Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py +update imports to use 'triton_pre_mlir' + +*Experimental* implementation of FlashAttention in Triton. +Tested with triton==2.0.0.dev20221202. +Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions +other than 64: +https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 +We'll update this implementation with the new Triton backend once this is fixed. + +We use the FlashAttention implementation from Phil Tillet a starting point. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Implement both causal and non-causal attention. +- Implement both self-attention and cross-attention. +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. +- Support attention bias. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. +- Make the backward for d=128 much faster by reducing register spilling. +- Optionally parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. + +Caution: +- This is an *experimental* implementation. The forward pass should be quite robust but +I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward, while Triton backward is +generally slower than CUDA backward. Overall Triton forward + backward is slightly slower +than CUDA forward + backward. +- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton version supports attention bias, while CUDA version doesn't. +""" + +import math + +import torch + +import triton_pre_mlir as triton +import triton_pre_mlir.language as tl + + +# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), +# # This config has a race condition when EVEN_M == False, disabling it for now. +# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), +# ], +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# ) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, K, V, Bias, Out, + Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_bb, stride_bh, stride_bm, + stride_ob, stride_oh, stride_om, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # off_b = tl.program_id(1) + # off_h = tl.program_id(2) + # off_hb = off_b * nheads + off_h + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + if BIAS_TYPE == 'vector': + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == 'matrix': + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + else: + k = tl.load(k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != 'none': + if BIAS_TYPE == 'vector': + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == 'matrix': + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0).to(tl.float32) + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + else: + v = tl.load(v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store(out_ptrs, acc_o, + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, DO, Delta, + stride_ob, stride_oh, stride_om, + stride_dob, stride_doh, stride_dom, + nheads, seqlen_q, seqlen_q_rounded, headdim, + BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) + do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD: tl.constexpr, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + if BIAS_TYPE == 'vector': + b_ptrs = Bias + offs_n + elif BIAS_TYPE == 'matrix': + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could screw it up. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_d[None, :] < headdim), other=0.0) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + if BIAS_TYPE != 'none': + tl.debug_barrier() # Race condition otherwise + if BIAS_TYPE == 'vector': + if EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == 'matrix': + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_n[None, :] < seqlen_k), + other=0.0).to(tl.float32) + qk = qk * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + if BIAS_TYPE == 'none': + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + else: + p = tl.exp(qk - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_d[None, :] < headdim), other=0.0) + # if EVEN_M: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs) + # else: + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + # else: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + # else: + # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + # & (offs_d[None, :] < headdim), other=0.0) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, + eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last") + else: + dq = tl.load(dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last") + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add(dq_ptrs, dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + if BIAS_TYPE == 'matrix': + b_ptrs += BLOCK_M * stride_bm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_bb, stride_bh, stride_bm, + stride_dob, stride_doh, stride_dom, + stride_dqb, stride_dqh, stride_dqm, + stride_dkb, stride_dkh, stride_dkn, + stride_dvb, stride_dvh, stride_dvn, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + if BIAS_TYPE != 'none': + Bias += off_b * stride_bb + off_h * stride_bh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD=False, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD=True, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + + +def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, 'FlashAttention only support head dimensions up to 128' + assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' + assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' + assert q.is_cuda and k.is_cuda and v.is_cuda + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + + has_bias = bias is not None + bias_type = 'none' + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = 'vector' + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = 'matrix' + else: + raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' + ' or (seqlen_q, seqlen_k)') + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, k, v, bias, o, + lse, tmp, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + *bias_strides, + o.stride(0), o.stride(2), o.stride(1), + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, causal, BLOCK_HEADDIM, + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse, softmax_scale # softmax_scale could have been updated + + +def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, do, delta, + o.stride(0), o.stride(2), o.stride(1), + do.stride(0), do.stride(2), do.stride(1), + nheads, seqlen_q, seqlen_q_rounded, d, + BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + has_bias = bias is not None + bias_type = 'none' + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + if bias.shape[2:] == (1, seqlen_k): + bias_type = 'vector' + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = 'matrix' + else: + raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' + ' or (seqlen_q, seqlen_k)') + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads) + _bwd_kernel[grid]( + q, k, v, bias, + do, dq_accum, dk, dv, + lse, delta, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + *bias_strides, + do.stride(0), do.stride(2), do.stride(1), + dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), + dk.stride(0), dk.stride(2), dk.stride(1), + dv.stride(0), dv.stride(2), dv.stride(1), + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, causal, BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): + """ + qkv: (batch, seqlen, 3, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). + ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) + """ + # Make sure that the last dimension is contiguous + if qkv.stride(-1) != 1: + qkv = qkv.contiguous() + o, lse, ctx.softmax_scale = _flash_attn_forward( + qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, + softmax_scale=softmax_scale + ) + ctx.save_for_backward(qkv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + qkv, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet' + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dqkv = torch.empty_like(qkv) + _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, + dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], + bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dqkv, None, None, None + + +flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): + """ + q: (batch, seqlen_q, nheads, headdim) + kv: (batch, seqlen_k, 2, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, kv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, kv, o, lse, bias = ctx.saved_tensors + if len(ctx.needs_input_grad) >= 3: + assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, + dq, dkv[:, :, 0], dkv[:, :, 1], + bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dq, dkv, None, None, None + + +flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): + """ + q: (batch_size, seqlen_q, nheads, headdim) + k, v: (batch_size, seqlen_k, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, k, v, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, + bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dq, dk, dv, None, None, None + + +flash_attn_func = FlashAttnFunc.apply diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 1a8aeee940..746a343101 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -42,6 +42,11 @@ from llmfoundry.models.utils.param_init_fns import ( # type: ignore MODEL_INIT_REGISTRY, generic_param_init_fn_) +try: + from llmfoundry.models.layers.flash_attn_triton import flash_attn_func +except: + pass + Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] diff --git a/scripts/inference/convert_composer_to_hf.py b/scripts/inference/convert_composer_to_hf.py index 3e87e1cb44..287320b8ed 100644 --- a/scripts/inference/convert_composer_to_hf.py +++ b/scripts/inference/convert_composer_to_hf.py @@ -260,8 +260,11 @@ def visit(self, node: ast.AST): return super().visit(node) -def convert_to_relative_import(module_name: str) -> str: +def convert_to_relative_import( + module_name: str, original_parent_module_name: Optional[str]) -> str: parts = module_name.split('.') + if parts[-1] == original_parent_module_name: + return '.' return '.' + parts[-1] @@ -275,6 +278,10 @@ def process_file(file_path: str, folder_path: str) -> List[str]: with open(file_path, 'r') as f: source = f.read() + parent_module_name = None + if os.path.basename(file_path) == '__init__.py': + parent_module_name = os.path.basename(os.path.dirname(file_path)) + tree = ast.parse(source) new_files_to_process = [] nodes_to_remove = [] @@ -283,7 +290,8 @@ def process_file(file_path: str, folder_path: str) -> List[str]: if isinstance(node, ast.ImportFrom) and node.module.startswith('llmfoundry'): module_path = find_module_file(node.module) - node.module = convert_to_relative_import(node.module) + node.module = convert_to_relative_import(node.module, + parent_module_name) # recursively process any llmfoundry files new_files_to_process.append(module_path) # remove any imports from composer or omegaconf diff --git a/scripts/train/README.md b/scripts/train/README.md index 38c0da2fe1..d4496eb781 100644 --- a/scripts/train/README.md +++ b/scripts/train/README.md @@ -2,13 +2,7 @@ ## Installation -If you haven't already, make sure to install the requirements: - -```bash -git clone https://github.com/mosaicml/llm-foundry.git -cd llm-foundry -pip install -e ".[gpu]" # or pip install -e . if no NVIDIA GPU -``` +If you haven't already, make sure to [install the requirements](../../README.md#Installation). ## Dataset preparation To run pretraining, you'll need to make yourself a copy of a pretraining dataset. Check out the `llm-foundry/data_prep` folder for detailed instructions. diff --git a/setup.py b/setup.py index d3e525dc5c..c7f4dffd13 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ install_requires = [ 'composer[nlp,wandb]>=0.14.1,<0.15', 'mosaicml-streaming>=0.4.1,<0.5', - 'torch==1.13.1', + 'torch>=1.13.1,<=2.0.1', 'datasets==2.10.1', 'sentencepiece==0.1.97', 'einops==0.5.0', @@ -58,6 +58,9 @@ 'mosaicml-cli>=0.3,<1', 'onnx==1.13.1', 'onnxruntime==1.14.1', + 'cmake>=3.25.0,<=3.26.3', # required for triton-pre-mlir below + # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI + 'triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python', ] extra_deps = {} @@ -78,7 +81,6 @@ extra_deps['gpu'] = [ 'flash-attn==v1.0.3.post0', - 'triton==2.0.0.dev20221202', # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v0.2.8#subdirectory=csrc/xentropy', ] diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 2c304e30d7..789a3362f1 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -177,7 +177,7 @@ def test_denoising_dataloader(decoder_only_format, pretokenize, packing_ratio): 'sequence_mask_ratios': 0.25, }, 'drop_last': False, - 'num_workers': 0, + 'num_workers': 4, } cfg = om.create(cfg) device_batch_size = 2 @@ -237,7 +237,7 @@ def test_finetuning_dataloader(decoder_only_format, allow_pad_trimming, 'shuffle': True, }, 'drop_last': False, - 'num_workers': 0, + 'num_workers': 4, 'pin_memory': False, 'prefetch_factor': 2, 'persistent_workers': False, diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py new file mode 100644 index 0000000000..5675b65ecd --- /dev/null +++ b/tests/test_hf_conversion_script.py @@ -0,0 +1,123 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys + +from composer import Trainer + +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM + +# Add repo root to path so we can import scripts and test it +repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(repo_dir) +import shutil +from argparse import Namespace +from typing import cast + +import pytest +import torch +import transformers +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +from scripts.inference.convert_composer_to_hf import main + + +def delete_transformers_cache(): + hf_cache_home = os.path.expanduser( + os.getenv( + 'HF_HOME', + os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), + 'huggingface'))) + HF_MODULES_CACHE = os.getenv('HF_MODULES_CACHE', + os.path.join(hf_cache_home, 'modules')) + if os.path.exists(HF_MODULES_CACHE) and os.path.isdir(HF_MODULES_CACHE): + shutil.rmtree(HF_MODULES_CACHE) + + +def get_config( + conf_path='scripts/train/yamls/pretrain/testing.yaml') -> DictConfig: + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + with open(conf_path) as f: + test_cfg = om.load(f) + return cast(DictConfig, test_cfg) + + +def test_convert_and_generate_torch(tmp_path): + delete_transformers_cache() + + cfg = get_config() + cfg['model']['init_device'] = 'cpu' + cfg['model']['attn_config']['attn_impl'] = 'torch' + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'EleutherAI/gpt-neox-20b') + model = ComposerMPTCausalLM(cfg['model'], tokenizer) + trainer = Trainer(model=model) + trainer.save_checkpoint(os.path.join(tmp_path, 'checkpoint.pt')) + + args = Namespace(composer_path=os.path.join(tmp_path, 'checkpoint.pt'), + hf_output_path=os.path.join(tmp_path, 'hf-output-folder'), + output_precision='fp32', + local_checkpoint_save_location=None, + hf_repo_for_upload=None, + test_uploaded_model=False) + main(args) + + config = transformers.AutoConfig.from_pretrained(os.path.join( + tmp_path, 'hf-output-folder'), + trust_remote_code=True) + config.attn_config['attn_impl'] = 'torch' + model = transformers.AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_path, 'hf-output-folder'), + config=config, + trust_remote_code=True) + tokenizer = transformers.AutoTokenizer.from_pretrained( + os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True) + + output = model.generate(tokenizer('hello', + return_tensors='pt')['input_ids'], + max_new_tokens=1) + assert output.shape == (1, 2) + + delete_transformers_cache() + + +@pytest.mark.gpu +def test_convert_and_generate_triton(tmp_path): + delete_transformers_cache() + + cfg = get_config() + cfg['model']['init_device'] = 'cpu' + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'EleutherAI/gpt-neox-20b') + model = ComposerMPTCausalLM(cfg['model'], tokenizer) + trainer = Trainer(model=model) + trainer.save_checkpoint(os.path.join(tmp_path, 'checkpoint.pt')) + + args = Namespace(composer_path=os.path.join(tmp_path, 'checkpoint.pt'), + hf_output_path=os.path.join(tmp_path, 'hf-output-folder'), + output_precision='fp32', + local_checkpoint_save_location=None, + hf_repo_for_upload=None, + test_uploaded_model=False) + main(args) + + config = transformers.AutoConfig.from_pretrained(os.path.join( + tmp_path, 'hf-output-folder'), + trust_remote_code=True) + config.attn_config['attn_impl'] = 'triton' + model = transformers.AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_path, 'hf-output-folder'), + config=config, + trust_remote_code=True) + model.to(device='cuda', dtype=torch.bfloat16) + tokenizer = transformers.AutoTokenizer.from_pretrained( + os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True) + + output = model.generate(tokenizer( + 'hello', return_tensors='pt')['input_ids'].to(device='cuda'), + max_new_tokens=1) + assert output.shape == (1, 2) + + delete_transformers_cache() diff --git a/tests/test_hf_mpt_gen.py b/tests/test_hf_mpt_gen.py new file mode 100644 index 0000000000..59739e23b3 --- /dev/null +++ b/tests/test_hf_mpt_gen.py @@ -0,0 +1,64 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from composer.core.precision import get_precision_context +from composer.utils import dist, get_device, reproducibility +from omegaconf import OmegaConf as om + +from llmfoundry import COMPOSER_MODEL_REGISTRY +from llmfoundry.utils import build_tokenizer + + +@pytest.mark.gpu +@pytest.mark.parametrize('device', ['cpu', 'gpu']) +@pytest.mark.parametrize('attn_impl', ['triton', 'torch']) +def test_init_hfhub_mpt(device, attn_impl): + if device == 'cpu' and attn_impl == 'triton': + pytest.skip(f'{attn_impl=} not implemented for {device=}.') + device = get_device(device) + + with open('scripts/train/yamls/pretrain/testing.yaml') as f: + test_cfg = om.load(f) + + reproducibility.seed_all(test_cfg.get('seed', 42)) + + attn_uses_sequence_id = True if test_cfg.get('eos_token_id', + None) is not None else False + test_cfg.model = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'mosaicml/mpt-7b', + 'pretrained': False, + 'config_overrides': { + 'd_model': 128, + 'n_heads': 4, + 'n_layers': 2, + 'expansion_ratio': 2, + 'attn_config': { + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': attn_uses_sequence_id, + }, + }, + } + + # build tokenizer + tokenizer = build_tokenizer(test_cfg.tokenizer) + + # build model + model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, + tokenizer) + test_cfg.n_params = sum(p.numel() for p in model.parameters()) + + model.eval() + model = device.module_to_device(model) + + with get_precision_context('amp_bf16' if device.name == 'gpu' else 'fp32'): + _ = model.generate( + device.tensor_to_device( + tokenizer('hello', return_tensors='pt')['input_ids']), + max_new_tokens=10, + ) + + +def test_init_hfhub_mpt_cpu(): + test_init_hfhub_mpt(device='cpu', attn_impl='torch') diff --git a/tests/test_model.py b/tests/test_model.py index 06ef20c568..a1ce2aa12d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -351,7 +351,9 @@ def test_loss_fn(): except: pytest.skip('Fused cross entropy was not installed') - reproducibility.seed_all(1111) + # run numerical test in pure fp32 + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: @@ -364,6 +366,8 @@ def test_loss_fn(): 'init_std': 0.02, } + reproducibility.seed_all(test_cfg.get('global_seed', 42)) + tokenizer = build_tokenizer(test_cfg.tokenizer) model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, @@ -383,7 +387,7 @@ def test_loss_fn(): eps=test_cfg.optimizer.eps, weight_decay=test_cfg.optimizer.weight_decay) - for i in range(25): + for i in range(15): batch = gen_random_batch(2, test_cfg) output_1 = model_1(batch) output_2 = model_2(batch) diff --git a/tests/test_onnx.py b/tests/test_onnx.py index 84616a7572..1957fe2443 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -29,13 +29,15 @@ def test_onnx_export(tmp_path): AutoConfig.register('mpt', MPTConfig) AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM) + batch_size, vocab_size, max_seq_len = 1, 50368, 128 + hf_config = MPTConfig( init_device='cpu', - d_model=128, + d_model=64, n_heads=4, n_layers=2, expansion_ratio=2, - max_seq_len=2048, + max_seq_len=max_seq_len, emb_pdrop=0.0, resid_pdrop=0.0, attn_config={ @@ -43,18 +45,14 @@ def test_onnx_export(tmp_path): 'alibi': True, }, use_cache=True, - vocab_size=50368, + vocab_size=vocab_size, norm_type='layernorm', ) mpt = MPTForCausalLM(hf_config) mpt.eval() print('Creating random batch...') - sample_input = gen_random_batch( - 1, - 50368, - 2048, - ) + sample_input = gen_random_batch(batch_size, vocab_size, max_seq_len) with torch.no_grad(): mpt(**sample_input)