Skip to content

Commit

Permalink
Upgrade qbits backend (#460)
Browse files Browse the repository at this point in the history
* avx_vnni supporting

* disable template combine when using avx_vnni kernel and src_dt==bf16

* fix windows compile

* better err log

* fix avx2 f4 padding

Signed-off-by: Wang,Zhe <zhe1.wang@intel.com>

---------

Signed-off-by: Wang,Zhe <zhe1.wang@intel.com>
  • Loading branch information
zhewang1-intc committed Oct 13, 2023
1 parent 98e5f9a commit 45e03b9
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ static inline __m128i unpack_4bits_sse(void* srcptr) {
template <JBLAS_SIGN_INT_TYPE S4_T>
static inline void convert_s4_s8_16_sse(int8_t* dstptr, int8_t* srcptr) {
auto dst0 = unpack_4bits_sse<S4_T>(srcptr);
if constexpr (S4_T == S4_FULLRANGE) {
auto s8 = _mm_set1_epi8(8);
dst0 = _mm_sub_epi8(dst0, s8);
}
_mm_storeu_si128((__m128i*)dstptr, dst0);
}

static inline void fp4_pad_4bit(int8_t* dstptr, int8_t* srcptr) {
auto dst0 = unpack_4bits_sse<S4_FULLRANGE>(srcptr);
_mm_storeu_si128((__m128i*)dstptr, dst0);
}

Expand Down Expand Up @@ -141,8 +150,6 @@ static inline JBLAS_CODE dequant_kblock_s8_f32(int8_t* srcptr, float* dstptr, in
kblock, NPad);
}

constexpr void (*pad_fp4)(int8_t* dstptr, int8_t* srcptr) = &convert_s4_s8_16_sse<S4_FULLRANGE>;

template <JBLAS_SIGN_INT_TYPE S4_T>
static inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src,
int ld_dst) {
Expand Down Expand Up @@ -269,10 +276,10 @@ static inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* ds
int ld_dst, _ST* scales, int k_offset, int kblock, int NPad) {
if constexpr (std::is_same<_DST_T, float>::value) {
return decompress_kblock_bit4_fp32<_ST>(srcptr, (float*)dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset,
kblock, NPad, &dequant_f4_N<48, float, F4_T>, pad_fp4);
kblock, NPad, &dequant_f4_N<48, float, F4_T>, fp4_pad_4bit);
} else if constexpr (std::is_same<_DST_T, utils::bf16>::value) {
return decompress_kblock_bit4_bf16<_ST>(srcptr, (utils::bf16*)dstptr, row, col, ld_src, ld_dst, scales, nullptr,
k_offset, kblock, NPad, &dequant_f4_N<64, utils::bf16, F4_T>, pad_fp4);
k_offset, kblock, NPad, &dequant_f4_N<64, utils::bf16, F4_T>, fp4_pad_4bit);
}
return JblasNotSupport;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "../include/jblas_weightonly_dispatcher.hpp"
#include "../include/jblas_customop.hpp"
#include "../include/dispatcher_utils.hpp"
#include <ATen/core/TensorBody.h>
#include <c10/util/Exception.h>

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <string>
#include <type_traits>
#include <vector>

#include "../include/dispatcher_utils.hpp"
#include "../include/jblas_customop.hpp"
#include "../include/jblas_weightonly_dispatcher.hpp"
#include "jblas/jit_blas.h"
#include "jblas/jit_blas_epilogue.h"
#include "jblas/jit_blas_gemm.h"
Expand All @@ -41,6 +41,7 @@

