Skip to content

Commit

Permalink
Revert "add 2x4 impl for blklen 32"
Browse files Browse the repository at this point in the history
This reverts commit f837239.

the 4x2 impl was faster in microbenchmark measurements
  • Loading branch information
edgchen1 committed Jul 13, 2024
1 parent f837239 commit 277df8d
Showing 1 changed file with 2 additions and 333 deletions.
335 changes: 2 additions & 333 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,206 +153,6 @@ namespace
// The ComputeRxC functions compute an R row by C column tile of the output matrix.
//

template <bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemm_CompInt8_Compute2x4_BlkLenGreaterThan16(
size_t BlkLen,
const std::byte* QuantARowPtr,
const std::byte* QuantBDataColPtr,
const float* QuantBScaleColPtr,
const std::byte* QuantBZeroPointColPtr,
const float* BiasPtr,
float* SumPtr,
size_t BlockCountK,
size_t StrideQuantA,
size_t StrideQuantBData,
size_t StrideQuantBScale,
size_t StrideQuantBZeroPoint,
size_t ldc
)
{
// process blocks in 32-element sub-blocks
const size_t SubBlksPerBlk = BlkLen / 32;

const std::byte* QuantAPtr = QuantARowPtr;
const std::byte* QuantBDataPtr = QuantBDataColPtr;
const float* QuantBScalePtr = QuantBScaleColPtr;
const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr;

float32x4_t acc00{}, acc01{}, acc02{}, acc03{}, acc10{}, acc11{}, acc12{}, acc13{};

for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) {
const std::byte* QuantABlkRow0 = QuantAPtr;
const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA;

const float QuantBScaleCol0 = *QuantBScalePtr;
const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale * 1);
const float QuantBScaleCol2 = *(QuantBScalePtr + StrideQuantBScale * 2);
const float QuantBScaleCol3 = *(QuantBScalePtr + StrideQuantBScale * 3);

// compute combined scales
const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0;
const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1;
const float scale02 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol2;
const float scale03 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol3;
const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0;
const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1;
const float scale12 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol2;
const float scale13 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol3;

// load B zero point
int8_t bzp_col0, bzp_col1, bzp_col2, bzp_col3;
if constexpr (HasZeroPoint) {
const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr;
const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint * 1);
const std::byte QuantBZeroPointByteCol2 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint * 2);
const std::byte QuantBZeroPointByteCol3 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint * 3);
if ((k_blk_idx & 1) == 0) {
bzp_col0 = std::to_integer<int8_t>(QuantBZeroPointByteCol0 & std::byte{0x0F});
bzp_col1 = std::to_integer<int8_t>(QuantBZeroPointByteCol1 & std::byte{0x0F});
bzp_col2 = std::to_integer<int8_t>(QuantBZeroPointByteCol2 & std::byte{0x0F});
bzp_col3 = std::to_integer<int8_t>(QuantBZeroPointByteCol3 & std::byte{0x0F});
} else {
bzp_col0 = std::to_integer<int8_t>(QuantBZeroPointByteCol0 >> 4);
bzp_col1 = std::to_integer<int8_t>(QuantBZeroPointByteCol1 >> 4);
bzp_col2 = std::to_integer<int8_t>(QuantBZeroPointByteCol2 >> 4);
bzp_col3 = std::to_integer<int8_t>(QuantBZeroPointByteCol3 >> 4);
}
} else {
bzp_col0 = 8;
bzp_col1 = 8;
bzp_col2 = 8;
bzp_col3 = 8;
}

const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0);
const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1);

for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; ++sub_blk_idx) {
// load A
const int8x16_t av_row0_0 = vld1q_s8(QuantADataPtrRow0 + 0);
const int8x16_t av_row0_1 = vld1q_s8(QuantADataPtrRow0 + 16);
const int8x16_t av_row1_0 = vld1q_s8(QuantADataPtrRow1 + 0);
const int8x16_t av_row1_1 = vld1q_s8(QuantADataPtrRow1 + 16);

// columns 0 and 1 of B
{
// load B
const uint8x16_t bv_packed_col0 = vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr));
const uint8x16_t bv_packed_col1 =
vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr) + StrideQuantBData);

const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F);

int8x16_t bv_col0_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col0, LowMaskU8x16));
int8x16_t bv_col0_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col0, 4));
int8x16_t bv_col1_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col1, LowMaskU8x16));
int8x16_t bv_col1_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col1, 4));

