diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index ca829bc5bfca..7a60a8f5c6a4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -153,206 +153,6 @@ namespace // The ComputeRxC functions compute an R row by C column tile of the output matrix. // -template -MLAS_FORCEINLINE void -SQ4BitGemm_CompInt8_Compute2x4_BlkLenGreaterThan16( - size_t BlkLen, - const std::byte* QuantARowPtr, - const std::byte* QuantBDataColPtr, - const float* QuantBScaleColPtr, - const std::byte* QuantBZeroPointColPtr, - const float* BiasPtr, - float* SumPtr, - size_t BlockCountK, - size_t StrideQuantA, - size_t StrideQuantBData, - size_t StrideQuantBScale, - size_t StrideQuantBZeroPoint, - size_t ldc -) -{ - // process blocks in 32-element sub-blocks - const size_t SubBlksPerBlk = BlkLen / 32; - - const std::byte* QuantAPtr = QuantARowPtr; - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - float32x4_t acc00{}, acc01{}, acc02{}, acc03{}, acc10{}, acc11{}, acc12{}, acc13{}; - - for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { - const std::byte* QuantABlkRow0 = QuantAPtr; - const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; - - const float QuantBScaleCol0 = *QuantBScalePtr; - const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale * 1); - const float QuantBScaleCol2 = *(QuantBScalePtr + StrideQuantBScale * 2); - const float QuantBScaleCol3 = *(QuantBScalePtr + StrideQuantBScale * 3); - - // compute combined scales - const float scale00 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol0; - const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; - const float scale02 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol2; - const float scale03 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol3; - const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; - const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; - const float scale12 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol2; - const float scale13 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol3; - - // load B zero point - int8_t bzp_col0, bzp_col1, bzp_col2, bzp_col3; - if constexpr (HasZeroPoint) { - const std::byte QuantBZeroPointByteCol0 = *QuantBZeroPointPtr; - const std::byte QuantBZeroPointByteCol1 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint * 1); - const std::byte QuantBZeroPointByteCol2 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint * 2); - const std::byte QuantBZeroPointByteCol3 = *(QuantBZeroPointPtr + StrideQuantBZeroPoint * 3); - if ((k_blk_idx & 1) == 0) { - bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 & std::byte{0x0F}); - bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 & std::byte{0x0F}); - bzp_col2 = std::to_integer(QuantBZeroPointByteCol2 & std::byte{0x0F}); - bzp_col3 = std::to_integer(QuantBZeroPointByteCol3 & std::byte{0x0F}); - } else { - bzp_col0 = std::to_integer(QuantBZeroPointByteCol0 >> 4); - bzp_col1 = std::to_integer(QuantBZeroPointByteCol1 >> 4); - bzp_col2 = std::to_integer(QuantBZeroPointByteCol2 >> 4); - bzp_col3 = std::to_integer(QuantBZeroPointByteCol3 >> 4); - } - } else { - bzp_col0 = 8; - bzp_col1 = 8; - bzp_col2 = 8; - bzp_col3 = 8; - } - - const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); - const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); - - for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; ++sub_blk_idx) { - // load A - const int8x16_t av_row0_0 = vld1q_s8(QuantADataPtrRow0 + 0); - const int8x16_t av_row0_1 = vld1q_s8(QuantADataPtrRow0 + 16); - const int8x16_t av_row1_0 = vld1q_s8(QuantADataPtrRow1 + 0); - const int8x16_t av_row1_1 = vld1q_s8(QuantADataPtrRow1 + 16); - - // columns 0 and 1 of B - { - // load B - const uint8x16_t bv_packed_col0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); - const uint8x16_t bv_packed_col1 = - vld1q_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - int8x16_t bv_col0_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col0, LowMaskU8x16)); - int8x16_t bv_col0_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col0, 4)); - int8x16_t bv_col1_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col1, LowMaskU8x16)); - int8x16_t bv_col1_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col1, 4)); - - // subtract B zero point - bv_col0_0 = vsubq_s8(bv_col0_0, vdupq_n_s8(bzp_col0)); - bv_col0_1 = vsubq_s8(bv_col0_1, vdupq_n_s8(bzp_col0)); - bv_col1_0 = vsubq_s8(bv_col1_0, vdupq_n_s8(bzp_col1)); - bv_col1_1 = vsubq_s8(bv_col1_1, vdupq_n_s8(bzp_col1)); - - // quantized dot product - int32x4_t dot00{}, dot01{}, dot10{}, dot11{}; - dot00 = vdotq_s32(vdotq_s32(dot00, av_row0_0, bv_col0_0), av_row0_1, bv_col0_1); - dot01 = vdotq_s32(vdotq_s32(dot01, av_row0_0, bv_col1_0), av_row0_1, bv_col1_1); - dot10 = vdotq_s32(vdotq_s32(dot10, av_row1_0, bv_col0_0), av_row1_1, bv_col0_1); - dot11 = vdotq_s32(vdotq_s32(dot11, av_row1_0, bv_col1_0), av_row1_1, bv_col1_1); - - // convert to float - const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); - const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); - const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); - const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); - - // multiply by scale and update accumulator - acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); - acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); - acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); - acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); - } - - // columns 2 and 3 of B - { - // load B - const uint8x16_t bv_packed_col2 = - vld1q_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData * 2); - const uint8x16_t bv_packed_col3 = - vld1q_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData * 3); - - const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); - - int8x16_t bv_col2_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col2, LowMaskU8x16)); - int8x16_t bv_col2_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col2, 4)); - int8x16_t bv_col3_0 = vreinterpretq_s8_u8(vandq_u8(bv_packed_col3, LowMaskU8x16)); - int8x16_t bv_col3_1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed_col3, 4)); - - // subtract B zero point - bv_col2_0 = vsubq_s8(bv_col2_0, vdupq_n_s8(bzp_col2)); - bv_col2_1 = vsubq_s8(bv_col2_1, vdupq_n_s8(bzp_col2)); - bv_col3_0 = vsubq_s8(bv_col3_0, vdupq_n_s8(bzp_col3)); - bv_col3_1 = vsubq_s8(bv_col3_1, vdupq_n_s8(bzp_col3)); - - // quantized dot product - int32x4_t dot02{}, dot03{}, dot12{}, dot13{}; - dot02 = vdotq_s32(vdotq_s32(dot02, av_row0_0, bv_col2_0), av_row0_1, bv_col2_1); - dot03 = vdotq_s32(vdotq_s32(dot03, av_row0_0, bv_col3_0), av_row0_1, bv_col3_1); - dot12 = vdotq_s32(vdotq_s32(dot12, av_row1_0, bv_col2_0), av_row1_1, bv_col2_1); - dot13 = vdotq_s32(vdotq_s32(dot13, av_row1_0, bv_col3_0), av_row1_1, bv_col3_1); - - // convert to float - const float32x4_t dot_f32_02 = vcvtq_f32_s32(dot02); - const float32x4_t dot_f32_03 = vcvtq_f32_s32(dot03); - const float32x4_t dot_f32_12 = vcvtq_f32_s32(dot12); - const float32x4_t dot_f32_13 = vcvtq_f32_s32(dot13); - - // multiply by scale and update accumulator - acc02 = vfmaq_f32(acc02, dot_f32_02, vdupq_n_f32(scale02)); - acc03 = vfmaq_f32(acc03, dot_f32_03, vdupq_n_f32(scale03)); - acc12 = vfmaq_f32(acc12, dot_f32_12, vdupq_n_f32(scale12)); - acc13 = vfmaq_f32(acc13, dot_f32_13, vdupq_n_f32(scale13)); - } - - // increment block data pointers to next sub-block - QuantADataPtrRow0 += 32; - QuantADataPtrRow1 += 32; - QuantBDataPtr += 16; - } - - // increment other block pointers - - QuantAPtr += Q8BlkSize(BlkLen); - QuantBScalePtr += 1; - - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1; - } - } - - SumPtr[ldc * 0 + 0] = vaddvq_f32(acc00); - SumPtr[ldc * 0 + 1] = vaddvq_f32(acc01); - SumPtr[ldc * 0 + 2] = vaddvq_f32(acc02); - SumPtr[ldc * 0 + 3] = vaddvq_f32(acc03); - SumPtr[ldc * 1 + 0] = vaddvq_f32(acc10); - SumPtr[ldc * 1 + 1] = vaddvq_f32(acc11); - SumPtr[ldc * 1 + 2] = vaddvq_f32(acc12); - SumPtr[ldc * 1 + 3] = vaddvq_f32(acc13); - - if (BiasPtr != nullptr) { - SumPtr[ldc * 0 + 0] += BiasPtr[0]; - SumPtr[ldc * 0 + 1] += BiasPtr[1]; - SumPtr[ldc * 0 + 2] += BiasPtr[2]; - SumPtr[ldc * 0 + 3] += BiasPtr[3]; - SumPtr[ldc * 1 + 0] += BiasPtr[0]; - SumPtr[ldc * 1 + 1] += BiasPtr[1]; - SumPtr[ldc * 1 + 2] += BiasPtr[2]; - SumPtr[ldc * 1 + 3] += BiasPtr[3]; - } -} - template MLAS_FORCEINLINE void SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( @@ -1320,138 +1120,7 @@ SQ4BitGemmKernel_CompInt8_BlkLen16( template void -SQ4BitGemmKernel_CompInt8_BlkLen32_2x4Tiling( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - size_t ldc, - const float* Bias -) -{ - constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = 32; - - const size_t StrideQuantA = BlockCountK * Q8BlkSize(BlkLen); - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - const std::byte* QuantARowPtr = QuantA; - - float* SumRowPtr = C; - - size_t m_remaining = CountM; - while (m_remaining > 1) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - const float* BiasPtr = Bias; - - float* SumPtr = SumRowPtr; - - size_t n_remaining = CountN; - while (n_remaining > 3) { - // Compute 2x4 tiles of output - SQ4BitGemm_CompInt8_Compute2x4_BlkLenGreaterThan16( - BlkLen, - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK, - StrideQuantA, - StrideQuantBData, - StrideQuantBScale, - StrideQuantBZeroPoint, - ldc - ); - - // Move to next 4 columns - AdvanceColPtrs<4, HasZeroPoint>( - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr - ); - - n_remaining -= 4; - } - - while (n_remaining > 0) { - // Compute 2x1 tiles of output - for (size_t i = 0; i < 2; ++i) { - SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( - QuantARowPtr + StrideQuantA * i, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr + ldc * i, - BlockCountK - ); - } - - // Move to next column - AdvanceColPtrs<1, HasZeroPoint>( - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr - ); - - n_remaining -= 1; - } - - // Move to next 2 rows - AdvanceRowPtrs<2>( - StrideQuantA, ldc, - QuantARowPtr, SumRowPtr - ); - - m_remaining -= 2; - } - - if (m_remaining > 0) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - - const float* BiasPtr = Bias; - - float* SumPtr = SumRowPtr; - - size_t n_remaining = CountN; - while (n_remaining > 0) { - // Compute 1x1 tiles of output - SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK - ); - - // Move to next column - AdvanceColPtrs<1, HasZeroPoint>( - StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, - QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, BiasPtr, SumPtr - ); - - n_remaining -= 1; - } - } -} - -template -void -SQ4BitGemmKernel_CompInt8_BlkLen32_4x2Tiling( +SQ4BitGemmKernel_CompInt8_BlkLen32( const std::byte* QuantA, const std::byte* QuantBData, const float* QuantBScale, @@ -1744,7 +1413,7 @@ SQ4BitGemmKernel_CompInt8_DispatchOnBlkLen( Bias ); } else if (BlkLen == 32) { - SQ4BitGemmKernel_CompInt8_BlkLen32_2x4Tiling( + SQ4BitGemmKernel_CompInt8_BlkLen32( QuantA, QuantBData, QuantBScale,