inline bool check_amx() { return jblas::utils::parallel::CpuDevice::getInstance()->AMX_BF16(); }
inline bool check_avx512_vnni() { return jblas::utils::parallel::CpuDevice::getInstance()->AVX512_VNNI(); }
inline bool check_avx_vnni() { return jblas::utils::parallel::CpuDevice::getInstance()->AVX_VNNI(); };
inline bool check_avx512f() { return jblas::utils::parallel::CpuDevice::getInstance()->AVX512F(); }
inline bool check_avx2() { return jblas::utils::parallel::CpuDevice::getInstance()->AVX2(); }
class env_initer {
Expand All @@ -67,11 +68,13 @@ concept normal_PrologueA = requires {

template <typename T>
concept perchannel_Gemmcore = std::is_same_v<T, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI> ||
std::is_same_v<T, jblas::gemm::GemmCore_Row_NN_16x48_AMX_S8S8>;
std::is_same_v<T, jblas::gemm::GemmCore_Row_NN_16x48_AMX_S8S8> ||
std::is_same_v<T, jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI>;

template <typename T>
concept int8_cmptype_kblock_Gemmcore = std::is_same_v<T, jblas::gemm::kblock::GemmCore_Row_NN_16x48_AMX_INT8_KBLOCK> ||
std::is_same_v<T, jblas::gemm::kblock::GemmCore_Row_NN_3x48_AVX512_VNNI_KBLOCK>;
std::is_same_v<T, jblas::gemm::kblock::GemmCore_Row_NN_3x48_AVX512_VNNI_KBLOCK> ||
std::is_same_v<T, jblas::gemm::kblock::GemmCore_Row_NN_1x48_AVX_VNNI_KBLOCK>;

static void* jblas_workspace = nullptr;
static int64_t workspace_size = 0;
Expand Down Expand Up @@ -160,7 +163,8 @@ void parse_paramC(qbits_config_param* p, qbits_runtime_ctx* ctx, ParamA param_a)
ctx->beta};
return do_compute<KERNEL, ParamA, ParamC>(p, ctx, param_a, param_c);
}
if constexpr (std::is_same_v<typename KERNEL::GemmCore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI>) {
if constexpr (std::is_same_v<typename KERNEL::GemmCore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI> ||
std::is_same_v<typename KERNEL::GemmCore, jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI>) {
ParamC param_c = {ctx->output->data_ptr(),
ctx->ldo,
param_a.Q->mZPtr,
Expand Down Expand Up @@ -197,7 +201,7 @@ void parse_paramA(qbits_config_param* p, qbits_runtime_ctx* ctx) {
"Qbits: workspace size should large than " + std::to_string(need_size) + " bytes");
return workspace;
} else {
tmpbuf = aligned_alloc(64, need_size);
tmpbuf = malloc(need_size);
return tmpbuf;
}
};
Expand Down Expand Up @@ -241,7 +245,8 @@ void parse_store(qbits_config_param* p, qbits_runtime_ctx* ctx) {
if (p->dst_dt == QBITS_FP32) {
using namespace jblas::epilogue::gemm;
if constexpr (perchannel_Gemmcore<Gemmcore>) {
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI>)
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI> ||
std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI>)
return execute_task<
TASK, Interface<Launcher<ISA, Gemmcore, PrologueA, PrologueB, ZpDequantInt32AlphaBetaStoreFp32>, Parallel>>(
p, ctx);
Expand All @@ -256,7 +261,8 @@ void parse_store(qbits_config_param* p, qbits_runtime_ctx* ctx) {
}
if (p->dst_dt == QBITS_BF16) {
if constexpr (perchannel_Gemmcore<Gemmcore>) {
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI>)
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI> ||
std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI>)
return execute_task<
TASK, Interface<Launcher<ISA, Gemmcore, PrologueA, PrologueB, ZpDequantInt32AlphaBetaStoreBf16>, Parallel>>(
p, ctx);
Expand All @@ -280,7 +286,8 @@ void parse_activation(qbits_config_param* p, qbits_runtime_ctx* ctx) {
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::kblock::GemmCore_Row_NN_16x48_AMX_INT8_KBLOCK>)
return parse_store<TASK, Interface, Launcher, Gemmcore, Parallel, ISA, PrologueB, ActivationF32S8KBlockQuantize>(
p, ctx);
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::kblock::GemmCore_Row_NN_3x48_AVX512_VNNI_KBLOCK>)
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::kblock::GemmCore_Row_NN_3x48_AVX512_VNNI_KBLOCK> ||
std::is_same_v<Gemmcore, jblas::gemm::kblock::GemmCore_Row_NN_1x48_AVX_VNNI_KBLOCK>)
return parse_store<TASK, Interface, Launcher, Gemmcore, Parallel, ISA, PrologueB, ActivationF32U8KBlockQuantize>(
p, ctx);
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512F> ||
Expand All @@ -293,7 +300,8 @@ void parse_activation(qbits_config_param* p, qbits_runtime_ctx* ctx) {
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_16x48_AMX_S8S8>)
return parse_store<TASK, Interface, Launcher, Gemmcore, Parallel, ISA, PrologueB, ActivationFp32SymS8Quantize>(
p, ctx);
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI>)
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI> ||
std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI>)
return parse_store<TASK, Interface, Launcher, Gemmcore, Parallel, ISA, PrologueB, ActivationFp32AsymU8Quantize>(
p, ctx);
}
Expand All @@ -315,7 +323,8 @@ void parse_activation(qbits_config_param* p, qbits_runtime_ctx* ctx) {
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_16x48_AMX_S8S8>)
return parse_store<TASK, Interface, Launcher, Gemmcore, Parallel, ISA, PrologueB, ActivationBf16SymS8Quantize>(
p, ctx);
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI>)
if constexpr (std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI> ||
std::is_same_v<Gemmcore, jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI>)
return parse_store<TASK, Interface, Launcher, Gemmcore, Parallel, ISA, PrologueB, ActivationBf16AsymU8Quantize>(
p, ctx);
}
Expand Down Expand Up @@ -365,9 +374,7 @@ template <QBITS_TASK TASK>
void parse_gemm_core_online(qbits_config_param* p, qbits_runtime_ctx* ctx) {
bool per_channel_quant = ctx->blocksize == -1 ? true : false;
if (per_channel_quant) {
TORCH_CHECK(p->compute_type == "int8",
"Qbits: compute type must be int8 when enable per_channel "
"quantization.");
TORCH_CHECK(p->compute_type == "int8", "Qbits: compute type must be int8 when enable per_channel quantization.");
if (check_amx())
return parse_weight<TASK, jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB,
jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight,
Expand All @@ -378,6 +385,11 @@ void parse_gemm_core_online(qbits_config_param* p, qbits_runtime_ctx* ctx) {
jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight,
jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI, jblas::utils::parallel::Parallel2DGemm,
JblasAVX512_VNNI>(p, ctx);
if (check_avx_vnni())
return parse_weight<TASK, jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB,
jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight,
jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI, jblas::utils::parallel::Parallel2DGemm,
JblasAVX_VNNI>(p, ctx);
}
if (p->compute_type == "int8") {
if (check_amx()) {
Expand All @@ -399,6 +411,13 @@ void parse_gemm_core_online(qbits_config_param* p, qbits_runtime_ctx* ctx) {
jblas::gemm::kblock::GemmCore_Row_NN_3x48_AVX512_VNNI_KBLOCK,
jblas::utils::parallel::Parallel2DGemmKBlockFixed, JblasAVX512_VNNI>(p, ctx);
}
if (check_avx_vnni() &&
ctx->blocksize % (jblas::gemm::kblock::GemmCore_Row_NN_1x48_AVX_VNNI_KBLOCK::KTILE * 2) == 0) {
return parse_weight<TASK, jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight,
jblas::wrapper::gemm_kblock::GemmLauncherKBlock,
jblas::gemm::kblock::GemmCore_Row_NN_1x48_AVX_VNNI_KBLOCK,
jblas::utils::parallel::Parallel2DGemmKBlockFixed, JblasAVX_VNNI>(p, ctx);
}
TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type: blocksize:", ctx->blocksize,
" ISA largger than vnni:", check_avx512_vnni());
}
Expand Down Expand Up @@ -440,6 +459,7 @@ void parse_gemm_core_offline(qbits_config_param* p, qbits_runtime_ctx* ctx) {
case jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK:
case jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK:
assert(p->compute_type == "int8");
// TODO(zhe): potential bug, quantize in vnni machine, compute on amx machine.
if (check_amx() && blocksize % (jblas::gemm::kblock::GemmCore_Row_NN_16x48_AMX_INT8_KBLOCK::KTILE * 2) == 0) {
return parse_weight<TASK, jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight,
jblas::wrapper::gemm_kblock::GemmSLauncherKBlockPackWeight,
Expand All @@ -453,7 +473,18 @@ void parse_gemm_core_offline(qbits_config_param* p, qbits_runtime_ctx* ctx) {
jblas::utils::parallel::Parallel2DGemmKBlockFixed, JblasAVX512_VNNI>(p, ctx);
}
TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type: blocksize:", blocksize,
" ISA largger than vnni:", check_avx512_vnni());
" ISA largger than avx512-vnni:", check_avx512_vnni());
break;
case jblas::gemm::GemmCoreType::AVX_VNNI_1x48_KBLOCK:
assert(p->compute_type == "int8");
if (check_avx_vnni() &&
ctx->blocksize % (jblas::gemm::kblock::GemmCore_Row_NN_1x48_AVX_VNNI_KBLOCK::KTILE * 2) == 0)
return parse_weight<TASK, jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight,
jblas::wrapper::gemm_kblock::GemmLauncherKBlock,
jblas::gemm::kblock::GemmCore_Row_NN_1x48_AVX_VNNI_KBLOCK,
jblas::utils::parallel::Parallel2DGemmKBlockFixed, JblasAVX_VNNI>(p, ctx);
TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type: blocksize:", blocksize,
" ISA largger than avx-vnni:", check_avx_vnni());
break;
case jblas::gemm::GemmCoreType::AVX512F_8x48:
assert(p->compute_type == "fp32");
Expand Down Expand Up @@ -491,19 +522,24 @@ void parse_gemm_core_offline(qbits_config_param* p, qbits_runtime_ctx* ctx) {
jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight,
jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI, jblas::utils::parallel::Parallel2DGemm,
JblasAVX512_VNNI>(p, ctx);
TORCH_CHECK(false,
"Qbits: device ISA must lagger than AVX512_VNNI when "
"GemmCore==Row_NN_8x48_AVX512_VNNI");
TORCH_CHECK(false, "Qbits: device ISA must lagger than AVX512_VNNI when GemmCore==Row_NN_8x48_AVX512_VNNI");
break;
case jblas::gemm::GemmCoreType::AVX_VNNI_2x48:
assert(p->compute_type == "int8");
if (check_avx_vnni())
return parse_weight<TASK, jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB,
jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight,
jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI, jblas::utils::parallel::Parallel2DGemm,
JblasAVX_VNNI>(p, ctx);
TORCH_CHECK(false, "Qbits: device ISA must lagger than AVX_VNNI when GemmCore==Row_NN_2x48_AVX_VNNI");
break;
case jblas::gemm::GemmCoreType::AMX_INT8_16x48:
if (check_amx())
return parse_weight<TASK, jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB,
jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight,
jblas::gemm::GemmCore_Row_NN_16x48_AMX_S8S8, jblas::utils::parallel::Parallel2DGemm,
JblasAMX_INT8>(p, ctx);
TORCH_CHECK(false,
"Qbits: device ISA must lagger than AMX_INT8 when "
"GemmCore==Row_NN_16x48_AMX_S8S8");
TORCH_CHECK(false, "Qbits: device ISA must support AMX_INT8 when GemmCore==Row_NN_16x48_AMX_S8S8");
default:
break;
}
Expand Down

0 comments on commit 45e03b9

Please sign in to comment.