Skip to content

Commit

Permalink
use 4x2 tiles for blklen > 32
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Jul 13, 2024
1 parent 277df8d commit a2f5c6b
Showing 1 changed file with 28 additions and 163 deletions.
191 changes: 28 additions & 163 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,140 +469,6 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLen16(
}
}

template <bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16(
size_t BlkLen,
const std::byte* QuantARowPtr,
const std::byte* QuantBDataColPtr,
const float* QuantBScaleColPtr,
const std::byte* QuantBZeroPointColPtr,
const float* BiasPtr,
float* SumPtr,
size_t BlockCountK,
size_t StrideQuantA,
size_t StrideQuantBData,
size_t StrideQuantBScale,
size_t StrideQuantBZeroPoint,
size_t ldc
)
{
// process blocks in 32-element sub-blocks
const size_t SubBlksPerBlk = BlkLen / 32;

const std::byte* QuantAPtr = QuantARowPtr;
const std::byte* QuantBDataPtr = QuantBDataColPtr;
const float* QuantBScalePtr = QuantBScaleColPtr;
const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr;

float32x4_t acc00{}, acc01{}, acc10{}, acc11{};

for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) {
const std::byte* QuantABlkRow0 = QuantAPtr;
const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA;

const float QuantBScaleCol0 = *QuantBScalePtr;
const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale);

// compute combined scales
const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0;
const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1;
const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0;
const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1;

// load B zero point
int8_t bzp_col0;
int8_t bzp_col1;
if constexpr (HasZeroPoint) {
const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr;
const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint);
if ((k_blk_idx & 1) == 0) {
bzp_col0 = std::to_integer<int8_t>(QuantBZeroPointByteCol0 & std::byte{0x0F});
bzp_col1 = std::to_integer<int8_t>(QuantBZeroPointByteCol1 & std::byte{0x0F});
} else {
bzp_col0 = std::to_integer<int8_t>(QuantBZeroPointByteCol0 >> 4);
bzp_col1 = std::to_integer<int8_t>(QuantBZeroPointByteCol1 >> 4);
}
} else {
bzp_col0 = 8;
bzp_col1 = 8;
}

const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0);
const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1);

for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; ++sub_blk_idx) {
// load A
const int8x16_t av_row0_0 = vld1q_s8(QuantADataPtrRow0 + 0);
const int8x16_t av_row0_1 = vld1q_s8(QuantADataPtrRow0 + 16);
const int8x16_t av_row1_0 = vld1q_s8(QuantADataPtrRow1 + 0);
const int8x16_t av_row1_1 = vld1q_s8(QuantADataPtrRow1 + 16);

// load B
const uint8x16_t bv_packed_col0 = vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr));
const uint8x16_t bv_packed_col1 = vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr) + StrideQuantBData);

const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F);

int8x16_t bv_col0_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col0, LowMaskU8x16));
int8x16_t bv_col0_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col0, 4));
int8x16_t bv_col1_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col1, LowMaskU8x16));
int8x16_t bv_col1_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col1, 4));

// subtract B zero point
bv_col0_0 = vsubq_s8(bv_col0_0, vdupq_n_s8(bzp_col0));
bv_col0_1 = vsubq_s8(bv_col0_1, vdupq_n_s8(bzp_col0));
bv_col1_0 = vsubq_s8(bv_col1_0, vdupq_n_s8(bzp_col1));
bv_col1_1 = vsubq_s8(bv_col1_1, vdupq_n_s8(bzp_col1));

// quantized dot product
int32x4_t dot00{}, dot01{}, dot10{}, dot11{};
dot00 = vdotq_s32(vdotq_s32(dot00, av_row0_0, bv_col0_0), av_row0_1, bv_col0_1);
dot01 = vdotq_s32(vdotq_s32(dot01, av_row0_0, bv_col1_0), av_row0_1, bv_col1_1);
dot10 = vdotq_s32(vdotq_s32(dot10, av_row1_0, bv_col0_0), av_row1_1, bv_col0_1);
dot11 = vdotq_s32(vdotq_s32(dot11, av_row1_0, bv_col1_0), av_row1_1, bv_col1_1);