// subtract B zero point
bv_col0_0 = vsubq_s8(bv_col0_0, vdupq_n_s8(bzp_col0));
bv_col0_1 = vsubq_s8(bv_col0_1, vdupq_n_s8(bzp_col0));
bv_col1_0 = vsubq_s8(bv_col1_0, vdupq_n_s8(bzp_col1));
bv_col1_1 = vsubq_s8(bv_col1_1, vdupq_n_s8(bzp_col1));

// quantized dot product
int32x4_t dot00{}, dot01{}, dot10{}, dot11{};
dot00 = vdotq_s32(vdotq_s32(dot00, av_row0_0, bv_col0_0), av_row0_1, bv_col0_1);
dot01 = vdotq_s32(vdotq_s32(dot01, av_row0_0, bv_col1_0), av_row0_1, bv_col1_1);
dot10 = vdotq_s32(vdotq_s32(dot10, av_row1_0, bv_col0_0), av_row1_1, bv_col0_1);
dot11 = vdotq_s32(vdotq_s32(dot11, av_row1_0, bv_col1_0), av_row1_1, bv_col1_1);

// convert to float
const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00);
const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01);
const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10);
const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11);

// multiply by scale and update accumulator
acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00));
acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01));
acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10));
acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11));
}

// columns 2 and 3 of B
{
// load B
const uint8x16_t bv_packed_col2 =
vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr) + StrideQuantBData * 2);
const uint8x16_t bv_packed_col3 =
vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr) + StrideQuantBData * 3);

const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F);

int8x16_t bv_col2_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col2, LowMaskU8x16));
int8x16_t bv_col2_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col2, 4));
int8x16_t bv_col3_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col3, LowMaskU8x16));
int8x16_t bv_col3_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col3, 4));

// subtract B zero point
bv_col2_0 = vsubq_s8(bv_col2_0, vdupq_n_s8(bzp_col2));
bv_col2_1 = vsubq_s8(bv_col2_1, vdupq_n_s8(bzp_col2));
bv_col3_0 = vsubq_s8(bv_col3_0, vdupq_n_s8(bzp_col3));
bv_col3_1 = vsubq_s8(bv_col3_1, vdupq_n_s8(bzp_col3));

// quantized dot product
int32x4_t dot02{}, dot03{}, dot12{}, dot13{};
dot02 = vdotq_s32(vdotq_s32(dot02, av_row0_0, bv_col2_0), av_row0_1, bv_col2_1);
dot03 = vdotq_s32(vdotq_s32(dot03, av_row0_0, bv_col3_0), av_row0_1, bv_col3_1);
dot12 = vdotq_s32(vdotq_s32(dot12, av_row1_0, bv_col2_0), av_row1_1, bv_col2_1);
dot13 = vdotq_s32(vdotq_s32(dot13, av_row1_0, bv_col3_0), av_row1_1, bv_col3_1);

// convert to float
const float32x4_t dot_f32_02 = vcvtq_f32_s32(dot02);
const float32x4_t dot_f32_03 = vcvtq_f32_s32(dot03);
const float32x4_t dot_f32_12 = vcvtq_f32_s32(dot12);
const float32x4_t dot_f32_13 = vcvtq_f32_s32(dot13);

// multiply by scale and update accumulator
acc02 = vfmaq_f32(acc02, dot_f32_02, vdupq_n_f32(scale02));
acc03 = vfmaq_f32(acc03, dot_f32_03, vdupq_n_f32(scale03));
acc12 = vfmaq_f32(acc12, dot_f32_12, vdupq_n_f32(scale12));
acc13 = vfmaq_f32(acc13, dot_f32_13, vdupq_n_f32(scale13));
}

// increment block data pointers to next sub-block
QuantADataPtrRow0 += 32;
QuantADataPtrRow1 += 32;
QuantBDataPtr += 16;
}

// increment other block pointers

QuantAPtr += Q8BlkSize(BlkLen);
QuantBScalePtr += 1;

if constexpr (HasZeroPoint) {
QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1;
}
}

SumPtr[ldc * 0 + 0] = vaddvq_f32(acc00);
SumPtr[ldc * 0 + 1] = vaddvq_f32(acc01);
SumPtr[ldc * 0 + 2] = vaddvq_f32(acc02);
SumPtr[ldc * 0 + 3] = vaddvq_f32(acc03);
SumPtr[ldc * 1 + 0] = vaddvq_f32(acc10);
SumPtr[ldc * 1 + 1] = vaddvq_f32(acc11);
SumPtr[ldc * 1 + 2] = vaddvq_f32(acc12);
SumPtr[ldc * 1 + 3] = vaddvq_f32(acc13);

