diff --git a/README.md b/README.md index 59bb310..9dab7aa 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,15 @@ The current release supports: - Efficient CUDA kernel implementation for fast inference (support context and decoding stage). - Examples on 4-bit inference of an instruction-tuned model (Vicuna) and multi-modal LM (LLaVA). -![TinyChat on RTX 4090: W4A16 is 2.3x faster than FP16](./tinychat/figures/4090_example.gif) +![TinyChat on Orin: W4A16 is 3.2x faster than FP16](./tinychat/figures/orin_example.gif) -Check out [TinyChat](tinychat), which delievers 2.3x faster inference performance for the **LLaMA-2** chatbot on RTX 4090! +Check out [TinyChat](tinychat), which delievers **30 tokens/second** inference performance (**3.2x faster** than FP16) for the **LLaMA-2** chatbot on the resource-constrained NVIDIA Jetson Orin! It also offers a turn-key solution for **on-device inference** of LLMs on **resource-constrained edge platforms**. With TinyChat, it is now possible to run **large** models on **small** and **low-power** devices even without Internet connection. ## News +- [2023/09] ⚡ Check out our latest [**TinyChat**](tinychat), which is ~2x faster than the first release on Orin! - [2023/09] ⚡ Check out [**AutoAWQ**](https://github.com/casper-hansen/AutoAWQ), a third-party implementation to make AWQ easier to expand to new models, improve inference speed, and integrate into Huggingface. - [2023/07] 🔥 We released **TinyChat**, an efficient and lightweight chatbot interface based on AWQ. TinyChat enables efficient LLM inference on both cloud and edge GPUs. LLama-2-chat models are supported! Check out our implementation [here](tinychat). - [2023/07] 🔥 We added AWQ support and pre-computed search results for Llama-2 models (7B & 13B). Checkout our model zoo [here](https://huggingface.co/datasets/mit-han-lab/awq-model-zoo)! diff --git a/awq/kernels/csrc/attention/README.md b/awq/kernels/csrc/attention/README.md new file mode 100644 index 0000000..ec0aae5 --- /dev/null +++ b/awq/kernels/csrc/attention/README.md @@ -0,0 +1,8 @@ +# Attention kernel from FasterTransformer + +This CUDA extension wraps the single-query attention [kernel](https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp) from +FasterTransformer v5.2.1 for benchmarking purpose. + +```sh +cd csrc/ft_attention && pip install . +``` diff --git a/awq/kernels/csrc/attention/cuda_bf16_fallbacks.cuh b/awq/kernels/csrc/attention/cuda_bf16_fallbacks.cuh new file mode 100644 index 0000000..f5641f6 --- /dev/null +++ b/awq/kernels/csrc/attention/cuda_bf16_fallbacks.cuh @@ -0,0 +1,257 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_bf16_wrapper.h" +#include + +namespace fastertransformer { + +#ifdef ENABLE_BF16 +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(f_val.x)); + int8[1] = static_cast(static_cast(f_val.y)); + return int16; +#else + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(val.x)); + int8[1] = static_cast(static_cast(val.y)); + return int16; +#endif +} + +inline __device__ __nv_bfloat162 float22bf162(const float2 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __floats2bfloat162_rn(val.x, val.y); +#else + return __float22bfloat162_rn(val); +#endif +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; +#else + return __bfloat162bfloat162(val); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#else + return __hadd2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); +#else + return __hadd(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); +#else + return __hsub2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); +#else + return __hsub(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); +#else + return __hmul2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); +#else + return __hmul(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); +#else + return __hfma2(x, y, z); +#endif +} + +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); +#else + return __hfma(x, y, z); +#endif +} + +inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x);; + return __floats2bfloat162_rn(expf(fxl), expf(fxh)); +#else + return h2exp(x); +#endif +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; +inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; + +inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ + __nv_bfloat162 t; t.x = x; t.y = y; return t; +} + +#endif + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); +#else + return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); +#else + return a * b * c + d; +#endif +} + +#endif // ENABLE_BF16 + +} // namespace fastertransformer diff --git a/awq/kernels/csrc/attention/cuda_bf16_wrapper.h b/awq/kernels/csrc/attention/cuda_bf16_wrapper.h new file mode 100644 index 0000000..efb6e79 --- /dev/null +++ b/awq/kernels/csrc/attention/cuda_bf16_wrapper.h @@ -0,0 +1,23 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_BF16 +#include +#endif diff --git a/awq/kernels/csrc/attention/decoder_masked_multihead_attention.cu b/awq/kernels/csrc/attention/decoder_masked_multihead_attention.cu new file mode 100644 index 0000000..e5a6690 --- /dev/null +++ b/awq/kernels/csrc/attention/decoder_masked_multihead_attention.cu @@ -0,0 +1,154 @@ +// Adapted from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention.h" +#include "decoder_masked_multihead_attention_utils.h" +#include "cuda_bf16_wrapper.h" +#include +#include +#include + +#include "decoder_masked_multihead_attention_template.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + auto kernel = mmha::masked_multihead_attention_kernel; \ + if (smem_sz >= 48 * 1024) { \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + dim3 grid(params.num_heads, params.batch_size); \ + kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#undef MMHA_LAUNCH_KERNEL + +template +void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + switch (params.hidden_size_per_head) { + case 32: + mmha_launch_kernel(params, stream); + break; + case 48: + mmha_launch_kernel(params, stream); + break; + case 64: + mmha_launch_kernel(params, stream); + break; + case 80: + mmha_launch_kernel(params, stream); + break; + case 96: + mmha_launch_kernel(params, stream); + break; + case 112: + mmha_launch_kernel(params, stream); + break; + case 128: + mmha_launch_kernel(params, stream); + break; + case 160: + mmha_launch_kernel(params, stream); + break; + case 192: + mmha_launch_kernel(params, stream); + break; + case 224: + mmha_launch_kernel(params, stream); + break; + case 256: + mmha_launch_kernel(params, stream); + break; + default: + assert(false); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/awq/kernels/csrc/attention/decoder_masked_multihead_attention.h b/awq/kernels/csrc/attention/decoder_masked_multihead_attention.h new file mode 100644 index 0000000..292fa34 --- /dev/null +++ b/awq/kernels/csrc/attention/decoder_masked_multihead_attention.h @@ -0,0 +1,184 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_bf16_wrapper.h" +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The structure of parameters for the masked multihead attention kernel. +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. + +template +struct Multihead_attention_params_base { + + // The output buffer. Dimensions B x D. + T* out = nullptr; + + // The input Qs and the associated bias. Dimensions B x D and D, resp. + const T *q = nullptr, *q_bias = nullptr; + // The input Ks and the associated bias. Dimensions B x D and D, resp. + const T *k = nullptr, *k_bias = nullptr; + // The input Vs and the associated bias. Dimensions B x D and D, resp. + const T *v = nullptr, *v_bias = nullptr; + + // The cache for the Ks. The size must be at least B x L x D. + T* k_cache = nullptr; + // The cache for the Vs. The size must be at least B x L x D. + T* v_cache = nullptr; + // The indirections to use for cache when beam sampling. + const int* cache_indir = nullptr; + + // Stride to handle the case when KQV is a single buffer + int stride = 0; + + // The batch size. + int batch_size = 0; + // The beam width + int beam_width = 0; + // The sequence length. + int memory_max_len = 0; + // The number of heads (H). + int num_heads = 0; + // The number of heads for KV cache. + int num_kv_heads = 0; + // The hidden dimension per head (Dh). + int hidden_size_per_head = 0; + // The per-head latent space reserved for rotary embeddings. + int rotary_embedding_dim = 0; + bool neox_rotary_style = false; + float rotary_base = 0.0f; + // The maximum length of input sentences. + int max_input_length = 0; + // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? + int timestep = 0; + // The current timestep of each sentences (support different timestep for different sentences) + + // The 1.f / sqrt(Dh). Computed on the host. + float inv_sqrt_dh = 0.0f; + + // Used when we have some input context like gpt + const int* total_padding_tokens = nullptr; + + const bool* masked_tokens = nullptr; + const int* prefix_prompt_lengths = nullptr; + int max_prefix_prompt_length = 0; + + const T* relative_attention_bias = nullptr; + int relative_attention_bias_stride = 0; + // The slope per head of linear position bias to attention score (H). + const float* linear_bias_slopes = nullptr; + + const T* ia3_key_weights = nullptr; + const T* ia3_value_weights = nullptr; + const int* ia3_tasks = nullptr; + + const float* qkv_scale_out = nullptr; + const float* attention_out_scale = nullptr; + int int8_mode = 0; +}; + +template +struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + + // allows to exist attention eary + bool* finished = nullptr; + + // required in case of cross attention + // will need it here till if constexpr in c++17 + int* memory_length_per_sample = nullptr; + + // required in case of masked attention with different length + const int* length_per_sample = nullptr; +}; + +template +struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + + // allows to exist attention eary + bool* finished = nullptr; + + // required in case of cross attention + int* memory_length_per_sample = nullptr; + + // required in case of masked attention with different length + const int* length_per_sample = nullptr; +}; + +template +using Masked_multihead_attention_params = Multihead_attention_params; + +template +using Cross_multihead_attention_params = Multihead_attention_params; + +template +struct outputCrossAttentionParam { + // max decoder output length + int max_decoder_seq_len = 0; + T* cross_attention_out = nullptr; + bool is_return_cross_attentions = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/awq/kernels/csrc/attention/decoder_masked_multihead_attention_template.hpp b/awq/kernels/csrc/attention/decoder_masked_multihead_attention_template.hpp new file mode 100644 index 0000000..84eda81 --- /dev/null +++ b/awq/kernels/csrc/attention/decoder_masked_multihead_attention_template.hpp @@ -0,0 +1,1608 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "decoder_masked_multihead_attention.h" +#include "decoder_masked_multihead_attention_utils.h" +#include "cuda_bf16_wrapper.h" +#include "cuda_bf16_fallbacks.cuh" +#include +#include +#include + +// #define MMHA_USE_HMMA_FOR_REDUCTION + +// Below are knobs to extend FP32 accumulation for higher FP16 accuracy + +// Does not seem to affect the accuracy that much +#define MMHA_USE_FP32_ACUM_FOR_FMA + +// Seems to slightly improve the accuracy +#define MMHA_USE_FP32_ACUM_FOR_OUT + +#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) + // Does not seem to improve the accuracy + //#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#endif + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. +// +// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use +// 64, 128 and 256 threads per block. +// +// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to +// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The +// cache buffer helps with memory accesses and contains keys with bias. +// +// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and +// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The +// values for x are chosen to create chunks of 16 bytes. +// +// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs +// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At +// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an +// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. +// +// After that loop, a parallel softmax is computed across the different Q * K^T values stored in +// shared memory. +// +// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many +// timesteps are computed by loop iteration. As with the keys, the values are read from a cache +// except for the current timestep. The layout of the cache buffer for the values is much simpler +// as it is [B, H, L, Dh]. +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_ { +}; + +template<> +struct Qk_vec_ { + using Type = float; +}; +template<> +struct Qk_vec_ { + using Type = float2; +}; +template<> +struct Qk_vec_ { + using Type = float4; +}; +template<> +struct Qk_vec_ { + using Type = float4; +}; +template<> +struct Qk_vec_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_ { + using Type = uint2; +}; +template<> +struct Qk_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct Qk_vec_<__nv_bfloat16, 32> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_<__nv_bfloat16, 64> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_<__nv_bfloat16, 128> { + using Type = bf16_4_t; +}; +template<> +struct Qk_vec_<__nv_bfloat16, 256> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_ { +}; + +template<> +struct K_vec_ { + using Type = float; +}; +template<> +struct K_vec_ { + using Type = float2; +}; +template<> +struct K_vec_ { + using Type = float4; +}; +template<> +struct K_vec_ { + using Type = uint32_t; +}; +template<> +struct K_vec_ { + using Type = uint2; +}; +template<> +struct K_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct K_vec_<__nv_bfloat16, 4> { + using Type = __nv_bfloat162; +}; +template<> +struct K_vec_<__nv_bfloat16, 2> { + using Type = bf16_4_t; +}; +template<> +struct K_vec_<__nv_bfloat16, 1> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_ { +}; + +template<> +struct V_vec_ { + using Type = float; +}; +template<> +struct V_vec_ { + using Type = float2; +}; +template<> +struct V_vec_ { + using Type = float4; +}; +template<> +struct V_vec_ { + using Type = uint32_t; +}; +template<> +struct V_vec_ { + using Type = uint2; +}; +template<> +struct V_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct V_vec_<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct V_vec_<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct Qk_vec_acum_fp32_ { +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float4; +}; +// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template +struct V_vec_acum_fp32_ { +}; + +template<> +struct V_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +{ +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = K_vec; +#endif + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) + { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) +{ + float4 c; + float zero = 0.f; + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +// Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float& dst, float src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t& dst, float src) +{ + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t& dst, float2 src) +{ + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) +{ + dst = __float2bfloat16(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, Float4_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4& dst, Float8_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2& dst, float2 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4& dst, float4 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) +{ + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) +{ + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float float_from_int8(int8_t u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 float_from_int8(int16_t u) +{ + union { + int16_t int16; + int8_t int8[2]; + }; + int16 = u; + return make_float2(int8[0], int8[1]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 float_from_int8(int32_t u) +{ + union { + int32_t int32; + int8_t int8[4]; + }; + int32 = u; + return make_float4(int8[0], int8[1], int8[2], int8[3]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off +inline __device__ Float8_ float_from_int8(int64_t u) +{ + union { + int64_t int64; + int16_t int16[4]; + }; + int64 = u; + return Float8_ {float_from_int8(int16[0]), + float_from_int8(int16[1]), + float_from_int8(int16[2]), + float_from_int8(int16[3])}; +} +// clang-format on + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int8_t cast_to_int8(float val) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int32_t cast_to_int8(float4 val) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int8[0] = cast_to_int8(val.x); + int8[1] = cast_to_int8(val.y); + int8[2] = cast_to_int8(val.z); + int8[3] = cast_to_int8(val.w); + return int32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int64_t cast_to_int8(Float8_ val) +{ + union { + int8_t int8[8]; + int64_t int64; + }; + int8[0] = cast_to_int8(val.x.x); + int8[1] = cast_to_int8(val.x.y); + int8[2] = cast_to_int8(val.y.x); + int8[3] = cast_to_int8(val.y.y); + int8[4] = cast_to_int8(val.z.x); + int8[5] = cast_to_int8(val.z.y); + int8[6] = cast_to_int8(val.w.x); + int8[7] = cast_to_int8(val.w.y); + return int64; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline size_t smem_size_in_bytes(const Multihead_attention_params& params, + int threads_per_value, + int threads_per_block) +{ + // The amount of shared memory needed to store the Q*K^T values in float. + const int max_timesteps = min(params.timestep, params.memory_max_len); + size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + + // The extra memory needed if we are not using floats for the final logits. + size_t logits_sz = 0; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(T) != 4) { + // TDOD + logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) : + div_up(max_timesteps + 1, 4) * 4 * sizeof(T); + } +#endif + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; + + size_t transpose_rotary_size = 0; + if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); + } + + // The max. + return max(max(softmax_sz, red_sz), transpose_rotary_size); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ constexpr uint32_t shfl_mask(int threads) +{ + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The type of the inputs. Supported types: float and half. + typename T, + // The hidden dimension per head. + int Dh, + int Dh_MAX, + // The number of threads per key. + int THREADS_PER_KEY, + // The number of threads per value. + int THREADS_PER_VALUE, + // The number of threads in a threadblock. + int THREADS_PER_BLOCK, + bool DO_CROSS_ATTENTION> +__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) +{ + + // Make sure the hidden dimension per head is a multiple of the number of threads per key. + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + // Make sure the hidden dimension per head is a multiple of the number of threads per value. + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + // The size of a warp. + constexpr int WARP_SIZE = 32; + // The number of warps in a threadblock. + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // Use smem_size_in_bytes (above) to determine the amount of shared memory. + extern __shared__ char smem_[]; + + // The shared memory for the Q*K^T values and partial logits in softmax. + float* qk_smem = reinterpret_cast(smem_); + + // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. + char* logits_smem_ = smem_; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(T) != 4) { + // TODO - change to tlength + const int max_timesteps = min(params.timestep, params.memory_max_len); + logits_smem_ += + (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + } + T* logits_smem = reinterpret_cast(logits_smem_); +#else + float* logits_smem = reinterpret_cast(logits_smem_); +#endif + + // The shared memory to do the final reduction for the output values. Reuse qk_smem. + T* out_smem = reinterpret_cast(smem_); + + // The shared memory buffers for the block-wide reductions. One for max, one for sum. + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec = typename Qk_vec_::Type; + + // Use alignment for safely casting the shared buffers as Qk_vec. + // Shared memory to store Q inputs. + __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; + + // This is one of the reasons we should have a separate kernel for cross attention + __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec = typename Qk_vec_::Type; + // The number of elements per vector. + constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // We will use block wide reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // The number of vectors per warp. + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread + // owns x elements, we have to decompose the linear index into chunks of x values and the posi- + // tion of the thread in that chunk. + + // The number of elements in a chunk of 16B (that's the x in the above formula). + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + // The number of K vectors in 16B. + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); + + // The batch/beam idx + const int bi = blockIdx.y; + if (params.finished != nullptr && params.finished[bi] == true) { + return; + } + // The beam idx + const int beami = bi % params.beam_width; + // The "beam-aware" batch idx + const int bbi = bi / params.beam_width; + // The head. + const int num_kv_heads = params.num_kv_heads; + const int kv_rep = (params.num_heads / num_kv_heads); + const int hi = blockIdx.x; + const int hi_kv = hi / kv_rep; + + // Combine the batch and the head indices. + const int bhi = bi * params.num_heads + hi; + const int bhi_kv = bi * (params.num_heads / kv_rep) + hi_kv; + // Combine the "beam-aware" batch idx and the head indices. + const int bbhi = bbi * params.beam_width * params.num_heads + hi; + const int bbhi_kv = bbi * params.beam_width * (params.num_heads / kv_rep) + hi_kv; + // The thread in the block. + const int tidx = threadIdx.x; + + const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); + // Every kv_rep threads have the same kv_cache values. So only the first one writes back. + const int write_kv_cache = handle_kv && (hi % kv_rep == 0); + + // While doing the product Q*K^T for the different keys we track the max. + float qk_max = -FLT_MAX; + + float qk = 0.0F; + + // int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; + const int q_base_offset = bi * params.stride + hi * Dh; + const int k_base_offset = bi * params.stride + hi_kv * Dh; + const int v_base_offset = k_base_offset; + + const size_t bi_seq_len_offset = bi * params.memory_max_len; + + // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : + (params.length_per_sample == nullptr) ? + params.timestep : + params.length_per_sample[bi] + params.max_prefix_prompt_length; + const int first_step = max(0, tlength + 1 - params.memory_max_len); + const int tlength_circ = tlength % params.memory_max_len; + + // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. + const bool is_masked = tidx >= QK_VECS_PER_WARP; + + // The offset in the Q and K buffer also accounts for the batch. + // int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + int q_offset = q_base_offset + tidx * QK_VEC_SIZE; + int k_offset = k_base_offset + tidx * QK_VEC_SIZE; + int v_offset = k_offset; + + // The offset in the bias buffer. + // int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE; + int v_bias_offset = k_bias_offset; + + const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; + const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; + + // Trigger the loads from the Q and K buffers. + Qk_vec q; + zero(q); + if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto q_scaling = params.qkv_scale_out[0]; + const auto q_quant = + *reinterpret_cast(&reinterpret_cast(params.q)[q_offset]); + + convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); + } + else { + q = *reinterpret_cast(¶ms.q[q_offset]); + } + } + + Qk_vec k; + zero(k); + if (DO_CROSS_ATTENTION) { + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength * QK_ELTS_IN_16B + ci; + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + *reinterpret_cast(¶ms.k_cache[offset]) : + k; + } + else { + if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto k_scaling = params.qkv_scale_out[1]; + const auto k_quant = + *reinterpret_cast(&reinterpret_cast(params.k)[k_offset]); + + convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); + } + else { + k = *reinterpret_cast(¶ms.k[k_offset]); + } + } + } + + // Trigger the loads from the Q and K bias buffers. + Qk_vec q_bias; + zero(q_bias); + q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? + *reinterpret_cast(¶ms.q_bias[q_bias_offset]) : + q_bias; + + Qk_vec k_bias; + zero(k_bias); + if (handle_kv) { + k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? + *reinterpret_cast(¶ms.k_bias[k_bias_offset]) : + k_bias; + } + + // Computes the Q/K values with bias. + q = add(q, q_bias); + if (handle_kv) { + k = add(k, k_bias); + } + if (do_ia3 && !is_masked) { + k = mul( + k, + *reinterpret_cast( + ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE])); + } + + // Padded len + const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; + if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { + if (handle_kv) { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + } + else { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + } + } + else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + + T* q_smem = reinterpret_cast(smem_); + T* k_smem = q_smem + params.rotary_embedding_dim; + + const int half_rotary_dim = params.rotary_embedding_dim / 2; + const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; + const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts + + assert(half_rotary_dim % QK_VEC_SIZE == 0); + + if (do_rotary) { + *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; + + if (handle_kv) { + *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; + } + } + + __syncthreads(); + + const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; + if (do_rotary) { + mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + + if (handle_kv) { + mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + + mmha::apply_rotary_embedding( + q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + + mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + } + else { + mmha::apply_rotary_embedding( + q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base); + } + mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + } + + __syncthreads(); + + if (do_rotary) { + q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); + if (handle_kv) { + k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); + } + } + + __syncthreads(); + } + + if (!is_masked) { + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + + // Store Dh values of k_bias into smem, since will need to add later + // if params.timestep == 0 + if (DO_CROSS_ATTENTION && params.timestep == 0) { + *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; + } + + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength_circ * QK_ELTS_IN_16B + ci; + + if (write_kv_cache) { + // Trigger the stores to global memory. + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[offset]) = k; + } + } + + // Compute \sum_i Q[i] * K^T[i] for the current timestep. +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; +#else + using Qk_vec_acum = Qk_vec; +#endif + qk = dot(q, k); + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if (tidx == 0) { + // Normalize qk. + qk *= params.inv_sqrt_dh; + if (params.relative_attention_bias != nullptr) { + // TODO (Haotian): check whether we should replace hi with hi_kv, + // although params.relative_attention_bias is usually not used. + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + (tlength - padd_len) * params.relative_attention_bias_stride + + (tlength - padd_len)]); + } + // Add alibi positional encoding + // qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0; + // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. + + qk_max = qk; + qk_smem[tlength - first_step] = qk; + // qk_smem[params.timestep] = qk; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The type of queries and keys for the math in the Q*K^T product. + using K_vec = typename K_vec_::Type; + // The number of elements per vector. + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + // The number of elements per thread. + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + // The number of vectors per thread. + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + // The position the first key loaded by each thread from the cache buffer (for this B * H). + int ko = tidx / THREADS_PER_KEY; + // The position of the thread in the chunk of keys. + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); + + // Load the Q values from shared memory. The values are reused during the loop on K. + K_vec q_vec[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + + K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; + if (DO_CROSS_ATTENTION && params.timestep == 0) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + } + + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + // The base pointer for the key in the cache buffer. + T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* k_cache_batch = ¶ms.k_cache[bbhi_kv * params.memory_max_len * Dh + ki]; + + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). + // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + // prefix prompt length if has + const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; + + // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. + const bool has_beams = params.cache_indir != nullptr; + const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; + + for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + + // The keys loaded from the key cache. + K_vec k[K_VECS_PER_THREAD]; + K_vec k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.memory_max_len + ti_circ; + // if( ti < params.timestep ) { + const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); + if (ti < tlength) { + if (!within_bounds) { + k[ii] = k_vec_zero; + } + else { + if (has_beams) { + const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; + k[ii] = *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]); + } + else { + k[ii] = *reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]); + } + } + // add bias and update k_cache + if (DO_CROSS_ATTENTION && params.timestep == 0) { + k[ii] = add(k[ii], k_bias_vec[ii]); + + if (do_ia3) { + k[ii] = mul( + k[ii], + *reinterpret_cast( + ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki + + ii * THREADS_PER_KEY * K_VEC_SIZE])); + } + + if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { + *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; + } + } + } + } + + // Perform the dot product and normalize qk. + // + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + ti]); + } + if (params.linear_bias_slopes != nullptr) { + // Apply the linear position bias: (ki - qi) * slope[hi]. + // The padding token locates between the input context and the generated tokens. + // We need to remove the number of padding tokens in the distance computation. + // ti : 0 1 2 3 4 5 6 7 8 9(tlength) + // token: i i i i p p p o o o where i=input, p=pad, o=output. + // e.g. ti = 2, dist = (9 - 3) - 2 = 4. + int max_context_length = params.max_prefix_prompt_length + params.max_input_length; + float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; + + qk += mul(params.linear_bias_slopes[hi], dist); + } + // Add alibi positional encoding + // qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0; + qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = qk; + } + } + +// Perform the final reduction to compute the max inside each warp. +// +// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the +// group so it's not needed to run the reduction inside the group (again). +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + const int warp = tidx / WARP_SIZE; + const int lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Compute the logits and start the sum. + float sum = 0.f; + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); + sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // Normalize the logits. + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + const size_t cross_attention_out_offset = + params.is_return_cross_attentions ? + bhi_kv * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : + 0; + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + float logit = qk_smem[ti - first_step] * inv_sum; + if (params.is_return_cross_attentions) { + params.cross_attention_out[cross_attention_out_offset + ti] = logit; + } + convert_from_float(logits_smem[ti - first_step], logit); + } + + // Put Values part below so we leverage __syncthreads + // from the previous step + + // The number of elements per vector. + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + // A vector of V elements for the current timestep. + using V_vec = typename V_vec_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + + // The base pointer for the value in the cache buffer. + T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* v_cache_batch = ¶ms.v_cache[bbhi_kv * params.memory_max_len * Dh + vi]; + + // The number of values processed per iteration of the loop. + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + // One group of threads computes the product(s) for the current timestep. + V_vec v_bias; + zero(v_bias); + // if( vo == params.timestep % V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + if (handle_kv) { + if (vo == tlength % V_PER_ITER) { + // Trigger the loads from the V bias buffer. + if (params.v_bias != nullptr) { + v_bias = *reinterpret_cast(¶ms.v_bias[hi_kv * Dh + vi]); + } + if (DO_CROSS_ATTENTION) { + *reinterpret_cast(&bias_smem[vi]) = v_bias; + } + } + } + } + + // From previous, before values, step + // Also make sure the logits are in shared memory. + __syncthreads(); + + // Values continued +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec; +#endif + // The partial outputs computed by each thread. + V_vec_acum out; + zero(out); + + // Loop over the timesteps to compute the partial outputs. + // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + + // Fetch offset based on cache_indir when beam sampling + const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; + const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh; + // Load the values from the cache. + V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh]); + if (DO_CROSS_ATTENTION && params.timestep == 0) { + v = add(v, *reinterpret_cast(&bias_smem[vi])); + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + *reinterpret_cast(&v_cache[ti * Dh]) = v; + } + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else + T logit = logits_smem[ti - first_step]; + + // Update the partial sums. + out = fma(logit, v, out); +#endif + } + } + + // One group of threads computes the product(s) for the current timestep. + // if( vo == params.timestep % V_PER_ITER ) { + if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { + + V_vec v; + if (DO_CROSS_ATTENTION) { + v = *reinterpret_cast(&v_cache[tlength * Dh]); + } + else { + // Trigger the loads from the V buffer. + const auto v_offset = v_base_offset + vi; + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto v_scaling = params.qkv_scale_out[2]; + const auto v_quant = + *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); + + convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); + } + else { + v = *reinterpret_cast(¶ms.v[v_offset]); + } + // Trigger the loads from the V bias buffer. + // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + } + + // Compute the V values with bias. + v = add(v, v_bias); + if (write_kv_cache) { + + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = v; + } + + // Initialize the output value with the current timestep. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + // out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); +#else + // out = fma(logits_smem[params.timestep], v, out); + out = fma(logits_smem[tlength - first_step], v, out); +#endif + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + out = mul(*params.attention_out_scale, out); + *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = + cast_to_int8(out); + } + else { + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); + } +#else + // TODO: support int8_mode? + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace mmha + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); diff --git a/awq/kernels/csrc/attention/decoder_masked_multihead_attention_utils.h b/awq/kernels/csrc/attention/decoder_masked_multihead_attention_utils.h new file mode 100644 index 0000000..fe46257 --- /dev/null +++ b/awq/kernels/csrc/attention/decoder_masked_multihead_attention_utils.h @@ -0,0 +1,1786 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_bf16_wrapper.h" +#include "cuda_bf16_fallbacks.cuh" +#include + +using namespace fastertransformer; + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Float4_ { + float2 x; + float2 y; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct num_elems; +template<> +struct num_elems { + static constexpr int value = 1; +}; +template<> +struct num_elems { + static constexpr int value = 2; +}; +template<> +struct num_elems { + static constexpr int value = 4; +}; +template<> +struct num_elems { + static constexpr int value = 4; +}; +template<> +struct num_elems { + static constexpr int value = 8; +}; + +template<> +struct num_elems { + static constexpr int value = 2; +}; +template<> +struct num_elems { + static constexpr int value = 4; +}; +template<> +struct num_elems { + static constexpr int value = 8; +}; + +#ifdef ENABLE_BF16 +template<> +struct num_elems<__nv_bfloat162> { + static constexpr int value = 2; +}; +template<> +struct num_elems { + static constexpr int value = 4; +}; +template<> +struct num_elems { + static constexpr int value = 8; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct packed_type; +template +struct packed_type { + using type = T; +}; +template<> +struct packed_type { + using type = int16_t; +}; +template<> +struct packed_type { + using type = int32_t; +}; +template<> +struct packed_type { + using type = int64_t; +}; + +template<> +struct packed_type { + using type = float2; +}; +template<> +struct packed_type { + using type = float4; +}; +template<> +struct packed_type { + using type = Float8_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float add(float a, float b) +{ + return a + b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 add(float2 a, float2 b) +{ + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 add(float4 a, float4 b) +{ + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) +{ + return a + b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) +{ + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) +{ + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t add(uint16_t a, uint16_t b) +{ + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t add(uint32_t a, uint32_t b) +{ + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 add(uint2 a, uint2 b) +{ + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 add(uint4 a, uint4 b) +{ + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t float_to_half(float f) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? + float zero = 0.f; + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#endif + return tmp.u16[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t float2_to_half2(float2 f) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float half_to_float(uint16_t h) +{ + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 half2_to_float2(uint32_t v) +{ + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float add(float a, uint16_t b) +{ + return a + half_to_float(b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float add(float a, __nv_bfloat16 b) +{ + return a + __bfloat162float(b); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 add(uint32_t a, float2 fb) +{ + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ add(uint2 a, Float4_ fb) +{ + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ add(uint4 a, Float8_ fb) +{ + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t h0_h0(uint16_t a) +{ + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(float a, float b, float c) +{ + return a * b + c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float2 a, float2 b, float2 c) +{ + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float a, float2 b, float2 c) +{ + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float4 a, float4 b, float4 c) +{ + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float a, float4 b, float4 c) +{ + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) +{ + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) +{ + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) +{ + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) +{ + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) +{ + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) +{ + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) +{ + return fma(h0_h0(a), b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) +{ + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) +{ + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) +{ + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) +{ + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) +{ + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) +{ + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) +{ + return fma(h0_h0(a), b, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) +{ + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) +{ + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) +{ + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) +{ + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(a, b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(bf162bf162(a), b, c); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) +{ + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) +{ + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) +{ + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) +{ + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) +{ + return fma(bf162bf162(a), b, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) +{ + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) +{ + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) +{ + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) +{ + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ Acc mul(A a, B b) +{ + return a * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(float a, float b) +{ + return a * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(float2 a, float2 b) +{ + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(float a, float2 b) +{ + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float4 mul(float4 a, float4 b) +{ + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float4 mul(float a, float4 b) +{ + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(float a, Float8_ b) +{ + Float8_ c; + c.x = make_float2(a * b.x.x, a * b.x.y); + c.y = make_float2(a * b.y.x, a * b.y.y); + c.z = make_float2(a * b.z.x, a * b.z.y); + c.w = make_float2(a * b.w.x, a * b.w.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) +{ + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) +{ + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) +{ + return mul(h0_h0(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint2 mul(uint2 a, uint2 b) +{ + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint2 mul(uint16_t a, uint2 b) +{ + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint4 mul(uint4 a, uint4 b) +{ + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint4 mul(uint16_t a, uint4 b) +{ + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(uint16_t a, uint16_t b) +{ + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(uint16_t a, float b) +{ + return half_to_float(a) * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(uint32_t a, uint32_t b) +{ + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(uint16_t a, uint32_t b) +{ + return mul(h0_h0(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(uint2 a, uint2 b) +{ + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(uint16_t a, uint2 b) +{ + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(uint4 a, uint4 b) +{ + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(uint16_t a, uint4 b) +{ + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hmul(a, b); +#else + return bf16hmul(a, b); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hmul2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) +{ + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) +{ + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) +{ + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) +{ + float fa = (float)a; + float fb = (float)b; + return fa * fb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(__nv_bfloat16 a, float b) +{ + return __bfloat162float(a) * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) +{ + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) +{ + return mul(bf162bf162(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) +{ + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) +{ + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float v) +{ + return v; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float2 v) +{ + return v.x + v.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float4 v) +{ + return v.x + v.y + v.z + v.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float sum(__nv_bfloat162 v) +{ + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(bf16_4_t v) +{ + return sum(v.x) + sum(v.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(bf16_8_t v) +{ + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint16_t v) +{ + return half_to_float(v); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint32_t v) +{ + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint2 v) +{ + uint32_t c = add(v.x, v.y); + return sum(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint4 v) +{ +#if 1 + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); +#else + uint32_t c = add(v.x, v.y); + uint32_t d = add(v.z, v.w); + c = add(c, d); +#endif + return sum(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(Float4_ v) +{ + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(Float8_ v) +{ + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float dot(T a, T b) +{ + return sum(mul(a, b)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float dot(T a, T b) +{ + return sum(mul(a, b)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void zero(uint16_t& dst) +{ + dst = uint16_t(0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void zero(T& dst) +{ + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base) +{ + const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim); + return {cos(inv_freq), sin(inv_freq)}; +} + +inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef) +{ + float2 rot_v; + rot_v.x = coef.x * v.x - coef.y * v.y; + rot_v.y = coef.x * v.y + coef.y * v.x; + return rot_v; +} + +inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) +{ + float2 fv = half2_to_float2(v); + float2 rot_fv = rotary_embedding_transform(fv, coef); + return float2_to_half2(rot_fv); +} + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) +{ + float2 fv = bf1622float2(v); + float2 rot_fv = rotary_embedding_transform(fv, coef); + return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); +} +#endif + +inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + return; +} + +inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + return; +} + +inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + + Float4_& q_ = *reinterpret_cast(&q); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q_.x = rotary_embedding_transform(q_.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q_.y = rotary_embedding_transform(q_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + + Float4_& q_ = *reinterpret_cast(&q); + Float4_& k_ = *reinterpret_cast(&k); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q_.x = rotary_embedding_transform(q_.x, coef0); + k_.x = rotary_embedding_transform(k_.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q_.y = rotary_embedding_transform(q_.y, coef1); + k_.y = rotary_embedding_transform(k_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} + +#ifdef ENABLE_BF16 +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void +apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} +#endif // ENABLE_BF16 + +template +__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); + +template<> +__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch) +{ + return; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = smem[transpose_idx]; + tmp.u16[1] = smem[smem_pitch + transpose_idx]; + + vec = tmp.u32; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp_1, tmp_2; + tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + union { + uint2 u32x2; + uint16_t u16[4]; + } tmp_3; + tmp_3.u16[0] = tmp_1.u16[0]; + tmp_3.u16[1] = tmp_2.u16[0]; + tmp_3.u16[2] = tmp_1.u16[1]; + tmp_3.u16[3] = tmp_2.u16[1]; + + vec = tmp_3.u32x2; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint64_t u64; + uint16_t u16[4]; + } tmp_1, tmp_2; + tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + union { + uint4 u32x4; + uint16_t u16[8]; + } tmp_3; + tmp_3.u16[0] = tmp_1.u16[0]; + tmp_3.u16[1] = tmp_2.u16[0]; + tmp_3.u16[2] = tmp_1.u16[1]; + tmp_3.u16[3] = tmp_2.u16[1]; + tmp_3.u16[4] = tmp_1.u16[2]; + tmp_3.u16[5] = tmp_2.u16[2]; + tmp_3.u16[6] = tmp_1.u16[3]; + tmp_3.u16[7] = tmp_2.u16[3]; + + vec = tmp_3.u32x4; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + __nv_bfloat16 bf16[2]; + } tmp_1, tmp_2; + tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; + vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; +} + +template<> +__device__ __inline__ void +vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + union { + uint64_t u64; + __nv_bfloat16 bf16[4]; + } tmp_1, tmp_2; + tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; + vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; + vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]}; + vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]}; +} +#endif // ENABLE_BF16 + +template<> +__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch) +{ + vec.x = smem[transpose_idx]; + vec.z = smem[transpose_idx + 1]; + vec.y = smem[smem_pitch + transpose_idx]; + vec.w = smem[smem_pitch + transpose_idx + 1]; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + half u16[2]; + } tmp; + tmp.u16[0] = smem[transpose_idx]; + tmp.u16[1] = smem[smem_pitch + transpose_idx]; + + vec = tmp.u32; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + vec.x = smem[transpose_idx]; + vec.y = smem[smem_pitch + transpose_idx]; +} +#endif + +template<> +__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch) +{ + vec.x = smem[transpose_idx]; + vec.y = smem[smem_pitch + transpose_idx]; +} + +template +__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); + +template<> +__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch) +{ + return; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint64_t u64; + uint16_t u16[4]; + } tmp_1, tmp_2; + + union { + uint4 u32x4; + uint16_t u16[8]; + } tmp_3; + tmp_3.u32x4 = vec; + tmp_1.u16[0] = tmp_3.u16[0]; + tmp_2.u16[0] = tmp_3.u16[1]; + tmp_1.u16[1] = tmp_3.u16[2]; + tmp_2.u16[1] = tmp_3.u16[3]; + tmp_1.u16[2] = tmp_3.u16[4]; + tmp_2.u16[2] = tmp_3.u16[5]; + tmp_1.u16[3] = tmp_3.u16[6]; + tmp_2.u16[3] = tmp_3.u16[7]; + + *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u64; + *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u64; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp_1, tmp_2; + + union { + uint2 u32x2; + uint16_t u16[4]; + } tmp_3; + tmp_3.u32x2 = vec; + tmp_1.u16[0] = tmp_3.u16[0]; + tmp_2.u16[0] = tmp_3.u16[1]; + tmp_1.u16[1] = tmp_3.u16[2]; + tmp_2.u16[1] = tmp_3.u16[3]; + + *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u32; + *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u32; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = vec; + + smem[transpose_idx] = tmp.u16[0]; + smem[smem_pitch + transpose_idx] = tmp.u16[1]; +} + +template<> +__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch) +{ + smem[transpose_idx] = vec.x; + smem[transpose_idx + 1] = vec.z; + smem[smem_pitch + transpose_idx] = vec.y; + smem[smem_pitch + transpose_idx + 1] = vec.w; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + half u16[2]; + } tmp; + + tmp.u32 = vec; + smem[transpose_idx] = tmp.u16[0]; + smem[smem_pitch + transpose_idx] = tmp.u16[1]; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + smem[transpose_idx] = vec.x; + smem[smem_pitch + transpose_idx] = vec.y; +} + +template<> +__device__ __inline__ void +write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); +} + +template<> +__device__ __inline__ void +write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); +} +#endif + +template<> +__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch) +{ + smem[transpose_idx] = vec.x; + smem[smem_pitch + transpose_idx] = vec.y; +} + +} // namespace mmha diff --git a/awq/kernels/csrc/attention/ft_attention.cpp b/awq/kernels/csrc/attention/ft_attention.cpp new file mode 100644 index 0000000..bef83ec --- /dev/null +++ b/awq/kernels/csrc/attention/ft_attention.cpp @@ -0,0 +1,182 @@ +// Adapted from NVIDIA/FasterTransformer and FlashAttention + +#include +#include "ATen/cuda/CUDAContext.h" +#include + +#include "ft_attention.h" +#include "decoder_masked_multihead_attention.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \ + if (TYPE == at::ScalarType::Half) { \ + using scalar_t = at::Half; \ + __VA_ARGS__(); \ + } else if (TYPE == at::ScalarType::BFloat16) { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (TYPE == at::ScalarType::Float) { \ + using scalar_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ + } + +template +void masked_multihead_attention(const Masked_multihead_attention_params& params, + const cudaStream_t& stream); + +template +void cross_multihead_attention(const Masked_multihead_attention_params& params, + const cudaStream_t& stream); + +template +struct SATypeConverter { + using Type = T; +}; + +template<> +struct SATypeConverter { + using Type = uint16_t; +}; + +template<> +struct SATypeConverter { + using Type = __nv_bfloat16; +}; + +template +void set_params(Masked_multihead_attention_params ¶ms, + const size_t batch_size, + const size_t nheads, + const size_t nheads_kv, + const size_t memory_max_seqlen, + const size_t headdim, + const int timestep, + const int rotary_embedding_dim, + const float rotary_base, + const bool neox_rotary_style, + const int qkv_batch_stride, + T *q_ptr, + T *k_ptr, + T *v_ptr, + T *k_cache_ptr, + T *v_cache_ptr, + int *length_per_sample, + float *alibi_slopes_ptr, + T *out_ptr) { + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + params.q = q_ptr; + params.k = k_ptr; + params.v = v_ptr; + params.q_bias = nullptr; + params.k_bias = nullptr; + params.v_bias = nullptr; + params.k_cache = k_cache_ptr; + params.v_cache = v_cache_ptr; + params.linear_bias_slopes = alibi_slopes_ptr; + params.out = out_ptr; + params.cache_indir = nullptr; + params.stride = qkv_batch_stride; + params.batch_size = batch_size; + params.beam_width = 1; + params.memory_max_len = memory_max_seqlen; + params.num_heads = nheads; + params.num_kv_heads = nheads_kv; + params.hidden_size_per_head = headdim; + params.rotary_embedding_dim = rotary_embedding_dim; + params.rotary_base = rotary_base; + params.neox_rotary_style = neox_rotary_style; + params.timestep = timestep; + params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); + params.total_padding_tokens = nullptr; + params.masked_tokens = nullptr; + params.prefix_prompt_lengths = nullptr; + params.max_prefix_prompt_length = 0; + params.relative_attention_bias = nullptr; + params.relative_attention_bias_stride = 0; + params.cross_attention_out = nullptr; + params.max_decoder_seq_len = 0; + params.is_return_cross_attentions = false; + params.finished = nullptr; + params.memory_length_per_sample = nullptr; + params.length_per_sample = length_per_sample; +} + +torch::Tensor single_query_attention(const torch::Tensor q, + const torch::Tensor k, + const torch::Tensor v, + torch::Tensor k_cache, + torch::Tensor v_cache, + c10::optional length_per_sample_, + c10::optional alibi_slopes_, + const int timestep, + const int rotary_embedding_dim, + const float rotary_base, + // neox_rotary_style = not interleaved + const bool neox_rotary_style) { + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); + int batch_size = v_cache.size(0); + int nheads = q.size(1); + int nheads_kv = v_cache.size(1); + int memory_max_seqlen = v_cache.size(2); + int headdim = v_cache.size(3); + CHECK_SHAPE(q, batch_size, nheads, headdim); + CHECK_SHAPE(k, batch_size, nheads_kv, headdim); + CHECK_SHAPE(v, batch_size, nheads_kv, headdim); + CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim); + // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 + int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; + CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize); + TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); + TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); + TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); + // TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0)); + CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); + + if (length_per_sample_.has_value()) { + auto length_per_sample = length_per_sample_.value(); + CHECK_DEVICE(length_per_sample); + CHECK_SHAPE(length_per_sample, batch_size); + CHECK_CONTIGUOUS(length_per_sample); + TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); + } + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + CHECK_SHAPE(alibi_slopes, nheads); + CHECK_CONTIGUOUS(alibi_slopes); + TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + torch::Tensor out = torch::empty_like(q); + + DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { + using DataType = typename SATypeConverter::Type; + Masked_multihead_attention_params params; + set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, + timestep, rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0), + reinterpret_cast(q.data_ptr()), + reinterpret_cast(k.data_ptr()), + reinterpret_cast(v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + length_per_sample_.has_value() + ? length_per_sample_.value().data_ptr() : nullptr, + alibi_slopes_.has_value() + ? alibi_slopes_.value().data_ptr(): nullptr, + reinterpret_cast(out.data_ptr())); + auto stream = at::cuda::getCurrentCUDAStream(); + masked_multihead_attention(params, stream); + }); + return out; +} \ No newline at end of file diff --git a/awq/kernels/csrc/attention/ft_attention.h b/awq/kernels/csrc/attention/ft_attention.h new file mode 100644 index 0000000..df67d53 --- /dev/null +++ b/awq/kernels/csrc/attention/ft_attention.h @@ -0,0 +1,15 @@ +#pragma once +#include + + +torch::Tensor single_query_attention(const torch::Tensor q, + const torch::Tensor k, + const torch::Tensor v, + torch::Tensor k_cache, + torch::Tensor v_cache, + c10::optional length_per_sample_, + c10::optional alibi_slopes_, + const int timestep, + const int rotary_embedding_dim = 0, + const float rotary_base = 10000.0f, + const bool neox_rotary_style=true); \ No newline at end of file diff --git a/awq/kernels/csrc/attention/setup.py b/awq/kernels/csrc/attention/setup.py new file mode 100644 index 0000000..dc479f2 --- /dev/null +++ b/awq/kernels/csrc/attention/setup.py @@ -0,0 +1,152 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +from packaging.version import parse, Version + +from setuptools import setup, find_packages +import subprocess + +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != torch_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" + "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +cmdclass = {} +ext_modules = [] + +# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h +# See https://github.com/pytorch/pytorch/pull/70650 +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + +raise_if_cuda_home_none("--ft_attention") +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("ft_attention is only supported on CUDA 11 and above") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_70,code=sm_70") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + +ext_modules.append( + CUDAExtension( + name="ft_attention", + sources=[ + "ft_attention.cpp", + "decoder_masked_multihead_attention.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-DENABLE_BF16"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-DENABLE_BF16", # TODO + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[this_dir], + ) +) + +setup( + name="ft_attention", + version="0.1", + description="Attention for single query from FasterTransformer", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension} if ext_modules else {}, +) diff --git a/awq/kernels/csrc/pybind.cpp b/awq/kernels/csrc/pybind.cpp index cda4a6c..220911a 100644 --- a/awq/kernels/csrc/pybind.cpp +++ b/awq/kernels/csrc/pybind.cpp @@ -1,12 +1,19 @@ #include #include +#include "attention/ft_attention.h" #include "layernorm/layernorm.h" #include "quantization/gemm_cuda.h" +#include "quantization/gemv_cuda.h" #include "position_embedding/pos_encoding.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel"); m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); + m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel."); m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); + m.def("single_query_attention", &single_query_attention, "Attention with a single query", + py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), + py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0, + py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); } diff --git a/awq/kernels/csrc/quantization/gemv_cuda.cu b/awq/kernels/csrc/quantization/gemv_cuda.cu new file mode 100644 index 0000000..3a55e66 --- /dev/null +++ b/awq/kernels/csrc/quantization/gemv_cuda.cu @@ -0,0 +1,247 @@ +// Inspired by https://github.com/ankan-ban/llama_cu_awq +/* + +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} +} + +*/ + +#include +#include +#include +#include "gemv_cuda.h" +#define VECTORIZE_FACTOR 8 +#define Q_VECTORIZE_FACTOR 8 +#define PACK_FACTOR 8 +#define WARP_SIZE 32 + + +// Reduce sum within the warp using the tree reduction algorithm. +__device__ __forceinline__ float warp_reduce_sum(float sum) { + #pragma unroll + for(int i = 4; i >= 0; i--){ + sum += __shfl_down_sync(0xffffffff, sum, 1<(zeros + oc_idx * zeros_w + packed_group_idx * 2); + uint32_t packed_weights[4]; + // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) + *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); + // load scaling factors + // g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups. + float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]); + float current_zeros = (float)((packed_zeros >> (threadIdx.x / 2 * 4)) & 0xF); + int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; + const float4* inputs_ptr = inputs + inputs_ptr_delta; + // multiply 32 weights with 32 inputs + #pragma unroll + for (int ic_0 = 0; ic_0 < 4; ic_0++){ + // iterate over different uint32_t packed_weights in this loop + uint32_t current_packed_weight = packed_weights[ic_0]; + half packed_inputs[PACK_FACTOR]; + // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) + if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { + *((float4*)packed_inputs) = *(inputs_ptr + ic_0); + #pragma unroll + for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ + // iterate over 8 numbers packed within each uint32_t number + float current_single_weight_fp = (float)(current_packed_weight & 0xF); + float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); + //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); + psum += dequantized_weight * __half2float(packed_inputs[ic_1]); + current_packed_weight = current_packed_weight >> 4; + } + } + } + } + psum = warp_reduce_sum(psum); + if (threadIdx.x == 0) { + outputs[oc_idx] = __float2half(psum); + } +} + + +/* +Computes GEMV (group_size = 128). + +Args: + inputs: vector of shape [batch_size, IC]; + weight: matrix of shape [OC, IC / 8]; + output: vector of shape [OC]; + zeros: matrix of shape [OC, IC / group_size / 8]; + scaling_factors: matrix of shape [OC, IC / group_size]; + +Notes: + One cannot infer group_size from the shape of scaling factors. + the second dimension is rounded up to a multiple of PACK_FACTOR. +*/ +__global__ void gemv_kernel_g128( + const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs, + const int IC, const int OC){ + const int group_size = 128; + float psum = 0; + const int batch_idx = blockIdx.z; + const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y; + const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR; + half* outputs = _outputs + batch_idx * OC; + const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR); + const int weight_w = IC / PACK_FACTOR; + // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address + const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR); + // consistent with input shape + const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR; + //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w); + // tile size: 4 OC x 1024 IC per iter + for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){ + // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros. + uint32_t packed_zeros = *(zeros + oc_idx * zeros_w + packed_group_idx); + uint32_t packed_weights[4]; + // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) + *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); + // load scaling factors + // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups. + float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); + float current_zeros = (float)((packed_zeros >> (threadIdx.x / 4 * 4)) & 0xF); + int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; + const float4* inputs_ptr = inputs + inputs_ptr_delta; + // multiply 32 weights with 32 inputs + #pragma unroll + for (int ic_0 = 0; ic_0 < 4; ic_0++){ + // iterate over different uint32_t packed_weights in this loop + uint32_t current_packed_weight = packed_weights[ic_0]; + half packed_inputs[PACK_FACTOR]; + // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) + if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { + *((float4*)packed_inputs) = *(inputs_ptr + ic_0); + #pragma unroll + for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ + // iterate over 8 numbers packed within each uint32_t number + float current_single_weight_fp = (float)(current_packed_weight & 0xF); + float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); + //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); + psum += dequantized_weight * __half2float(packed_inputs[ic_1]); + current_packed_weight = current_packed_weight >> 4; + } + } + } + } + psum = warp_reduce_sum(psum); + if (threadIdx.x == 0) { + outputs[oc_idx] = __float2half(psum); + } +} + + +/* +Computes GEMV (PyTorch interface). + +Args: + _in_feats: tensor of shape [B, IC]; + _kernel: int tensor of shape [OC, IC // 8]; + _zeros: int tensor of shape [OC, IC // G // 8]; + _scaling_factors: tensor of shape [OC, IC // G]; + blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; + blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; + +Returns: + out_feats: tensor of shape [B, OC]; +*/ +torch::Tensor gemv_forward_cuda( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int group_size) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + // int kernel_volume = _out_in_map.size(1); + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + // auto out_in_map = _out_in_map.data_ptr(); + auto options = + torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + // kernel is [OC, IC] + at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + int blockDim_z = num_out_feats; + dim3 num_blocks(1, num_out_channels / 4, num_out_feats); + dim3 num_threads(32, 4); + if (group_size == 64) + { + gemv_kernel_g64<<>>( + // pointers + in_feats, kernel, zeros, scaling_factors, out_feats, + // constants + num_in_channels, num_out_channels + ); + } + else if (group_size == 128) + { + gemv_kernel_g128<<>>( + // pointers + in_feats, kernel, zeros, scaling_factors, out_feats, + // constants + num_in_channels, num_out_channels + ); + } + return _out_feats; +;} + diff --git a/awq/kernels/csrc/quantization/gemv_cuda.h b/awq/kernels/csrc/quantization/gemv_cuda.h new file mode 100644 index 0000000..748abc5 --- /dev/null +++ b/awq/kernels/csrc/quantization/gemv_cuda.h @@ -0,0 +1,9 @@ +#pragma once +#include + +torch::Tensor gemv_forward_cuda( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int group_size); diff --git a/awq/kernels/setup.py b/awq/kernels/setup.py index 78e938e..88e3095 100644 --- a/awq/kernels/setup.py +++ b/awq/kernels/setup.py @@ -1,9 +1,31 @@ from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension + extra_compile_args = { - "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"], - "nvcc": ["-O3", "-std=c++17"], + "cxx": [ + "-g", + "-O3", + "-fopenmp", + "-lgomp", + "-std=c++17", + "-DENABLE_BF16" + ], + "nvcc": [ + "-O3", + "-std=c++17", + "-DENABLE_BF16", # TODO + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--threads=8" + ], } setup( @@ -15,8 +37,11 @@ sources=[ "csrc/pybind.cpp", "csrc/quantization/gemm_cuda_gen.cu", + "csrc/quantization/gemv_cuda.cu", "csrc/layernorm/layernorm.cu", - "csrc/position_embedding/pos_encoding_kernels.cu" + "csrc/position_embedding/pos_encoding_kernels.cu", + "csrc/attention/ft_attention.cpp", + "csrc/attention/decoder_masked_multihead_attention.cu" ], extra_compile_args=extra_compile_args, ), diff --git a/awq/quantize/pre_quant.py b/awq/quantize/pre_quant.py index cf24d8c..5af81a5 100644 --- a/awq/quantize/pre_quant.py +++ b/awq/quantize/pre_quant.py @@ -20,7 +20,7 @@ def get_named_linears(module): def get_blocks(model): - if isinstance(model, LlamaForCausalLM): + if model.__class__.__name__ == 'LlamaForCausalLM': layers = model.model.layers elif isinstance(model, OPTForCausalLM): layers = model.model.decoder.layers diff --git a/awq/quantize/qmodule.py b/awq/quantize/qmodule.py index e40f107..e77555a 100644 --- a/awq/quantize/qmodule.py +++ b/awq/quantize/qmodule.py @@ -4,6 +4,25 @@ import awq_inference_engine # with CUDA kernels +def make_divisible(c, divisor): + return (c + divisor - 1) // divisor + + +def calculate_zeros_width(in_features, group_size=128, pack_num=8): + if group_size >= 128: + size_multiplier = 1 + elif group_size == 64: + size_multiplier = 2 + elif group_size == 32: + size_multiplier = 4 + else: + raise NotImplementedError + + base_width = make_divisible(in_features // group_size, pack_num) + base_width = make_divisible(base_width, size_multiplier) * size_multiplier + return base_width + + class ScaledActivation(nn.Module): def __init__(self, module, scales): super().__init__() @@ -25,13 +44,15 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): self.out_features = out_features self.w_bit = w_bit self.group_size = group_size if group_size != -1 else in_features + self.split_k_iters = 8 # quick sanity check (make sure aligment) assert self.in_features % self.group_size == 0 assert out_features % (32 // self.w_bit) == 0 - - self.register_buffer('qweight', torch.zeros((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) - self.register_buffer('qzeros', torch.zeros((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) - self.register_buffer('scales', torch.zeros((in_features // self.group_size, out_features), dtype=torch.float16, device=dev)) + pack_num = (32 // self.w_bit) + # TODO (Haotian): a function for buffer shape calculation + self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev)) + self.register_buffer('qzeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev)) + self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev)) if bias: self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev)) else: @@ -46,24 +67,31 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze # need scales and zeros info for real quantization assert scales is not None and zeros is not None scale_zeros = zeros * scales - - awq_linear.scales = scales.clone().half() - if linear.bias is not None: - awq_linear.bias = linear.bias.clone().half() pack_num = 32 // awq_linear.w_bit + qscales = torch.zeros( + (scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num), + dtype=torch.float16, + device=scales.device + ) + qscales[:, :scales.shape[1]] = scales + # awq_linear.scales = scales.clone().half() + awq_linear.scales = qscales + if linear.bias is not None: + awq_linear.bias = linear.bias.clone().half() intweight = [] for idx in range(awq_linear.in_features): - intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[idx // group_size]) / awq_linear.scales[idx // group_size]).to(torch.int)[:, None]) + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) / awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None]) intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() + # intweight = intweight.t().contiguous() intweight = intweight.to(dtype=torch.int32) qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device) for col in range(intweight.shape[1] // pack_num): if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] + # order_map = [0, 2, 4, 6, 1, 3, 5, 7] + order_map = [0, 1, 2, 3, 4, 5, 6, 7] else: raise NotImplementedError("Only 4-bit are supported for now.") for i in range(pack_num): @@ -72,25 +100,35 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze awq_linear.qweight = qweight zeros = zeros.to(dtype=torch.int32) - qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device) + qzeros = torch.zeros( + (zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)), + dtype=torch.int32, + device=zeros.device, + ) - for col in range(zeros.shape[1] // pack_num): + for col in range((zeros.shape[1] + pack_num - 1) // pack_num): if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] + # order_map = [0, 2, 4, 6, 1, 3, 5, 7] + order_map = [0, 1, 2, 3, 4, 5, 6, 7] else: raise NotImplementedError("Only 4-bit are supported for now.") for i in range(pack_num): + if col * pack_num + order_map[i] >= zeros.shape[1]: + continue qzero_col = zeros[:, col * pack_num + order_map[i]] qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit) awq_linear.qzeros = qzeros - return awq_linear @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features, ) - out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) + # out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.split_k_iters) + # print(x.shape, self.qweight.shape, self.scales.shape, self.qzeros.shape, self.group_size) + out = awq_inference_engine.gemv_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.group_size) out = out + self.bias if self.bias is not None else out + #print(out) + #assert 0 return out.reshape(out_shape) def extra_repr(self) -> str: diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 2670607..0437fc0 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -127,8 +127,8 @@ def real_quantize_model_weight( else: module.cuda() module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config) - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() + # scales = scales.t().contiguous() + # zeros = zeros.t().contiguous() q_linear = WQLinear.from_linear( module, w_bit, q_config['q_group_size'], False, scales, zeros) module.cpu() diff --git a/tinychat/README.md b/tinychat/README.md index e0edc40..7657618 100644 --- a/tinychat/README.md +++ b/tinychat/README.md @@ -29,7 +29,7 @@ The current release supports: ## Examples -Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit inference. The following examples showcase that TinyChat's W4A16 generation is 2.3x faster on RTX 4090 and 1.4x faster on Jetson Orin, compared to the FP16 baselines. (Tested with [LLaMA-2-7b]( https://huggingface.co/meta-llama/Llama-2-7b-chat-hf ) model.) +Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit inference. The following examples showcase that TinyChat's W4A16 generation is up to 3.7x faster on RTX 4090 and 3.3x faster on Jetson Orin, compared to the FP16 baselines. (Tested with [LLaMA-2-7b]( https://huggingface.co/meta-llama/Llama-2-7b-chat-hf ) model.) * TinyChat on RTX 4090: @@ -44,7 +44,7 @@ Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit infe We benchmark TinyChat on A6000 (server-class GPU), 4090 (desktop GPU) and Orin (edge GPU). -We use the default implementation from Huggingface for the FP16 baseline. The INT4 implementation applies AWQ and utilizes our fast W4A16 GPU kernel. Please notice that the end-to-end runtime for INT4 TinyChat could be further improved if we reduce the framework overhead from Huggingface (e.g. utilizing the implementation from TGI). We are working on a new release with even faster inference performance, please stay tuned! +We use the default implementation from Huggingface for the FP16 baseline. The INT4 implementation applies AWQ and utilizes our fast W4A16 GPU kernel. We also apply additional optimization techniques in the latest release. For example, we fuse all the operations in MHA/GQA/MQA into a single kernel, and fuse positional embedding kernels into the attention kernel. We also pre-allocate key-value caches to avoid the online memory allocation overhead from Huggingface. The latency reported in all tables are per-token latency for the generation stage. @@ -52,37 +52,40 @@ The latency reported in all tables are per-token latency for the generation stag | Model | FP16 latency (ms) | INT4 latency (ms) | Speedup | | ----------- |:-----------------:|:-----------------:|:-------:| -| LLaMA-2-7B | 27.14 | 12.44 | 2.18x | -| LLaMA-2-13B | 47.28 | 20.28 | 2.33x | -| Vicuna-7B | 26.06 | 12.43 | 2.10x | -| Vicuna-13B | 44.91 | 17.30 | 2.60x | -| MPT-7B | 22.79 | 16.87 | 1.35x | -| MPT-30B | OOM | 31.57 | -- | -| Falcon-7B | 39.44 | 27.34 | 1.44x | +| LLaMA-2-7B | 27.14 | 8.71 | 3.12x | +| LLaMA-2-13B | 47.28 | 14.64 | 3.23x | +| Vicuna-7B | 26.06 | 8.39 | 3.11x | +| Vicuna-13B | 44.91 | 13.46 | 3.34x | +| MPT-7B | 22.79 | 7.99 | 2.85x | +| MPT-30B | OOM | 28.15 | -- | +| Falcon-7B | 39.44 | 11.71 | 3.37x | ### 4090 Results | Model | FP16 latency (ms) | INT4 latency (ms) | Speedup | | ----------- |:-----------------:|:-----------------:|:-------:| -| LLaMA-2-7B | 19.97 | 8.66 | 2.31x | -| LLaMA-2-13B | OOM | 13.54 | -- | -| Vicuna-7B | 19.09 | 8.61 | 2.22x | -| Vicuna-13B | OOM | 12.17 | -- | -| MPT-7B | 17.09 | 12.58 | 1.36x | -| MPT-30B | OOM | 23.54 | -- | -| Falcon-7B | 29.91 | 19.84 | 1.51x | +| LLaMA-2-7B | 19.97 | 6.02* | 3.31x | +| LLaMA-2-13B | OOM | 10.35 | -- | +| Vicuna-7B | 19.09 | 5.33 | 3.58x | +| Vicuna-13B | OOM | 9.17 | -- | +| MPT-7B | 17.09 | 6.18 | 2.77x | +| MPT-30B | OOM | 20.60 | -- | +| Falcon-7B | 29.91 | 8.02 | 3.73x | + +*: The reason why LLaMA-2-7B is slower than Vicuna-7B is because we need a longer prompt (with > 500 tokens) to prevent the model from talking with itself. If we use the benchmarking strategy from exLLaMA (i.e. only 4 context tokens), our speed is around 195 tokens / second. ### Orin Results | Model | FP16 latency (ms) | INT4 latency (ms) | Speedup | | ----------- |:-----------------:|:-----------------:|:-------:| -| LLaMA-2-7B | 104.71 | 75.11 | 1.39x | -| LLaMA-2-13B | OOM | 136.81 | -- | -| Vicuna-7B | 93.12 | 65.34 | 1.43x | -| Vicuna-13B | OOM | 115.4 | -- | -| MPT-7B | 89.85 | 67.36 | 1.33x | -| Falcon-7B | 147.84 | 102.74 | 1.44x | +| LLaMA-2-7B | 104.71 | 33.07* | 3.17x | +| LLaMA-2-13B | OOM | 58.20 | -- | +| Vicuna-7B | 93.12 | 30.73 | 3.03x | +| Vicuna-13B | OOM | 54.98 | -- | +| MPT-7B | 89.85 | 31.22 | 2.88x | +| Falcon-7B | 147.84 | 45.10 | 3.28x | +*: We can similarly achieve 33 tokens / second on Orin if we use the benchmarking strategy from exLLaMA. ## Usage @@ -153,11 +156,21 @@ python demo.py --model_type llama \ --precision W16A16 ``` +5. (Optional) Run the benchmark script: + +```bash +cd tinychat +python benchmark.py --model_type llama \ + --model_path /PATH/TO/LLAMA2/llama-2-7b-chat \ + --q_group_size 128 +``` +Note: The kv caches in the current implementation are pre-allocated. So if you run out of memory, it might be the case that the kv cache is too large. To solve the problem, you may pass in `--max_seq_len [a smaller number]`. ## Reference -TinyChat is inspired by the following open-source projects: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [vLLM](https://github.com/vllm-project/vllm), [FastChat](https://github.com/lm-sys/FastChat). +TinyChat is inspired by the following open-source projects: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [vLLM](https://github.com/vllm-project/vllm), [FastChat](https://github.com/lm-sys/FastChat), [llama_cu_awq](https://github.com/ankan-ban/llama_cu_awq). + diff --git a/tinychat/benchmark.py b/tinychat/benchmark.py new file mode 100644 index 0000000..5ff87ce --- /dev/null +++ b/tinychat/benchmark.py @@ -0,0 +1,123 @@ +# Usage: +# Please first install awq/kernels +# then directly run CUDA_VISIBLE_DEVICES=0 python benchmark.py +import argparse +import torch +import time +import numpy as np +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils +import tinychat.utils.constants +from tinychat.utils.load_quant import load_awq_model +from awq.quantize.quantizer import real_quantize_model_weight +from tinychat.utils.tune import tune_all_wqlinears, device_warmup +from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp + + +def skip(*args, **kwargs): + pass + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_type", type=str, default="LLaMa", help="type of the model" + ) + parser.add_argument( + "--model_path", + type=str, + default="/data/llm/checkpoints/vicuna-hf/vicuna-7b", + help="path to the model", + ) + parser.add_argument("--q_group_size", type=int, default=128) + parser.add_argument( + "--verbose", + default=False, + action="store_true", + help="Wheter to print more information.", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=2048, + help="maximum sequence length for kv cache" + ) + parser.add_argument( + "--max_batch_size", + type=int, + default=1, + help="maximum batch size for kv cache" + ) + args = parser.parse_args() + + tinychat.utils.constants.max_batch_size = args.max_batch_size + tinychat.utils.constants.max_seq_len = args.max_seq_len + from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM + + modeling_utils._init_weights = False + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.kaiming_normal_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + device = "cuda:0" + # exLLaMA benchmarking parameters. + context_length = 4 + gen_length = 200 + input_ids = [1 for _ in range(context_length)] + + model_type_dict = { + "llama": LlamaForCausalLM, + "falcon": FalconForCausalLM, + "mpt": MPTForCausalLM, + } + + config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) + assert args.model_type.lower() in [ + "llama", + "falcon", + "mpt", + ], "We only support llama & falcon & mpt now" + model = model_type_dict[args.model_type.lower()](config).half() + real_quantize_model_weight( + model, w_bit=4, q_config=dict(q_group_size=args.q_group_size, zero_point=True), init_only=True + ) + model = model.to(device) + + # tune_all_wqlinears(model) + make_quant_attn(model, device) + make_quant_norm(model) + make_fused_mlp(model) + device_warmup(device) + + print("huggingface ckpt loaded") + print(model) + + time_lis = [] + + start_pos = 0 + + print("Benchmarking...") + with torch.inference_mode(): + for i in range(gen_length): + torch.cuda.synchronize() + t_st = time.time() + + if i == 0: + inputs = torch.as_tensor([input_ids], device=device) + else: + inputs = torch.as_tensor([[token]], device=device) + out = model(inputs, start_pos=start_pos) + start_pos += out.shape[1] + + torch.cuda.synchronize() + t_ed = time.time() + time_lis.append(t_ed - t_st) + token = out[:, -1].max(1)[1].unsqueeze(1) + if args.verbose: + print(i, np.median(time_lis)) + + print(f"Speed: {1 / np.median(time_lis)} tokens per second.") + + +if __name__ == "__main__": + main() diff --git a/tinychat/demo.py b/tinychat/demo.py index 966adac..93d095b 100644 --- a/tinychat/demo.py +++ b/tinychat/demo.py @@ -6,38 +6,46 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils from attributedict.collections import AttributeDict from tinychat.stream_generators import StreamGenerator, FalconStreamGenerator +import tinychat.utils.constants from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids +from tinychat.utils.tune import device_warmup, tune_all_wqlinears import os + os.environ["CUDA_VISIBLE_DEVICES"] = "0" # opt_params in TinyLLMEngine -gen_params = AttributeDict([ - ("seed", -1), # RNG seed - ("n_threads", 1), # TODO: fix this - ("n_predict", 512), # new tokens to predict - ("n_parts", -1), # amount of model parts (-1: determine from model dimensions) - ("n_ctx", 512), # context size - ("n_batch", 512), # batch size for prompt processing (must be >=32 to use BLAS) - ("n_keep", 0), # number of tokens to keep from initial prompt - ("n_vocab", 50272), # vocabulary size - - # sampling parameters - ("logit_bias", dict()), # logit bias for specific tokens: - ("top_k", 40), # <= 0 to use vocab size - ("top_p", 0.95), # 1.0 = disabled - ("tfs_z", 1.00), # 1.0 = disabled - ("typical_p", 1.00), # 1.0 = disabled - ("temp", 0.70), # 1.0 = disabled - ("repeat_penalty", 1.10), # 1.0 = disabled - ("repeat_last_n", 64), # last n tokens to penalize (0 = disable penalty, -1 = context size) - ("frequency_penalty", 0.00),# 0.0 = disabled - ("presence_penalty", 0.00), # 0.0 = disabled - ("mirostat", 0), # 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - ("mirostat_tau", 5.00), # target entropy - ("mirostat_eta", 0.10), # learning rate - ]) +gen_params = AttributeDict( + [ + ("seed", -1), # RNG seed + ("n_threads", 1), # TODO: fix this + ("n_predict", 512), # new tokens to predict + ("n_parts", -1), # amount of model parts (-1: determine from model dimensions) + ("n_ctx", 512), # context size + ("n_batch", 512), # batch size for prompt processing (must be >=32 to use BLAS) + ("n_keep", 0), # number of tokens to keep from initial prompt + ("n_vocab", 50272), # vocabulary size + # sampling parameters + ("logit_bias", dict()), # logit bias for specific tokens: + ("top_k", 40), # <= 0 to use vocab size + ("top_p", 0.95), # 1.0 = disabled + ("tfs_z", 1.00), # 1.0 = disabled + ("typical_p", 1.00), # 1.0 = disabled + ("temp", 0.70), # 1.0 = disabled + ("repeat_penalty", 1.10), # 1.0 = disabled + ( + "repeat_last_n", + 64, + ), # last n tokens to penalize (0 = disable penalty, -1 = context size) + ("frequency_penalty", 0.00), # 0.0 = disabled + ("presence_penalty", 0.00), # 0.0 = disabled + ("mirostat", 0), # 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + ("mirostat_tau", 5.00), # target entropy + ("mirostat_eta", 0.10), # learning rate + ] +) + def stream_output(output_stream): print(f"ASSISTANT: ", end="", flush=True) @@ -57,40 +65,77 @@ def stream_output(output_stream): total_tokens = timing["total_tokens"] generation_time_list = timing["generation_time_list"] generation_tokens = len(generation_time_list) - average_speed = (context_time + np.sum(generation_time_list)) / (context_tokens + generation_tokens) + average_speed = (context_time + np.sum(generation_time_list)) / ( + context_tokens + generation_tokens + ) print("=" * 50) print("Speed of Inference") print("-" * 50) # print(f"Context Stage : {context_time/context_tokens * 1000:.2f} ms/token") - print(f"Generation Stage : {np.average(generation_time_list) * 1000:.2f} ms/token") + print( + f"Generation Stage : {np.average(generation_time_list) * 1000:.2f} ms/token" + ) # print(f"Average Speed : {average_speed * 1000:.2f} ms/token") print("=" * 50) # print("token num:", total_tokens) # print("Model total Time = ", (context_time + np.sum(generation_time_list))*1000, "ms" ) return " ".join(output_text) -def device_warmup(device:str): - warm_up = torch.randn((4096,4096)).to(device) - torch.mm(warm_up,warm_up) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model_type', type=str, default='LLaMa', help='type of the model') - parser.add_argument('--model_path', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b', help='path to the model') - parser.add_argument('--precision' , type=str, default='W4A16', help='compute precision') - parser.add_argument('--device' , type=str, default='cuda') - parser.add_argument('--q_group_size', type=int, default=128) - parser.add_argument('--load_quant', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt', help='path to the pre-quanted 4-bit weights') + parser.add_argument( + "--model_type", type=str, default="LLaMa", help="type of the model" + ) + parser.add_argument( + "--model_path", + type=str, + default="/data/llm/checkpoints/vicuna-hf/vicuna-7b", + help="path to the model", + ) + parser.add_argument( + "--precision", type=str, default="W4A16", help="compute precision" + ) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--q_group_size", type=int, default=128) + parser.add_argument( + "--load_quant", + type=str, + default="/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt", + help="path to the pre-quanted 4-bit weights", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=2048, + help="maximum sequence length for kv cache" + ) + parser.add_argument( + "--max_batch_size", + type=int, + default=1, + help="maximum batch size for kv cache" + ) args = parser.parse_args() - assert args.model_type.lower() in ["llama", "falcon", "mpt"], "We only support llama & falcon & mpt now" + assert args.model_type.lower() in [ + "llama", + "falcon", + "mpt", + ], "We only support llama & falcon & mpt now" assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now" gen_params.n_predict = 512 gen_params.n_vocab = 32000 + tinychat.utils.constants.max_batch_size = args.max_batch_size + tinychat.utils.constants.max_seq_len = args.max_seq_len + # TODO (Haotian): a more elegant implementation here. + # We need to update these global variables before models use them. + from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM def skip(*args, **kwargs): pass + torch.nn.init.kaiming_uniform_ = skip torch.nn.init.kaiming_normal_ = skip torch.nn.init.uniform_ = skip @@ -99,38 +144,63 @@ def skip(*args, **kwargs): config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) if "mpt" in config.__class__.__name__.lower(): # config.init_device="meta" - tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + config.tokenizer_name, trust_remote_code=True + ) else: - tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, use_fast=False, trust_remote_code=True + ) modeling_utils._init_weights = False torch.set_default_dtype(torch.half) model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model_type_dict = { + "llama": LlamaForCausalLM, + "falcon": FalconForCausalLM, + "mpt": MPTForCausalLM, + } + if args.precision == "W4A16": if args.model_type.lower() == "llama": - model = load_awq_llama_fast(model, args.load_quant, 4, args.q_group_size, args.device) + model = model_type_dict["llama"](config).half() + model = load_awq_llama_fast( + model, args.load_quant, 4, args.q_group_size, args.device + ) else: - model = load_awq_model(model, args.load_quant, 4, args.q_group_size, args.device) + model = ( + model_type_dict[args.model_type.lower()](config).half() + ) + model = load_awq_model( + model, args.load_quant, 4, args.q_group_size, args.device + ) else: - model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, torch_dtype=torch.float16, trust_remote_code=True).to(args.device) + loaded_model = AutoModelForCausalLM.from_pretrained( + args.model_path, + config=config, + torch_dtype=torch.float16, + trust_remote_code=True, + ) + model = model_type_dict[args.model_type.lower()](config).half().to(args.device) + model.load_state_dict(loaded_model.state_dict()) # device warm up device_warmup(args.device) + # autotune split_k_iters + # tune_all_wqlinears(model) - if args.model_type.lower() == 'falcon': - stream_generator = FalconStreamGenerator - else: - stream_generator = StreamGenerator + # TODO (Haotian): Verify if the StreamGenerator still works for the unmodified falcon impl. + stream_generator = StreamGenerator # Optimize AWQ quantized model - if args.precision == "W4A16" and args.model_type.lower() == 'llama': + if args.precision == "W4A16" and args.model_type.lower() == "llama": from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp + make_quant_attn(model, args.device) make_quant_norm(model) make_fused_mlp(model) - model_prompter = get_prompter(args.model_type, args.model_path) - stop_token_ids = get_stop_token_ids(args.model_type, args.model_path) + stop_token_ids = get_stop_token_ids(args.model_type, args.model_path) count = 0 while True: # Get input from the user @@ -139,7 +209,14 @@ def skip(*args, **kwargs): print("EXIT...") break model_prompter.insert_prompt(input_prompt) - output_stream = stream_generator(model, tokenizer, model_prompter.model_input, gen_params, device=args.device, stop_token_ids = stop_token_ids) - outputs = stream_output(output_stream) + output_stream = stream_generator( + model, + tokenizer, + model_prompter.model_input, + gen_params, + device=args.device, + stop_token_ids=stop_token_ids, + ) + outputs = stream_output(output_stream) model_prompter.update_template(outputs) count += 1 diff --git a/tinychat/figures/4090_example.gif b/tinychat/figures/4090_example.gif index 64857e6..6a36090 100644 Binary files a/tinychat/figures/4090_example.gif and b/tinychat/figures/4090_example.gif differ diff --git a/tinychat/figures/orin_example.gif b/tinychat/figures/orin_example.gif index bd0090e..8e16469 100644 Binary files a/tinychat/figures/orin_example.gif and b/tinychat/figures/orin_example.gif differ diff --git a/tinychat/models/__init__.py b/tinychat/models/__init__.py new file mode 100644 index 0000000..0644a42 --- /dev/null +++ b/tinychat/models/__init__.py @@ -0,0 +1,3 @@ +from .falcon import FalconForCausalLM +from .llama import LlamaForCausalLM +from .mpt import MPTForCausalLM diff --git a/tinychat/models/falcon.py b/tinychat/models/falcon.py new file mode 100644 index 0000000..902adc3 --- /dev/null +++ b/tinychat/models/falcon.py @@ -0,0 +1,296 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Optional, Tuple +from dataclasses import dataclass +import math + +import torch +from torch import nn +import torch.nn.functional as F +import awq_inference_engine + +import tinychat.utils.constants + +max_batch_size = tinychat.utils.constants.max_batch_size +max_seq_len = tinychat.utils.constants.max_seq_len + +# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat( + (-x2, x1), dim=x1.ndim - 1 + ) # dim=-1 triggers a bug in torch < 1.8.0 + + +class RotaryEmbedding(nn.Module): + """Implementation of RotaryEmbedding from GPT-NeoX. + This implementation is design to operate on queries and keys that are compatible with + [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format). + """ + + def __init__( + self, + head_dim: int, + base=10000, + ): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = None + self.batch_size_cached = None + self.cos_cached: torch.Tensor | None = None + self.sin_cached: torch.Tensor | None = None + + def cos_sin( + self, + seq_len: int, + device="cuda", + dtype=torch.bfloat16, + ) -> torch.Tensor: + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + + return self.cos_cached, self.sin_cached + + def forward(self, _q, _k): + batch, seq_len, num_heads, head_dim = _q.shape + q = _q.permute(0, 2, 1, 3).contiguous().reshape(-1, seq_len, head_dim) + k = _k.permute(0, 2, 1, 3).contiguous().reshape(-1, seq_len, head_dim) + cos, sin = self.cos_sin(seq_len, q.device, q.dtype) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +class FalconAttentionFused(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.n_local_heads = args.n_head + self.head_dim = args.hidden_size // args.n_head + + self.query_key_value = nn.Linear( + args.hidden_size, + args.n_head * self.head_dim + 2 * self.head_dim, + bias=False, + ) + + self.dense = nn.Linear( + args.n_head * self.head_dim, + args.hidden_size, + bias=False, + ) + + # following fastertransformer definition + + self.cache_v = ( + torch.zeros( + ( + max_batch_size, + 1, + max_seq_len, + self.head_dim, + ) + ) + .cuda() + .half() + ) # added to half + # 8: pack 8 fp16 in FT, if fp32 then use 4 + self.cache_k = ( + torch.zeros( + ( + max_batch_size, + 1, + self.head_dim // 8, + max_seq_len, + 8, + ) + ) + .cuda() + .half() + ) # added to half + + self.rotary_emb = RotaryEmbedding(self.head_dim) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + + xqkv = self.query_key_value(x) + xqkv = xqkv.view(bsz, seqlen, self.n_local_heads + 2, self.head_dim) + xq = xqkv[:, :, :-2] + xk = xqkv[:, :, [-2]] + xv = xqkv[:, :, [-1]] + + if seqlen > 1: + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, 1, self.head_dim) + xv = xv.view(bsz, seqlen, 1, self.head_dim) + + xq, xk = self.rotary_emb(xq, xk) + xq = ( + xq.reshape(bsz, self.n_local_heads, seqlen, self.head_dim) + .permute(0, 2, 1, 3) + .contiguous() + ) + xk = ( + xk.reshape(bsz, 1, seqlen, self.head_dim) + .permute(0, 2, 1, 3) + .contiguous() + ) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + values_store = xv.transpose(2, 1) + keys_store = ( + xk.reshape(bsz, seqlen, 1, self.head_dim // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) + + self.cache_v[:bsz, :, start_pos : start_pos + seqlen, :] = values_store + self.cache_k[:bsz, :, :, start_pos : start_pos + seqlen, :] = keys_store + + keys = xk + values = xv + + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + else: + # xq = xq[:, 0, :, :] + # xk = xk[:, 0, :, :] + # xv = xv[:, 0, :, :] + xq = xq.view(bsz, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, 1, self.head_dim) + xv = xv.view(bsz, 1, self.head_dim) + + output = awq_inference_engine.single_query_attention( + xq, + xk, + xv, + self.cache_k, + self.cache_v, + None, + # alibi position encodings + None, + start_pos, + self.head_dim, + 10000, + True, + ) + output = output.reshape(bsz, 1, -1) + + return self.dense(output) + + +class FalconMLP(nn.Module): + def __init__( + self, + dim: int, + ): + super().__init__() + self.dense_h_to_4h = nn.Linear(dim, 4 * dim, bias=False) + self.act = nn.GELU() + self.dense_4h_to_h = nn.Linear(4 * dim, dim, bias=False) + + def forward(self, x): + x = self.act(self.dense_h_to_4h(x)) + x = self.dense_4h_to_h(x) + return x + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args): + super().__init__() + self.n_heads = args.n_head + self.dim = args.hidden_size + self.head_dim = args.hidden_size // args.n_head + self.self_attention = FalconAttentionFused(args) + self.mlp = FalconMLP(dim=args.hidden_size) + self.layer_id = layer_id + self.input_layernorm = nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_epsilon + ) + # self.post_attention_layernorm = nn.LayerNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + mask: Optional[torch.Tensor], + ): + layernorm_output = self.input_layernorm(x) + h_attn = x + self.self_attention.forward(layernorm_output, start_pos, mask) + h_mlp = self.mlp(layernorm_output) + out = h_attn + h_mlp + return out + + +class Transformer(nn.Module): + def __init__(self, params): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layer + + self.word_embeddings = nn.Embedding(params.vocab_size, params.hidden_size) + + self.h = torch.nn.ModuleList() + for layer_id in range(params.n_layer): + self.h.append(TransformerBlock(layer_id, params)) + + self.ln_f = nn.LayerNorm(params.hidden_size, eps=params.layer_norm_epsilon) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.word_embeddings(tokens) + + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + for layer in self.h: + h = layer(h, start_pos, mask) + h = self.ln_f(h) + return h + + +class FalconForCausalLM(nn.Module): + def __init__(self, params): + super().__init__() + self.config = params + self.transformer = Transformer(params) + self.lm_head = nn.Linear(params.hidden_size, params.vocab_size, bias=False) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + h = self.transformer(tokens, start_pos) + output = self.lm_head(h) # only compute last logits + return output.float() diff --git a/tinychat/models/llama.py b/tinychat/models/llama.py new file mode 100644 index 0000000..03a69e6 --- /dev/null +++ b/tinychat/models/llama.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Optional, Tuple +from dataclasses import dataclass +import math + +import torch +from torch import nn +import torch.nn.functional as F +import awq_inference_engine +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + +# from flash_attn.flash_attn_interface import flash_attn_unpadded_func + +import tinychat.utils.constants + +max_batch_size = tinychat.utils.constants.max_batch_size +multiple_of = tinychat.utils.constants.llama_multiple_of +max_seq_len = tinychat.utils.constants.max_seq_len + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = torch.empty_like(x) + awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.eps) + return output + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + # k_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_ = torch.view_as_complex( + xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous() + ) + xk_ = torch.view_as_complex( + xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous() + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class LlamaAttentionFused(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.n_local_heads = args.num_attention_heads + self.head_dim = args.hidden_size // args.num_attention_heads + kv_max_seq_len = min(max_seq_len, args.max_position_embeddings) + + self.q_proj = nn.Linear( + args.hidden_size, + args.num_attention_heads * self.head_dim, + bias=False, + ) + self.k_proj = nn.Linear( + args.hidden_size, + args.num_attention_heads * self.head_dim, + bias=False, + ) + self.v_proj = nn.Linear( + args.hidden_size, + args.num_attention_heads * self.head_dim, + bias=False, + ) + self.o_proj = nn.Linear( + args.num_attention_heads * self.head_dim, + args.hidden_size, + bias=False, + ) + + # following fastertransformer definition + + self.cache_v = ( + torch.zeros( + ( + max_batch_size, + self.n_local_heads, + # args.max_position_embeddings, + kv_max_seq_len, + self.head_dim, + ) + ) + .cuda() + .half() + ) # added to half + # 8: pack 8 fp16 in FT, if fp32 then use 4 + self.cache_k = ( + torch.zeros( + ( + max_batch_size, + self.n_local_heads, + self.head_dim // 8, + # args.max_position_embeddings, + kv_max_seq_len, + 8, + ) + ) + .cuda() + .half() + ) # added to half + + # dummy + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=2048, device="cuda:0" + ) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + # xqkv = self.qkv_proj(x) + # xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim) + # xq = xqkv[:, :, 0] + # xk = xqkv[:, :, 1] + # xv = xqkv[:, :, 2] + + xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + if seqlen > 1: + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + values_store = xv.transpose(2, 1) + keys_store = ( + xk.reshape(bsz, seqlen, self.n_local_heads, self.head_dim // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) + + self.cache_v[:bsz, :, start_pos : start_pos + seqlen, :] = values_store + self.cache_k[:bsz, :, :, start_pos : start_pos + seqlen, :] = keys_store + + keys = xk + values = xv + + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + else: + # xq = xq[:, 0, :, :] + # xk = xk[:, 0, :, :] + # xv = xv[:, 0, :, :] + xq = xq.view(bsz, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, self.n_local_heads, self.head_dim) + xv = xv.view(bsz, self.n_local_heads, self.head_dim) + + output = awq_inference_engine.single_query_attention( + xq, + xk, + xv, + self.cache_k, + self.cache_v, + None, + # alibi position encodings + None, + start_pos, + self.head_dim, + 10000, + True, + ) + output = output.reshape(bsz, 1, -1) + + return self.o_proj(output) + + +class LlamaMLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args): + super().__init__() + self.n_heads = args.num_attention_heads + self.dim = args.hidden_size + self.head_dim = args.hidden_size // args.num_attention_heads + self.self_attn = LlamaAttentionFused(args) + self.mlp = LlamaMLP( + dim=args.hidden_size, + hidden_dim=4 * args.hidden_size, + multiple_of=multiple_of, + ) + self.layer_id = layer_id + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + h = x + self.self_attn.forward( + self.input_layernorm(x), start_pos, freqs_cis, mask + ) + out = h + self.mlp.forward(self.post_attention_layernorm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.num_hidden_layers + + self.embed_tokens = nn.Embedding(params.vocab_size, params.hidden_size) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.num_hidden_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.hidden_size, eps=params.rms_norm_eps) + + self.freqs_cis = precompute_freqs_cis( + self.params.hidden_size // self.params.num_attention_heads, + self.params.max_position_embeddings * 2, + ) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.embed_tokens(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + return h + + +class LlamaForCausalLM(nn.Module): + def __init__(self, params): + super().__init__() + self.params = params + self.model = Transformer(params) + self.lm_head = nn.Linear(params.hidden_size, params.vocab_size, bias=False) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + h = self.model(tokens, start_pos) + output = self.lm_head(h) # only compute last logits + return output.float() diff --git a/tinychat/models/mpt.py b/tinychat/models/mpt.py new file mode 100644 index 0000000..1e6be26 --- /dev/null +++ b/tinychat/models/mpt.py @@ -0,0 +1,301 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Optional, Tuple +from dataclasses import dataclass +import math + +import torch +from torch import nn +import torch.nn.functional as F +import awq_inference_engine +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + +# from flash_attn.flash_attn_interface import flash_attn_unpadded_func + +import tinychat.utils.constants +max_batch_size = tinychat.utils.constants.max_batch_size +global_max_seq_len = tinychat.utils.constants.max_seq_len + +def gen_slopes(n_heads, alibi_bias_max=8): + _n_heads = 2 ** math.ceil(math.log2(n_heads)) + m = torch.arange(1, _n_heads + 1, dtype=torch.float32) + m = m.mul(alibi_bias_max / _n_heads) + slopes = 1.0 / torch.pow(2, m) + if _n_heads != n_heads: + slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] + return slopes.view(1, n_heads, 1, 1) + + +def build_alibi_bias( + n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32 +): + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len) + if full: + alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view( + 1, 1, seq_len, 1 + ) + alibi_bias = alibi_bias.abs().mul(-1) + slopes = gen_slopes(n_heads, alibi_bias_max) + alibi_bias = alibi_bias * slopes + slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1) + return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) + + +def _cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == "cuda": + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == "cpu": + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor + + +class LPLayerNorm(torch.nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + ) + + def forward(self, x): + module_device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = ( + _cast_if_autocast_enabled(self.weight) + if self.weight is not None + else self.weight + ) + downcast_bias = ( + _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + ) + with torch.autocast(enabled=False, device_type=module_device.type): + return torch.nn.functional.layer_norm( + downcast_x, + self.normalized_shape, + downcast_weight, + downcast_bias, + self.eps, + ) + + +class SharedEmbedding(nn.Embedding): + def forward(self, input: torch.Tensor, unembed: bool = False) -> torch.Tensor: + if unembed: + return F.linear(input, self.weight) + return super().forward(input) + + +class MPTAttentionFused(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.n_local_heads = args.n_heads + self.head_dim = args.d_model // args.n_heads + args.max_seq_len = min(args.max_seq_len, global_max_seq_len) + + self.Wqkv = nn.Linear( + args.d_model, + args.n_heads * self.head_dim * 3, + bias=False, + ) + + self.out_proj = nn.Linear( + args.n_heads * self.head_dim, + args.d_model, + bias=False, + ) + + # following fastertransformer definition + + self.cache_v = ( + torch.zeros( + ( + max_batch_size, + self.n_local_heads, + args.max_seq_len, + self.head_dim, + ) + ) + .cuda() + .half() + ) # added to half + # 8: pack 8 fp16 in FT, if fp32 then use 4 + self.cache_k = ( + torch.zeros( + ( + max_batch_size, + self.n_local_heads, + self.head_dim // 8, + args.max_seq_len, + 8, + ) + ) + .cuda() + .half() + ) # added to half + + alibi_slopes, alibi_bias = build_alibi_bias( + self.n_local_heads, args.max_seq_len + ) + # TODO (Haotian): fix device + self.alibi_slopes = alibi_slopes.float().to("cuda:0") + self.alibi_bias = alibi_bias.to("cuda:0") + + def forward( + self, + x: torch.Tensor, + start_pos: int, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + xqkv = self.Wqkv(x) + xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim) + xq = xqkv[:, :, 0] + xk = xqkv[:, :, 1] + xv = xqkv[:, :, 2] + + if seqlen > 1: + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + values_store = xv.transpose(2, 1) + keys_store = ( + xk.reshape(bsz, seqlen, self.n_local_heads, self.head_dim // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) + + self.cache_v[:bsz, :, start_pos : start_pos + seqlen, :] = values_store + self.cache_k[:bsz, :, :, start_pos : start_pos + seqlen, :] = keys_store + + keys = xk + values = xv + + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + scores += self.alibi_bias[..., :seqlen] + if mask is not None: + scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + else: + # xq = xq[:, 0, :, :] + # xk = xk[:, 0, :, :] + # xv = xv[:, 0, :, :] + xq = xq.view(bsz, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, self.n_local_heads, self.head_dim) + xv = xv.view(bsz, self.n_local_heads, self.head_dim) + output = awq_inference_engine.single_query_attention( + xq, + xk, + xv, + self.cache_k, + self.cache_v, + None, + # with alibi embedding + self.alibi_slopes.float(), + start_pos, + # rotary embed dim = 0 => no rotary embedding + 0, + 10000, + True, + ) + output = output.reshape(bsz, 1, -1) + + return self.out_proj(output) + + +class MPTMLP(nn.Module): + def __init__(self, d_model: int, expansion_ratio: int): + super().__init__() + self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, bias=False) + self.act = nn.GELU(approximate="none") + self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, bias=False) + self.down_proj._is_residual = True + + def forward(self, x): + return self.down_proj(self.act(self.up_proj(x))) + + +class MPTBlock(nn.Module): + def __init__(self, layer_id: int, args): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.d_model + self.head_dim = args.d_model // args.n_heads + self.attn = MPTAttentionFused(args) + self.ffn = MPTMLP(d_model=args.d_model, expansion_ratio=4) + self.layer_id = layer_id + self.norm_1 = LPLayerNorm(args.d_model, eps=1e-6) + self.norm_2 = LPLayerNorm(args.d_model, eps=1e-6) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + mask: Optional[torch.Tensor], + ): + h = x + self.attn.forward(self.norm_1(x), start_pos, mask) + out = h + self.ffn.forward(self.norm_2(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.wte = SharedEmbedding(params.vocab_size, params.d_model) + + self.blocks = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.blocks.append(MPTBlock(layer_id, params)) + + self.norm_f = LPLayerNorm(params.d_model, eps=1e-6) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.wte(tokens) + + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + for layer in self.blocks: + h = layer(h, start_pos, mask) + h = self.norm_f(h) + return h + + +class MPTForCausalLM(nn.Module): + def __init__(self, params): + super().__init__() + self.config = params + self.transformer = Transformer(params) + if params.no_bias: + for module in self.modules(): + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): + module.register_parameter("bias", None) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + h = self.transformer(tokens, start_pos) + output = self.transformer.wte(h, unembed=True) # only compute last logits + return output.float() diff --git a/tinychat/modules/fused_attn.py b/tinychat/modules/fused_attn.py index 2615ce7..094f9f3 100644 --- a/tinychat/modules/fused_attn.py +++ b/tinychat/modules/fused_attn.py @@ -2,10 +2,18 @@ import torch import torch.nn as nn from torch.nn import functional as F -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb - +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, +) +from typing import Optional from awq.quantize.qmodule import WQLinear import awq_inference_engine +from tinychat.models.llama import apply_rotary_emb + + +max_batch_size: int = 1 class QuantLlamaRotaryEmbedding(nn.Module): @@ -15,29 +23,35 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - + cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) - + # self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) # self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_sin_cache", cache.half(), persistent=False) - + def forward( self, query: torch.Tensor, @@ -58,30 +72,36 @@ def forward( ) return query, key + class QuantLlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__( - self, - hidden_size, - num_heads, - qkv_proj, - o_proj, - dev - ): + def __init__(self, hidden_size, num_heads, qkv_proj, o_proj, dev): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads if (self.head_dim * num_heads) != self.hidden_size: - raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads}).") + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) self.qkv_proj = qkv_proj self.o_proj = o_proj - self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=2048, device = dev) + self.rotary_emb = QuantLlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=2048, device=dev + ) - def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): + def forward( + self, + hidden_states, + past_key_value=None, + attention_mask=None, + position_ids=None, + output_attentions=False, + use_cache=False, + ): """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() @@ -91,19 +111,27 @@ def forward(self, hidden_states, past_key_value=None, attention_mask=None, posit # This updates the query and key states in-place, saving VRAM. query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2) - query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) - + query_states, key_states = self.rotary_emb( + query_states, key_states, position_ids + ) + del qkv_states - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) is_causal = past_key_value is None kv_seq_len = q_len if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - + value_states = value_states.to("cuda:0") if past_key_value is not None: @@ -121,7 +149,9 @@ def forward(self, hidden_states, past_key_value=None, attention_mask=None, posit past_key_value = (key_states, value_states) if use_cache else None # with torch.backends.cuda.sdp_kernel(enable_math=False): - attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) + attn_output = F.scaled_dot_product_attention( + query_states, key_states, value_states, is_causal=is_causal + ) del query_states, key_states, value_states attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) @@ -130,45 +160,178 @@ def forward(self, hidden_states, past_key_value=None, attention_mask=None, posit return attn_output, None, past_key_value +class QuantLlamaAttentionFused(nn.Module): + def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, args): + super().__init__() + self.hidden_size = hidden_size + self.n_local_heads = num_heads + self.head_dim = self.hidden_size // num_heads + self.qkv_proj = qkv_layer + self.o_proj = o_proj + + # following fastertransformer definition + self.cache_v = ( + torch.zeros( + ( + max_batch_size, + self.n_local_heads, + args.max_position_embeddings, + self.head_dim, + ) + ) + .to(dev) + .half() + ) # added to half + # 8: pack 8 fp16 in FT, if fp32 then use 4 + self.cache_k = ( + torch.zeros( + ( + max_batch_size, + self.n_local_heads, + self.head_dim // 8, + args.max_position_embeddings, + 8, + ) + ) + .to(dev) + .half() + ) # added to half + + # dummy + self.rotary_emb = QuantLlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=2048, device="cuda:0" + ) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + xqkv = self.qkv_proj(x) + xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim) + xq = xqkv[:, :, 0] + xk = xqkv[:, :, 1] + xv = xqkv[:, :, 2] + + if seqlen > 1: + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + values_store = xv.transpose(2, 1) + keys_store = ( + xk.reshape(bsz, seqlen, self.n_local_heads, self.head_dim // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) + + self.cache_v[:bsz, :, start_pos : start_pos + seqlen, :] = values_store + self.cache_k[:bsz, :, :, start_pos : start_pos + seqlen, :] = keys_store + + keys = xk + values = xv + + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + else: + xq = xq[:, 0, :, :] + xk = xk[:, 0, :, :] + xv = xv[:, 0, :, :] + output = awq_inference_engine.single_query_attention( + xq, + xk, + xv, + self.cache_k, + self.cache_v, + None, + None, + start_pos, + self.head_dim, + 10000, + True, + ) + output = output.reshape(bsz, 1, -1) + + return self.o_proj(output) + + def make_quant_attn(model, dev): """ Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. """ - + model = model.cpu() for name, m in model.named_modules(): - if not isinstance(m, LlamaAttention): + if not m.__class__.__name__ in ["LlamaAttention", "LlamaAttentionFused"]: continue q_proj = m.q_proj k_proj = m.k_proj v_proj = m.v_proj - qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) + qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) + scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) # g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) g_idx = None - bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + bias = ( + torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) + if q_proj.bias is not None + else None + ) - qkv_layer = WQLinear(q_proj.w_bit, q_proj.group_size, q_proj.in_features, q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.bias is not None, q_proj.qweight.device) + qkv_layer = WQLinear( + q_proj.w_bit, + q_proj.group_size, + q_proj.in_features, + q_proj.out_features + k_proj.out_features + v_proj.out_features, + q_proj.bias is not None, + q_proj.qweight.device, + ) qkv_layer.qweight = qweights qkv_layer.qzeros = qzeros qkv_layer.scales = scales qkv_layer.bias = bias + qkv_layer.split_k_iters = q_proj.split_k_iters # We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch. - attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, dev) - - if '.' in name: - parent_name = name.rsplit('.', 1)[0] - child_name = name[len(parent_name) + 1:] + if isinstance(m, LlamaAttention): + attn = QuantLlamaAttention( + m.hidden_size, m.num_heads, qkv_layer, m.o_proj, dev + ) + else: + attn = QuantLlamaAttentionFused( + m.args.hidden_size, + m.args.num_attention_heads, + qkv_layer, + m.o_proj, + dev, + m.args, + ) + if "." in name: + parent_name = name.rsplit(".", 1)[0] + child_name = name[len(parent_name) + 1 :] parent = model.get_submodule(parent_name) else: - parent_name = '' + parent_name = "" parent = model child_name = name - #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") - + # print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") setattr(parent, child_name, attn) + model = model.to(dev) diff --git a/tinychat/modules/fused_mlp.py b/tinychat/modules/fused_mlp.py index 89daf75..5070393 100644 --- a/tinychat/modules/fused_mlp.py +++ b/tinychat/modules/fused_mlp.py @@ -9,7 +9,6 @@ class QuantLlamaMLP(nn.Module): - def __init__( self, gate_proj, @@ -17,45 +16,72 @@ def __init__( up_proj, ): super().__init__() - self.register_buffer('gate_proj_qweight', gate_proj.qweight) - self.register_buffer('gate_proj_scales', gate_proj.scales) - self.register_buffer('gate_proj_qzeros', gate_proj.qzeros) - self.register_buffer('up_proj_qweight', up_proj.qweight) - self.register_buffer('up_proj_scales', up_proj.scales) - self.register_buffer('up_proj_qzeros', up_proj.qzeros) + self.register_buffer("gate_proj_qweight", gate_proj.qweight) + self.register_buffer("gate_proj_scales", gate_proj.scales) + self.register_buffer("gate_proj_qzeros", gate_proj.qzeros) + self.register_buffer("up_proj_qweight", up_proj.qweight) + self.register_buffer("up_proj_scales", up_proj.scales) + self.register_buffer("up_proj_qzeros", up_proj.qzeros) self.in_features = gate_proj.in_features self.intermediate_size = gate_proj.out_features self.out_features = down_proj.out_features self.w_bit = gate_proj.w_bit self.down_proj = down_proj + self.split_k_iters = down_proj.split_k_iters def forward(self, x): return self.down_proj(self.our_llama_mlp(x)) - + def our_llama_mlp(self, x): - out_shape = x.shape[:-1] + (self.intermediate_size, ) + out_shape = x.shape[:-1] + (self.intermediate_size,) x = x.reshape(-1, x.shape[-1]) + """ gate_output = awq_inference_engine.gemm_forward_cuda( - x, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, 8 + x, + self.gate_proj_qweight, + self.gate_proj_scales, + self.gate_proj_qzeros, + self.split_k_iters, + ) + """ + gate_output = awq_inference_engine.gemv_forward_cuda( + x, + self.gate_proj_qweight, + self.gate_proj_scales, + self.gate_proj_qzeros, + self.down_proj.group_size, ) gate_output = F.silu(gate_output) + """ up_output = awq_inference_engine.gemm_forward_cuda( - x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8 + x, + self.up_proj_qweight, + self.up_proj_scales, + self.up_proj_qzeros, + self.split_k_iters, + ) + """ + up_output = awq_inference_engine.gemv_forward_cuda( + x, + self.up_proj_qweight, + self.up_proj_scales, + self.up_proj_qzeros, + self.down_proj.group_size, ) c = gate_output * up_output c = c.reshape(out_shape) return c -def make_fused_mlp(m, parent_name=''): +def make_fused_mlp(m, parent_name=""): if not hasattr(make_fused_mlp, "called"): # print("[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.") make_fused_mlp.called = True """ Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. """ - if isinstance(m, LlamaMLP): + if m.__class__.__name__ in ["LlamaMLP"]: return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj) for name, child in m.named_children(): diff --git a/tinychat/modules/fused_norm.py b/tinychat/modules/fused_norm.py index 50f49c3..e8e1f0d 100644 --- a/tinychat/modules/fused_norm.py +++ b/tinychat/modules/fused_norm.py @@ -3,6 +3,7 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm import awq_inference_engine + class FTLlamaRMSNorm(nn.Module): def __init__(self, weight, eps=1e-6): """ @@ -14,10 +15,12 @@ def __init__(self, weight, eps=1e-6): def forward(self, x): output = torch.empty_like(x) - awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) - return output - - + awq_inference_engine.layernorm_forward_cuda( + x, self.weight, output, self.variance_epsilon + ) + return output + + def make_quant_norm(model): """ Replace all LlamaRMSNorm modules with FTLlamaRMSNorm modules @@ -29,15 +32,15 @@ def make_quant_norm(model): norm = FTLlamaRMSNorm(m.weight, m.variance_epsilon) - if '.' in name: - parent_name = name.rsplit('.', 1)[0] - child_name = name[len(parent_name) + 1:] + if "." in name: + parent_name = name.rsplit(".", 1)[0] + child_name = name[len(parent_name) + 1 :] parent = model.get_submodule(parent_name) else: - parent_name = '' + parent_name = "" parent = model child_name = name - #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") + # print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") setattr(parent, child_name, norm) diff --git a/tinychat/stream_generators/__init__.py b/tinychat/stream_generators/__init__.py index a199230..060c64f 100644 --- a/tinychat/stream_generators/__init__.py +++ b/tinychat/stream_generators/__init__.py @@ -1,2 +1,2 @@ from .falcon_stream_gen import * -from .stream_gen import * \ No newline at end of file +from .stream_gen import * diff --git a/tinychat/stream_generators/falcon_stream_gen.py b/tinychat/stream_generators/falcon_stream_gen.py index 645810a..15e66a9 100644 --- a/tinychat/stream_generators/falcon_stream_gen.py +++ b/tinychat/stream_generators/falcon_stream_gen.py @@ -9,6 +9,7 @@ transformers.logging.set_verbosity_error() + def is_partial_stop(output: str, stop_str: str): """Check whether the output contains a partial stop str.""" for i in range(0, min(len(output), len(stop_str))): @@ -21,15 +22,15 @@ def is_partial_stop(output: str, stop_str: str): def FalconStreamGenerator( model, tokenizer, - input : str, - gen_params : dict, + input: str, + gen_params: dict, device: str = "cuda:0", - context_len = 2048, - stream_interval = 2, - judge_sent_end = False, + context_len=2048, + stream_interval=2, + judge_sent_end=False, echo: bool = False, stop_str: str = "\nUser", - stop_token_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + stop_token_ids=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], ): prompt = input len_prompt = len(prompt) @@ -54,14 +55,14 @@ def FalconStreamGenerator( streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) generation_config = GenerationConfig( - max_new_tokens = max_new_tokens, - do_sample = gen_params.temp >= 1e-5, - temperature = gen_params.temp, - repetition_penalty = gen_params.repeat_penalty, - no_repeat_ngram_size = 10, - top_p = gen_params.top_p, - top_k = top_k, - eos_token_id = stop_token_ids, + max_new_tokens=max_new_tokens, + do_sample=gen_params.temp >= 1e-5, + temperature=gen_params.temp, + repetition_penalty=gen_params.repeat_penalty, + no_repeat_ngram_size=10, + top_p=gen_params.top_p, + top_k=top_k, + eos_token_id=stop_token_ids, ) generation_kwargs = dict( @@ -142,4 +143,4 @@ def FalconStreamGenerator( # clean gc.collect() - torch.cuda.empty_cache() \ No newline at end of file + torch.cuda.empty_cache() diff --git a/tinychat/stream_generators/stream_gen.py b/tinychat/stream_generators/stream_gen.py index 656e39c..424ef10 100644 --- a/tinychat/stream_generators/stream_gen.py +++ b/tinychat/stream_generators/stream_gen.py @@ -15,6 +15,7 @@ total_tokens = 0 generation_time_list = [] + def prepare_logits_processor( temperature: float, repetition_penalty: float, top_p: float, top_k: int ) -> LogitsProcessorList: @@ -32,14 +33,15 @@ def prepare_logits_processor( @torch.inference_mode() -def StreamGenerator(model, - tokenizer, - input : str, - gen_params : dict, - device: str = "cuda:0", - stream_interval: int = 2, - echo: bool = False, - stop_token_ids = [] +def StreamGenerator( + model, + tokenizer, + input: str, + gen_params: dict, + device: str = "cuda:0", + stream_interval: int = 2, + echo: bool = False, + stop_token_ids=[], ): input_ids = tokenizer(input).input_ids input_echo_len = len(input_ids) @@ -54,26 +56,43 @@ def StreamGenerator(model, logits_processor = prepare_logits_processor( gen_params.temp, gen_params.repeat_penalty, gen_params.top_p, top_k ) - + past_key_values = out = None stop_token_ids.append(tokenizer.eos_token_id) max_new_tokens = gen_params.n_predict + start_pos = 0 for i in range(max_new_tokens): torch.cuda.synchronize() t_st = time.time() - if i == 0: # Context Stage - out = model(torch.as_tensor([input_ids], device=device), use_cache=True) - logits = out.logits - past_key_values = out.past_key_values + if i == 0: + inputs = torch.as_tensor([input_ids], device=device) else: - out = model( - input_ids=torch.as_tensor([[token]], device=device), - use_cache=True, - past_key_values=past_key_values, - ) - logits = out.logits - past_key_values = out.past_key_values + inputs = torch.as_tensor([[token]], device=device) + + if ( + "llama" not in model.__class__.__name__.lower() + and "mpt" not in model.__class__.__name__.lower() + and "falcon" not in model.__class__.__name__.lower() + ): + if i == 0: # Context Stage + out = model(inputs, use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + else: + out = model( + input_ids=inputs, + use_cache=True, + past_key_values=past_key_values, + ) + logits = out.logits + past_key_values = out.past_key_values + else: + out = model(inputs, start_pos=start_pos) + start_pos += out.shape[1] + logits = out + torch.cuda.synchronize() + t_ed = time.time() # Processing the logits if logits_processor: @@ -91,9 +110,6 @@ def StreamGenerator(model, token = int(torch.multinomial(probs, num_samples=1)) output_ids.append(token) - torch.cuda.synchronize() - t_ed = time.time() - global context_time global context_tokens global total_tokens @@ -103,13 +119,12 @@ def StreamGenerator(model, context_tokens = logits.shape[1] generation_time_list = [] else: - generation_time_list.append(t_ed-t_st) - + generation_time_list.append(t_ed - t_st) + if token in stop_token_ids: stopped = True else: - stopped = False - + stopped = False if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: if echo: @@ -151,7 +166,7 @@ def StreamGenerator(model, else: finish_reason = None - total_tokens = (context_tokens + len(generation_time_list)) + total_tokens = context_tokens + len(generation_time_list) yield { "text": output, "usage": { @@ -160,16 +175,16 @@ def StreamGenerator(model, "total_tokens": input_echo_len + i, }, "finish_reason": finish_reason, - "timing":{ + "timing": { "context_tokens": context_tokens, "context_time": context_time, "total_tokens": total_tokens, "generation_time_list": generation_time_list, - } + }, } del past_key_values, out gc.collect() torch.cuda.empty_cache() - # return context_tokens, context_time, total_tokens, generation_time_list \ No newline at end of file + # return context_tokens, context_time, total_tokens, generation_time_list diff --git a/tinychat/utils/__init__.py b/tinychat/utils/__init__.py new file mode 100644 index 0000000..406492e --- /dev/null +++ b/tinychat/utils/__init__.py @@ -0,0 +1,3 @@ +import tinychat.utils.constants as constants + +constants.init() \ No newline at end of file diff --git a/tinychat/utils/constants.py b/tinychat/utils/constants.py new file mode 100644 index 0000000..10f0ce5 --- /dev/null +++ b/tinychat/utils/constants.py @@ -0,0 +1,7 @@ +import torch + +def init(): + global max_seq_len, max_batch_size, llama_multiple_of + max_seq_len = 2048 + max_batch_size = 1 + llama_multiple_of = 256 \ No newline at end of file diff --git a/tinychat/utils/load_quant.py b/tinychat/utils/load_quant.py index 80f99ea..7c86983 100644 --- a/tinychat/utils/load_quant.py +++ b/tinychat/utils/load_quant.py @@ -6,60 +6,91 @@ from awq.quantize.qmodule import WQLinear from tqdm import tqdm + def load_awq_model(model, checkpoint, w_bit, group_size, device): q_config = {"zero_point": True, "q_group_size": group_size} - real_quantize_model_weight(model, w_bit, q_config, init_only = True) + real_quantize_model_weight(model, w_bit, q_config, init_only=True) pbar = tqdm(range(1)) - pbar.set_description('Loading checkpoint') + pbar.set_description("Loading checkpoint") for i in pbar: if hasattr(model.config, "tie_encoder_decoder"): model.config.tie_encoder_decoder = False if hasattr(model.config, "tie_word_embeddings"): model.config.tie_word_embeddings = False model = load_checkpoint_and_dispatch( - model, checkpoint, + model, + checkpoint, no_split_module_classes=[ - "OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"] + "OPTDecoderLayer", + "LlamaDecoderLayer", + "BloomBlock", + "MPTBlock", + "DecoderLayer", + ], ).to(device) return model -def make_quant_linear(module, names, w_bit, groupsize, device, name=''): +def make_quant_linear(module, names, w_bit, groupsize, device, name=""): if isinstance(module, WQLinear): return for attr in dir(module): tmp = getattr(module, attr) - name1 = name + '.' + attr if name != '' else attr + name1 = name + "." + attr if name != "" else attr if name1 in names: delattr(module, attr) - setattr(module, attr, WQLinear(w_bit, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None, device)) + setattr( + module, + attr, + WQLinear( + w_bit, + groupsize, + tmp.in_features, + tmp.out_features, + tmp.bias is not None, + device, + ), + ) for name1, child in module.named_children(): - make_quant_linear(child, names, w_bit, groupsize, device, name + '.' + name1 if name != '' else name1) + make_quant_linear( + child, + names, + w_bit, + groupsize, + device, + name + "." + name1 if name != "" else name1, + ) + -def find_layers(module, layers=[nn.Linear], name=''): +def find_layers(module, layers=[nn.Linear], name=""): if type(module) in layers: return {name: module} res = {} for name1, child in module.named_children(): - res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) + res.update( + find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) return res def load_awq_llama_fast(model, checkpoint, w_bit, group_size, device): layers = find_layers(model) - for name in ['lm_head']: + for name in ["lm_head"]: if name in layers: del layers[name] make_quant_linear(model, layers, w_bit, group_size, device) del layers pbar = tqdm(range(1)) - pbar.set_description('Loading checkpoint') + pbar.set_description("Loading checkpoint") for i in pbar: - if checkpoint.endswith('.safetensors'): + if checkpoint.endswith(".safetensors"): from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) else: model.load_state_dict(torch.load(checkpoint)) - return model.to(device) \ No newline at end of file + return model.to(device) diff --git a/tinychat/utils/prompt_templates.py b/tinychat/utils/prompt_templates.py index 4e18762..a004cc6 100644 --- a/tinychat/utils/prompt_templates.py +++ b/tinychat/utils/prompt_templates.py @@ -1,12 +1,21 @@ from typing import List + class BasePrompter: - def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None): - self.system_inst = system_inst # System Instruction - self.role1 = role1 # The name of USER - self.role2 = role2 # The name of AI-Assistant - self.sen_spliter = sen_spliter # How to split system/user/assistant outputs - self.qa_spliter = qa_spliter # How to split Q&A rounds + def __init__( + self, + system_inst, + role1, + role2, + sen_spliter="\n", + qa_spliter="\n", + decorator: List[str] = None, + ): + self.system_inst = system_inst # System Instruction + self.role1 = role1 # The name of USER + self.role2 = role2 # The name of AI-Assistant + self.sen_spliter = sen_spliter # How to split system/user/assistant outputs + self.qa_spliter = qa_spliter # How to split Q&A rounds self.decorator = decorator if self.decorator == None: self.starter = "" @@ -15,27 +24,66 @@ def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = " self.starter = self.decorator[0] self.stopper = self.decorator[1] if self.system_inst == None: - self.template = self.starter + self.role1 + ": {prompt}" + self.stopper + self.sen_spliter \ - + self.starter + self.role2 + ":" + self.template = ( + self.starter + + self.role1 + + ": {prompt}" + + self.stopper + + self.sen_spliter + + self.starter + + self.role2 + + ":" + ) else: - self.template = self.starter + self.system_inst + self.stopper + self.sen_spliter \ - + self.starter + self.role1 + ": {prompt}" + self.stopper + self.sen_spliter \ - + self.starter + self.role2 + ":" + self.template = ( + self.starter + + self.system_inst + + self.stopper + + self.sen_spliter + + self.starter + + self.role1 + + ": {prompt}" + + self.stopper + + self.sen_spliter + + self.starter + + self.role2 + + ":" + ) self.model_input = None - + def insert_prompt(self, input_prompt): - self.model_input = self.template.format(prompt=input_prompt) + self.model_input = self.template.format(prompt=input_prompt) def update_template(self, outputs): - self.template = self.model_input + " " + outputs.strip() + self.stopper + self.qa_spliter \ - + self.starter + self.role1 + ": {prompt}" + self.stopper + self.sen_spliter \ - + self.starter + self.role2 + ":" + self.template = ( + self.model_input + + " " + + outputs.strip() + + self.stopper + + self.qa_spliter + + self.starter + + self.role1 + + ": {prompt}" + + self.stopper + + self.sen_spliter + + self.starter + + self.role2 + + ":" + ) self.model_input = None - + + class OneShotBasePrompter(BasePrompter): - def __init__(self, - oneshot_example: List[str], # User prompt + Assistant responce - system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None): + def __init__( + self, + oneshot_example: List[str], # User prompt + Assistant responce + system_inst, + role1, + role2, + sen_spliter="\n", + qa_spliter="\n", + decorator: List[str] = None, + ): super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter) assert len(oneshot_example) == 2, "One-shot example must be a List of 2 strs." self.user_example = oneshot_example[0] @@ -53,6 +101,7 @@ def __init__(self): qa_spliter = "" super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter) + class Llama2Prompter(OneShotBasePrompter): def __init__(self): system_inst = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." @@ -60,19 +109,23 @@ def __init__(self): role2 = "### Assistant" sen_spliter = "\n" qa_spliter = "" - user_example="Got any creative ideas for a 10 year old's birthday?" - assistant_example = "Of course! Here are some creative ideas for a 10-year-old's birthday party:\n" \ - + "1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.\n" \ - + "2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.\n" \ - + "3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.\n" \ - + "4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.\n" \ - + "5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.\n" \ - + "6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.\n" \ - + "7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.\n" \ - + "8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.\n" \ - + "Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!" + user_example = "Got any creative ideas for a 10 year old's birthday?" + assistant_example = ( + "Of course! Here are some creative ideas for a 10-year-old's birthday party:\n" + + "1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.\n" + + "2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.\n" + + "3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.\n" + + "4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.\n" + + "5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.\n" + + "6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.\n" + + "7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.\n" + + "8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.\n" + + "Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!" + ) oneshot_example = [user_example, assistant_example] - super().__init__(oneshot_example, system_inst, role1, role2, sen_spliter, qa_spliter) + super().__init__( + oneshot_example, system_inst, role1, role2, sen_spliter, qa_spliter + ) class FalconSimplePrompter(BasePrompter): @@ -87,12 +140,14 @@ def __init__(self): class FalconPrompter(BasePrompter): def __init__(self): - system_inst = "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, "\ - + "and a human user, called User. In the following interactions, User and Falcon will converse in natural language, "\ - + "and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. "\ - + "Falcon was built by the Technology Innovation Institute in Abu Dhabi. "\ - + "Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. "\ - + "It knows a lot, and always tells the truth. The conversation begins." + system_inst = ( + "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, " + + "and a human user, called User. In the following interactions, User and Falcon will converse in natural language, " + + "and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. " + + "Falcon was built by the Technology Innovation Institute in Abu Dhabi. " + + "Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. " + + "It knows a lot, and always tells the truth. The conversation begins." + ) role1 = "User" role2 = "Assistant" sen_spliter = "\n" @@ -109,13 +164,16 @@ def __init__(self): qa_spliter = "\n" super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter) + class MPTChatPrompter(BasePrompter): def __init__(self): - system_inst = "system\n" \ - + "- You are a helpful assistant chatbot trained by MosaicML.\n" \ - + "- You answer questions.\n" \ - + "- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n" \ - + "- You are more than just an information source, you are also able to write poetry, short stories, and make jokes." + system_inst = ( + "system\n" + + "- You are a helpful assistant chatbot trained by MosaicML.\n" + + "- You answer questions.\n" + + "- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n" + + "- You are more than just an information source, you are also able to write poetry, short stories, and make jokes." + ) role1 = "user" role2 = "assistant" sen_spliter = "\n" @@ -124,8 +182,7 @@ def __init__(self): super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter, decorator) - -def get_prompter(model_type, model_path = ""): +def get_prompter(model_type, model_path=""): if model_type.lower() == "llama": if "vicuna" in model_path: return VicunaPrompter() @@ -142,7 +199,8 @@ def get_prompter(model_type, model_path = ""): else: raise ValueError(f"model type {model_type} is not supported") -def get_stop_token_ids(model_type, model_path = ""): + +def get_stop_token_ids(model_type, model_path=""): if model_type.lower() == "llama": return [] elif model_type.lower() == "falcon": diff --git a/tinychat/utils/tune.py b/tinychat/utils/tune.py new file mode 100644 index 0000000..2ca231c --- /dev/null +++ b/tinychat/utils/tune.py @@ -0,0 +1,61 @@ +import numpy as np +import time +import torch +from awq.quantize.qmodule import WQLinear + + +__all__ = ["device_warmup", "tune_all_wqlinears"] + + +def device_warmup(device: str): + warm_up = torch.randn((4096, 4096)).to(device) + for i in range(100): + torch.mm(warm_up, warm_up) + + +def _time_module(module, inputs, measure_iters=1000): + time_lis = [] + # Warmup + for i in range(measure_iters): + module(inputs) + for i in range(measure_iters): + torch.cuda.synchronize() + st = time.time() + module(inputs) + torch.cuda.synchronize() + ed = time.time() + time_lis.append((ed - st)) + return np.median(time_lis) + + +def tune_wqlinear(module: WQLinear, measure_iters: int = 1000): + device_warmup(str(module.scales.device)) + inputs = torch.randn( + 1, module.in_features, device=module.scales.device, dtype=module.scales.dtype + ) + best_split_k_iter = None + best_latency = None + for split_k_iters in [1, 2, 4, 8, 16, 32]: + module.split_k_iters = split_k_iters + cur_latency = _time_module(module, inputs, measure_iters) + if best_split_k_iter is None or best_latency >= cur_latency: + best_split_k_iter = split_k_iters + best_latency = cur_latency + module.split_k_iters = best_split_k_iter + return best_split_k_iter + + +def tune_all_wqlinears(model, measure_iters: int = 1000): + tuned_results = dict() + for name, module in model.named_modules(): + if isinstance(module, WQLinear): + ic, oc = module.in_features, module.out_features + if (ic, oc) not in tuned_results: + print(f"Tuning {(ic, oc)}...") + split_k_iters = tune_wqlinear(module) + tuned_results[(ic, oc)] = split_k_iters + # write configs to model + for name, module in model.named_modules(): + if isinstance(module, WQLinear): + ic, oc = module.in_features, module.out_features + module.split_k_iters = tuned_results[(ic, oc)]