Skip to content

Commit

Permalink
Qbits woq ref impl for debug (#1248)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhewang1-intc authored Feb 4, 2024
1 parent 71f5e84 commit 18d36ef
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2023 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "bestla_weightonly_dispatcher.hpp"
namespace woq {

enum PACKW_ACQUIRE_TYPE {
SIZE = 0,
BLOCKSIZE,
K,
N,
ACT_SHUFFLE,
G_IDX,
WEI_TYPE,
CMPT_TYPE,
SCALE_TYPE,
};

void bestla_packq(woq_packq_param* p, woq_packq_ctx* ctx);
torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T);
} // namespace woq
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,5 @@ static std::map<std::string, BTLA_DTYPE> scale2bestladt_map{
{"fp32", BTLA_DTYPE::F32}, {"bf16", BTLA_DTYPE::BF16}, {"fp8_e8m0", BTLA_DTYPE::F8_E8M0}};

void dispatch_woq_task(woq_config_param* p, woq_runtime_ctx* ctx, WOQ_TASK task);
void bestla_packq(woq_packq_param* p, woq_packq_ctx* ctx);
void set_woq_workspace(torch::Tensor* workspace);
} // namespace woq
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "bestla/bestla_prologue_b.h"
#include "../include/bestla_weightonly_dispatcher.hpp"
#include "../include/bestla_packq_impl.hpp"

namespace woq {
template <class GemmCore, BTLA_ISA ISA>
Expand All @@ -17,6 +17,101 @@ void execute_qpack(woq_packq_param* p, woq_packq_ctx* ctx) {
p->asym ? ctx->zp->data_ptr<int8_t>() : nullptr, &qpackw, &dispatcher_utils::DefaultThreading);
}

std::string get_dtype_str(BTLA_DTYPE dtype) {
switch (dtype) {
case BTLA_DTYPE::F32:
return "fp32";
case BTLA_DTYPE::BF16:
return "bf16";
case BTLA_DTYPE::S4_CLIP:
return "int4_clip";
case BTLA_DTYPE::S4_FULLRANGE:
return "int4_fullrange";
case BTLA_DTYPE::F4_NF4:
return "nf4";
case BTLA_DTYPE::F4_E2M1:
return "fp4_e2m1";
case BTLA_DTYPE::F4_BNB:
return "fp4_e2m1_bnb";
case BTLA_DTYPE::S8:
return "int8";
case BTLA_DTYPE::F8_E5M2:
return "fp8_e5m2";
case BTLA_DTYPE::F8_E4M3:
return "fp8_e4m3";
case BTLA_DTYPE::F8_E8M0:
return "fp8_e8m0";
default:
TORCH_CHECK(false, "QBits: unrecognized data type.")
break;
}
}

std::string get_cmpt_str(bestla::gemm::CompType cmpt) {
using bestla::gemm::CompType;
switch (cmpt) {
case CompType::COMP_INT8_US_INT32:
case CompType::COMP_INT8_US_FP32:
return "int8";
case CompType::COMP_FP32:
return "fp32";
case CompType::COMP_BF16_FP32:
return "bf16";
default:
TORCH_CHECK(false, "QBits: unrecognized compute type.");
break;
}
}

std::vector<int> get_ascii_vec(std::string str) {
std::vector<int32_t> ret;
for (char c : str) ret.push_back(static_cast<int32_t>(c));
return ret;
}

torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
torch::Tensor output;
auto packw_ptr = dynamic_cast<bestla::storage::gemm::StorageWeightKBlockNInteger*>(
bestla::storage::gemm::PackedWeightParser::deserialBuffer(packw.data_ptr()));
output = torch::empty(1, torch::kInt64);
switch (ACQ_T) {
case SIZE:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->mSize));
case BLOCKSIZE:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->mBlockSize));
case K:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->mK));
case N:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->mN));
case ACT_SHUFFLE:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->ShfIndice() != nullptr ? 1 : 0));
case G_IDX: {
auto tensor_size = packw_ptr->mShuffleIndices.size<int>();
TORCH_CHECK(packw_ptr->ShfIndice() != nullptr, "QBits: not pack g_idx tensor.");
output = torch::empty(tensor_size, torch::kInt32);
memcpy(output.data_ptr(), packw_ptr->ShfIndice(), tensor_size * sizeof(int));
} break;
case WEI_TYPE:
case SCALE_TYPE: {
BTLA_DTYPE acquire_dt = ACQ_T == WEI_TYPE ? packw_ptr->mDType : packw_ptr->SDtype();
auto ascii_vec = get_ascii_vec(get_dtype_str(acquire_dt));
output = torch::empty(ascii_vec.size(), torch::kInt32);
memcpy(output.data_ptr(), ascii_vec.data(), ascii_vec.size() * sizeof(int));
} break;
case CMPT_TYPE: {
auto CType = bestla::gemm::CoreAttr::get_mask_val(packw_ptr->mCoreId, bestla::gemm::CoreAttr::COMP_MASK,
bestla::gemm::CoreAttr::COMP_SHIFT);
auto ascii_vec = get_ascii_vec(get_cmpt_str(static_cast<bestla::gemm::CompType>(CType)));
output = torch::empty(ascii_vec.size(), torch::kInt32);
memcpy(output.data_ptr(), ascii_vec.data(), ascii_vec.size() * sizeof(int));
} break;
default:
TORCH_CHECK(false, "QBits: unsupported acquire_type");
break;
}
return output;
}