if (BiasPtr != nullptr) {
SumPtr[ldc * 0 + 0] += BiasPtr[0];
SumPtr[ldc * 0 + 1] += BiasPtr[1];
SumPtr[ldc * 0 + 2] += BiasPtr[2];
SumPtr[ldc * 0 + 3] += BiasPtr[3];
SumPtr[ldc * 1 + 0] += BiasPtr[0];
SumPtr[ldc * 1 + 1] += BiasPtr[1];
SumPtr[ldc * 1 + 2] += BiasPtr[2];
SumPtr[ldc * 1 + 3] += BiasPtr[3];
}
}

template <bool HasZeroPoint>
MLAS_FORCEINLINE void
SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16(
Expand Down Expand Up @@ -1320,138 +1120,7 @@ SQ4BitGemmKernel_CompInt8_BlkLen16(

template <bool HasZeroPoint>
void
SQ4BitGemmKernel_CompInt8_BlkLen32_2x4Tiling(
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
const std::byte* QuantBZeroPoint,
float* C,
size_t CountM,
size_t CountN,
size_t BlockCountK,
size_t ldc,
const float* Bias
)
{
constexpr size_t BlkBitWidth = 4;
constexpr size_t BlkLen = 32;

const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen);

const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t StrideQuantBScale = BlockCountK;
const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(BlockCountK);

const std::byte* QuantARowPtr = QuantA;

float* SumRowPtr = C;

size_t m_remaining = CountM;
while (m_remaining > 1) {
const std::byte* QuantBDataColPtr = QuantBData;
const float* QuantBScaleColPtr = QuantBScale;
const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;

const float* BiasPtr = Bias;

float* SumPtr = SumRowPtr;

size_t n_remaining = CountN;
while (n_remaining > 3) {
// Compute 2x4 tiles of output
SQ4BitGemm_CompInt8_Compute2x4_BlkLenGreaterThan16<HasZeroPoint>(
BlkLen,
QuantARowPtr,
QuantBDataColPtr,
QuantBScaleColPtr,
QuantBZeroPointColPtr,
BiasPtr,
SumPtr,
BlockCountK,
StrideQuantA,
StrideQuantBData,
StrideQuantBScale,
StrideQuantBZeroPoint,
ldc
);

// Move to next 4 columns
AdvanceColPtrs<4, HasZeroPoint>(
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr
);

n_remaining -= 4;
}

while (n_remaining > 0) {
// Compute 2x1 tiles of output
for (size_t i = 0; i < 2; ++i) {
SQ4BitGemm_CompInt8_Compute1x1_BlkLen32<HasZeroPoint>(
QuantARowPtr + StrideQuantA * i,
QuantBDataColPtr,
QuantBScaleColPtr,
QuantBZeroPointColPtr,
BiasPtr,
SumPtr + ldc * i,
BlockCountK
);
}

// Move to next column
AdvanceColPtrs<1, HasZeroPoint>(
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr
);

n_remaining -= 1;
}

// Move to next 2 rows
AdvanceRowPtrs<2>(
StrideQuantA, ldc,
QuantARowPtr, SumRowPtr
);

m_remaining -= 2;
}

if (m_remaining > 0) {
const std::byte* QuantBDataColPtr = QuantBData;
const float* QuantBScaleColPtr = QuantBScale;
const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;

const float* BiasPtr = Bias;

float* SumPtr = SumRowPtr;

size_t n_remaining = CountN;
while (n_remaining > 0) {
// Compute 1x1 tiles of output
SQ4BitGemm_CompInt8_Compute1x1_BlkLen32<HasZeroPoint>(
QuantARowPtr,
QuantBDataColPtr,
QuantBScaleColPtr,
QuantBZeroPointColPtr,
BiasPtr,
SumPtr,
BlockCountK
);

// Move to next column
AdvanceColPtrs<1, HasZeroPoint>(
StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr
);

n_remaining -= 1;
}
}
}

template <bool HasZeroPoint>
void
SQ4BitGemmKernel_CompInt8_BlkLen32_4x2Tiling(
SQ4BitGemmKernel_CompInt8_BlkLen32(
const std::byte* QuantA,
const std::byte* QuantBData,
const float* QuantBScale,
Expand Down Expand Up @@ -1744,7 +1413,7 @@ SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen(
Bias
);
} else if (BlkLen == 32) {
SQ4BitGemmKernel_CompInt8_BlkLen32_2x4Tiling<HasZeroPoint>(
SQ4BitGemmKernel_CompInt8_BlkLen32<HasZeroPoint>(
QuantA,
QuantBData,
QuantBScale,
Expand Down

0 comments on commit 277df8d

Please sign in to comment.