Skip to content

Commit

Permalink
QBits adapt to the latest BesTLA (#1535)
Browse files Browse the repository at this point in the history
Co-authored-by: changwangss <chang1.wang@intel.com>
Co-authored-by: VincyZhang <wenxin.zhang@intel.com>
  • Loading branch information
3 people authored May 13, 2024
1 parent 53fd140 commit c169bec
Show file tree
Hide file tree
Showing 17 changed files with 135 additions and 120 deletions.
8 changes: 8 additions & 0 deletions docs/qbits.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@ qbits.woq_linear(
activation, pack_weight, bias, output, n, add_bias, compute_type, weight_type, scale_type, asym)
```
please refer [here](https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/transformers/llm/operator/csrc/qbits_ut) for more QBits operators usage.

## Pytorch version constrain
If user wants to use QBits, the Pytorch version must meet ITREX requirements, here are the constrains:

| ITREX version | Pytorch version |
| :-----------: | :-------------: |
| v1.4 | 2.2.0+cpu |
| v1.4.1 | 2.2.0+cpu |
3 changes: 3 additions & 0 deletions intel_extension_for_transformers/qbits/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ find_package(PythonLibs 3 REQUIRED)
endif()

include(FindOpenMP)
set(BTLA_ENABLE_OPENMP ON CACHE BOOL "BesTLA enable compiling OpenMP threading")
add_subdirectory(dispatcher)
add_subdirectory(../transformers/runtime/third_party/pybind11 pybind11)

file(GLOB HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp)
file(GLOB qbits_src ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp)

add_compile_options(-flto=auto)

# Link against LibTorch
pybind11_add_module(qbits_py ${qbits_src})
target_compile_features(qbits_py PRIVATE cxx_std_14)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ endif()

set_target_properties(bestla_dispatcher PROPERTIES POSITION_INDEPENDENTBTLA_CODE ON)
set_target_properties(bestla_dispatcher PROPERTIES LINKER_LANGUAGE CXX)
target_link_libraries(bestla_dispatcher OpenMP::OpenMP_CXX OpenMP::OpenMP_C "${TORCH_LIBRARIES}" bestla::bestla)
target_link_libraries(bestla_dispatcher OpenMP::OpenMP_CXX OpenMP::OpenMP_C "${TORCH_LIBRARIES}" bestla)
set_property(TARGET torch_cpu PROPERTY INTERFACE_COMPILE_OPTIONS "")
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@

template <typename Param, typename DST_T, BTLA_ISA ISA_T>
inline BTLA_CODE alphabeta_dt_cvt_process(float* tmp_dst, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param) {
const int M, const int N, const Param& _param) {
auto DOffset = M_offset * _param.ldd + N_offset;
auto dptr = reinterpret_cast<float*>(_param.D) + DOffset;
bestla::kernel::wrapper::AlphaBetaF32F32::template forward<ISA_T>(_param.alpha, tmp_dst, cachestep, _param.beta, dptr,
_param.ldd, tmp_dst, cachestep, M, N);
_param.ldd, tmp_dst, cachestep, M, N);

auto COffset = M_offset * _param.ldc + N_offset;
auto cptr = reinterpret_cast<DST_T*>(_param.C) + COffset;
if constexpr (std::is_same_v<DST_T, float>) {
return bestla::kernel::wrapper::Memcpy2D::template forward<ISA_T, float, DST_T>(tmp_dst, cptr, M, N, cachestep,
_param.ldc, NULL);
_param.ldc, NULL);
}
if constexpr (std::is_same_v<DST_T, bestla::utils::bf16>) {
return bestla::kernel::wrapper::Memcpy2DFp32CvtBf16::template forward<ISA_T>(
Expand All @@ -47,8 +47,8 @@ class AlphaBetaProcess {
int ldc, ldd;
float alpha, beta;
};
BTLA_CODE forward(float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
const int N, const Param& _param, void* tmpcache = nullptr, size_t cachesize = -1) {
static BTLA_CODE forward(float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
const int N, const Param& _param, void* tmpcache = nullptr, size_t cachesize = -1) {
return alphabeta_dt_cvt_process<Param, DST_T, ISA_T>(cacheptr, cachestep, M_offset, N_offset, M, N, _param);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct woq_runtime_ctx {

static std::map<std::string, BTLA_DTYPE> wei2bestladt_map{{"int8", BTLA_DTYPE::S8},
{"int4_clip", BTLA_DTYPE::S4_CLIP},
{"int4_fullrange", BTLA_DTYPE::S4_FULLRANGE},
{"int3_clip", BTLA_DTYPE::S3_CLIP},
{"int2_clip", BTLA_DTYPE::S2_CLIP},
{"nf4", BTLA_DTYPE::F4_NF4},
{"fp4_e2m1_bnb", BTLA_DTYPE::F4_BNB},
{"fp4_e2m1", BTLA_DTYPE::F4_E2M1},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,26 @@ inline bool check_avx_vnni() { return bestla::device::CpuDevice::getInstance()->
inline bool check_avx512f() { return bestla::device::CpuDevice::getInstance()->AVX512F(); }
inline bool check_avx2() { return bestla::device::CpuDevice::getInstance()->AVX2(); }

class qbits_threading {
public:
static bestla::parallel::IThreading* get() {
GetCPUDevice();
static bestla::parallel::StdThreading OptmizedThreading;
static bestla::parallel::OMPThreading DefaultThreading;
if (!_cd->isHybrid()) {
return &DefaultThreading;
}
return &OptmizedThreading;
}

static void set_threads(int n_thread) { get()->set_threads(n_thread); }
};

class env_initer {
public:
env_initer() {
if (check_amx()) bestla::utils::request_perm_xtile_data();
qbits_threading::set_threads(bestla::device::CpuDevice::getInstance()->getThreads());
verbose = std::getenv("QBITS_VERBOSE") != nullptr;
FLAGS_caffe2_log_level = 0;
}
Expand All @@ -56,7 +72,7 @@ class Timer {
high_resolution_clock::time_point m_end;
};
static Timer timer;
static bestla::parallel::OMPThreading DefaultThreading(bestla::device::CpuDevice::getInstance()->getThreads());

string get_torch_dt_name(torch::Tensor* tensor);

} // namespace dispatcher_utils
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
set(NEURAL_SPEED_URL https://github.com/intel/neural-speed.git)
set(NEURAL_SPEED_TAG bestlav0.1)
set(NEURAL_SPEED_TAG 2f7943681e02c6e87a4c70c3925327f00194c78f)

FetchContent_Declare(
neural_speed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ void do_gemm(bestla_gemm_runtime_ctx* ctx) {
packw.assign(tmpbuf);
if (ctx->matB_trans) {
launcher.mProB.packWeightTranspose(ctx->n, ctx->k, {reinterpret_cast<DT*>(ctx->matB->data_ptr()), ctx->k, &packw},
&dispatcher_utils::DefaultThreading);
dispatcher_utils::qbits_threading::get());
} else {
launcher.mProB.packWeight(ctx->n, ctx->k, {reinterpret_cast<DT*>(ctx->matB->data_ptr()), ctx->n, &packw},
&dispatcher_utils::DefaultThreading);
dispatcher_utils::qbits_threading::get());
}
bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k);
typename Launcher::Param args{gp,
{reinterpret_cast<DT*>(ctx->matA->data_ptr()), ctx->k},
{reinterpret_cast<DT*>(ctx->matB->data_ptr()), ctx->n, &packw},
{reinterpret_cast<DT*>(ctx->matC->data_ptr()), ctx->n}};
bestla::parallel::GemmRun<Parallel>(launcher, args, &dispatcher_utils::DefaultThreading);
bestla::parallel::GemmRun<Parallel>(launcher, args, dispatcher_utils::qbits_threading::get());
bestla::utils::afree(tmpbuf);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx
*(ctx->output) = torch::empty(qpackw.mSize, torch::kInt8);
qpackw.assign(ctx->output->data_ptr<int8_t>());
if (p->enable_act_shuffle)
ker.setShuffleIndices(ctx->g_idx->data_ptr<int>(), &qpackw, &dispatcher_utils::DefaultThreading);
ker.setShuffleIndices(ctx->g_idx->data_ptr<int>(), &qpackw, dispatcher_utils::qbits_threading::get());
ker.packQWeight(ctx->n, ctx->k, ctx->qweight->data_ptr<int8_t>(), ctx->n, ctx->scale->data_ptr<float>(),
p->asym ? ctx->zp->data_ptr<int8_t>() : nullptr, &qpackw, &dispatcher_utils::DefaultThreading);
p->asym ? ctx->zp->data_ptr<int8_t>() : nullptr, &qpackw, dispatcher_utils::qbits_threading::get());
}

std::string get_dtype_str(BTLA_DTYPE dtype) {
Expand All @@ -41,8 +41,10 @@ std::string get_dtype_str(BTLA_DTYPE dtype) {
return "bf16";
case BTLA_DTYPE::S4_CLIP:
return "int4_clip";
case BTLA_DTYPE::S4_FULLRANGE:
return "int4_fullrange";
case BTLA_DTYPE::S3_CLIP:
return "int3_clip";
case BTLA_DTYPE::S2_CLIP:
return "int2_clip";
case BTLA_DTYPE::F4_NF4:
return "nf4";
case BTLA_DTYPE::F4_E2M1:
Expand All @@ -66,7 +68,6 @@ std::string get_dtype_str(BTLA_DTYPE dtype) {
std::string get_cmpt_str(bestla::gemm::CompType cmpt) {
using bestla::gemm::CompType;
switch (cmpt) {
case CompType::COMP_INT8_US_INT32:
case CompType::COMP_INT8_US_FP32:
return "int8";
case CompType::COMP_FP32:
Expand Down Expand Up @@ -182,43 +183,34 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
}

void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int4_fullrange",
// TODO(zhe): elegant impl.
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip",
"Qbits: only support Integer WOQ in PACKQ");

// NTILE & compute-dtype determine the padsize.
// in qbits:
// avx_vnni/avx512f_vnni/amx-int8 NTILE==48, compute-dtype=int8;
// avx2/avx512f NTILE==48, compute-dtype=fp32;
// amx-bf16 NTILE==64, compute-dtype=bf16.
if (task == WOQ_GET_PACKW_SIZE) {
if (p->compute_type == "int8")
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
if (p->compute_type == "fp32")
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
if (p->compute_type == "bf16")
return execute_qpack<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
}

if (p->compute_type == "int8") {
if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
}
if (dispatcher_utils::check_avx512_vnni() &&
p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx2() && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize,
", ISA support vnni:", dispatcher_utils::check_avx_vnni());
", ISA support avx2:", dispatcher_utils::check_avx2());
}
if (p->compute_type == "fp32") {
if (dispatcher_utils::check_avx512f()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
}
if (dispatcher_utils::check_avx2()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx2<48, 2>, BTLA_ISA::AVX2>(p, ctx, task);
return execute_qpack<bestla::gemm::SCoreRowNAvx2<24, 4>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32");
}
Expand Down
Loading

0 comments on commit c169bec

Please sign in to comment.