void bestla_packq(woq_packq_param* p, woq_packq_ctx* ctx) {
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int4_fullrange",
"Qbits: only support Integer WOQ in PACKQ");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ void woq_dequantize(woq_config_param* p, woq_runtime_ctx* ctx) {
using PrologueB = typename Launcher::PrologueB;
using WType = typename Launcher::PrologueB::StorageWeight;
static PrologueB kernel;
// TODO(zhe): using unified StorageWeightKBlockNInteger after sync with neural-speed(with NFloat ProB feature).
if (ctx->transpose) {
kernel.unpackTransposeWeight(ctx->deseries_wei->mN, ctx->deseries_wei->mK,
dynamic_cast<bestla::storage::gemm::StorageWeightKBlockNInteger*>(ctx->deseries_wei),
Expand Down
12 changes: 9 additions & 3 deletions intel_extension_for_transformers/llm/operator/csrc/qbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "dispatcher/include/dispatcher_utils.hpp"
#include "dispatcher/include/bestla_gemm_dispatcher.hpp"
#include "dispatcher/include/bestla_weightonly_dispatcher.hpp"
#include "dispatcher/include/bestla_packq_impl.hpp"
#include "include/dropout.hpp"
#include <ATen/core/TensorBody.h>
#include <c10/core/ScalarType.h>
Expand Down Expand Up @@ -45,8 +46,8 @@ static void inline init_woq_config_param(woq::woq_config_param* p, woq::woq_runt
case woq::WOQ_QUANTIZE:
case woq::WOQ_DEQUANTIZE:
p->src_dt = dispatcher_utils::QBITS_FP32;
p->dst_dt = dispatcher_utils::QBITS_FP32; // bestla doesn't care about dst_dt in quantize/dequant task,so set fp32
// as default.
p->dst_dt = dispatcher_utils::QBITS_FP32; // bestla doesn't care about dst_dt in quantize/dequant task,so set
// fp32 as default.
break;
case woq::WOQ_LINEAR:
p->src_dt = get_qbits_dt(ctx->activation);
Expand Down Expand Up @@ -122,7 +123,7 @@ static void set_woq_workspace(const torch::Tensor& workspace) {
}

static void bestlaop_gemm(const torch::Tensor& matA, const torch::Tensor& matB, const torch::Tensor& matC,
bool matB_trans) {
bool matB_trans) {
TORCH_CHECK(matA.dim() == 2 && matB.dim() == 2 && matC.dim() == 2,
"Qbits: only support 2-dim input-tensor in bestla gemm op.");
bestla_gemm::bestla_gemm_runtime_ctx ctx;
Expand All @@ -138,6 +139,10 @@ static void bestlaop_gemm(const torch::Tensor& matA, const torch::Tensor& matB,
return bestla_gemm::dispatch_bestla_gemm(&ctx);
}

static torch::Tensor acquire_woq_packw_info(torch::Tensor& packw, int64_t acquire_type) {
return woq::get_packw_info(packw, static_cast<woq::PACKW_ACQUIRE_TYPE>(acquire_type));
}

static torch::Tensor qbits_dropout_fwd(torch::Tensor& output, double p) { return dropout_fwd(output, p); }

static void qbits_dropout_bwd(torch::Tensor& grad, torch::Tensor& scale) { dropout_bwd(grad, scale); }
Expand All @@ -149,6 +154,7 @@ TORCH_LIBRARY(bestlaop, m) {
m.def("woq_packq", &woq_packq);
m.def("set_woq_workspace", &set_woq_workspace);
m.def("matmul", &bestlaop_gemm);
m.def("acquire_woq_packw_info", &acquire_woq_packw_info);
}

TORCH_LIBRARY(qbits_customop, m) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.

from ut_utils import *
from enum import Enum


def convert_idx(g_idx, k, blocksize):
Expand All @@ -27,6 +28,18 @@ def convert_idx(g_idx, k, blocksize):
return ret_idx


class acquire_type(Enum):
SIZE = 0
BLOCKSIZE = 1
K = 2
N = 3
ACT_SHUFFLE = 4
G_IDX = 5
WEI_TYPE = 6
CMPT_TYPE = 7
SCALE_TYPE = 8


@pytest.mark.parametrize("m", [256])
@pytest.mark.parametrize("n", [1024])
@pytest.mark.parametrize("k", [512])
Expand Down Expand Up @@ -66,3 +79,19 @@ def test(m, k, n, weight_type, scale_type, compute_type, asym, blocksize, dump_t
assert (abs(ref_dst - tar_dst).max() < 8)
else:
assert (abs(ref_dst - tar_dst).max() < 10)
packw_size = torch.ops.bestlaop.acquire_woq_packw_info(
packw, acquire_type.SIZE.value)[0].item()
if packw_size != packw.size()[0]:
assert (0)
packw_wei_type = torch.ops.bestlaop.acquire_woq_packw_info(
packw, acquire_type.WEI_TYPE.value)
packw_wei_type_str = ''.join(chr(ascii_code)
for ascii_code in packw_wei_type.tolist())
if packw_wei_type_str != weight_type:
assert (0)
enable_act_shuffle = torch.ops.bestlaop.acquire_woq_packw_info(
packw, acquire_type.ACT_SHUFFLE.value)[0] != 0
assert (enable_act_shuffle)
acquire_g_idx = packw_wei_type = torch.ops.bestlaop.acquire_woq_packw_info(
packw, acquire_type.G_IDX.value)
assert (abs(acquire_g_idx-cvt_idx).max() == 0)
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,49 @@
# limitations under the License.


import os
import operator
import torch
from functools import reduce
from torch import Tensor
from typing import Tuple, Optional, List
from enum import Enum


class qbits_acquire_type(Enum):
SIZE = 0
BLOCKSIZE = 1
K = 2
N = 3
ACT_SHUFFLE = 4
G_IDX = 5
WEI_TYPE = 6
CMPT_TYPE = 7
SCALE_TYPE = 8


def qbits_woq_linear_ref_impl(activation, packw, bias, compute_type, weight_type, scale_type):
assert (activation.is_contiguous())
assert (packw.is_contiguous())
activation = activation.to(torch.float32)
n = torch.ops.bestlaop.acquire_woq_packw_info(
packw, qbits_acquire_type.N.value)[0].item()
k = activation.shape[1]
revert_wei = torch.empty(k, n, dtype=torch.float)
torch.ops.bestlaop.woq_dequantize(
packw, revert_wei, False, compute_type, weight_type, scale_type)
enable_act_shuffle = torch.ops.bestlaop.acquire_woq_packw_info(
packw, qbits_acquire_type.ACT_SHUFFLE.value)[0] != 0
if enable_act_shuffle:
g_idx = torch.ops.bestlaop.acquire_woq_packw_info(
packw, qbits_acquire_type.G_IDX.value)
activation = torch.index_select(activation, 1, g_idx)
out = torch.matmul(activation, revert_wei)
if bias is not None:
assert (bias.is_contiguous())
assert (bias.dtype == torch.float32)
out += bias
return out


def prod(iterable):
Expand Down Expand Up @@ -64,18 +102,23 @@ def forward(

# 2. Matmul
# output = torch.nn.functional.linear(A, B_dequant, bias)
torch.ops.bestlaop.woq_linear(
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)
qbits_debug_flag = os.getenv('QBITS_DEBUG', 'NULL')
if qbits_debug_flag == 'NULL':
torch.ops.bestlaop.woq_linear(
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)
else:
out = qbits_woq_linear_ref_impl(
A, B.data, bias, compute_dtype, weight_dtype, scale_dtype)
output = out

# 3. Save state
Expand All @@ -101,7 +144,8 @@ def forward(
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
bias_grad = None if ctx.bias is None else torch.zeros_like(
ctx.bias)
return (
torch.zeros_like(ctx.A),
torch.zeros_like(ctx.B),
Expand All @@ -114,7 +158,8 @@ def backward(ctx, grad_output):
A, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None

B_dequant = torch.zeros(grad_output.shape[-1], A.shape[-1], dtype=torch.float)
B_dequant = torch.zeros(
grad_output.shape[-1], A.shape[-1], dtype=torch.float)

torch.ops.bestlaop.woq_dequantize(
B, B_dequant, True, ctx.compute_dtype, ctx.weight_dtype, ctx.scale_dtype
Expand Down Expand Up @@ -149,17 +194,22 @@ def matmul_kbit(
A, B, out, bias, compute_dtype, weight_dtype, scale_dtype
)
else:
torch.ops.bestlaop.woq_linear(
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)
qbits_debug_flag = os.getenv('QBITS_DEBUG', 'NULL')
if qbits_debug_flag == 'NULL':
torch.ops.bestlaop.woq_linear(
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)
else:
out = qbits_woq_linear_ref_impl(
A, B.data, bias, compute_dtype, weight_dtype, scale_dtype)

return out

0 comments on commit 18d36ef

Please sign in to comment.