// convert to float
const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00);
const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01);
const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10);
const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11);

// multiply by scale and update accumulator
acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00));
acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01));
acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10));
acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11));

// increment block data pointers to next sub-block
QuantADataPtrRow0 += 32;
QuantADataPtrRow1 += 32;
QuantBDataPtr += 16;
}

// increment other block pointers

QuantAPtr += Q8BlkSize(BlkLen);
QuantBScalePtr += 1;

if constexpr (HasZeroPoint) {
QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1;
}
}

SumPtr[0] = vaddvq_f32(acc00);
SumPtr[1] = vaddvq_f32(acc01);
SumPtr[ldc + 0] = vaddvq_f32(acc10);
SumPtr[ldc + 1] = vaddvq_f32(acc11);

if (BiasPtr != nullptr) {
SumPtr[0] += BiasPtr[0];
SumPtr[1] += BiasPtr[1];
SumPtr[ldc + 0] += BiasPtr[0];
SumPtr[ldc + 1] += BiasPtr[1];
}
}

template <bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemm_CompInt8_Compute1x1_BlkLen16(
Expand Down Expand Up @@ -1278,7 +1144,7 @@ SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32(
float* SumRowPtr = C;

size_t m_remaining = CountM;
while (m_remaining > 1) {
while (m_remaining > 3) {
const std::byte* QuantBDataColPtr = QuantBData;
const float* QuantBScaleColPtr = QuantBScale;
const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;
Expand All @@ -1289,8 +1155,8 @@ SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32(

size_t n_remaining = CountN;
while (n_remaining > 1) {
// Compute 2x2 tiles of output
SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16<HasZeroPoint>(
// Compute 4x2 tiles of output
SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16<HasZeroPoint>(
BlkLen,
QuantARowPtr,
QuantBDataColPtr,
Expand All @@ -1316,40 +1182,31 @@ SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32(
}

if (n_remaining > 0) {
// Compute last 2x1 tile of output
SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32<HasZeroPoint>(
BlkLen,
QuantARowPtr,
QuantBDataColPtr,
QuantBScaleColPtr,
QuantBZeroPointColPtr,
BiasPtr,
SumPtr,
BlockCountK
);

SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32<HasZeroPoint>(
BlkLen,
QuantARowPtr + StrideQuantA,
QuantBDataColPtr,
QuantBScaleColPtr,
QuantBZeroPointColPtr,
BiasPtr,
SumPtr + ldc,
BlockCountK
);
// Compute last 4x1 tile of output
for (size_t i = 0; i < 4; ++i) {
SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32<HasZeroPoint>(
BlkLen,
QuantARowPtr + StrideQuantA * i,
QuantBDataColPtr,
QuantBScaleColPtr,
QuantBZeroPointColPtr,
BiasPtr,
SumPtr + ldc * i,
BlockCountK
);
}
}

// Move to next 2 rows
AdvanceRowPtrs<2>(
// Move to next 4 rows
AdvanceRowPtrs<4>(
StrideQuantA, ldc,
QuantARowPtr, SumRowPtr
);

m_remaining -= 2;
m_remaining -= 4;
}

if (m_remaining > 0) {
while (m_remaining > 0) {
const std::byte* QuantBDataColPtr = QuantBData;
const float* QuantBScaleColPtr = QuantBScale;
const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;
Expand Down Expand Up @@ -1380,6 +1237,14 @@ SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32(

n_remaining -= 1;
}

// Move to next row
AdvanceRowPtrs<1>(
StrideQuantA, ldc,
QuantARowPtr, SumRowPtr
);

m_remaining -= 1;
}
}

Expand Down

0 comments on commit a2f5c6b

Please sign in to comment.