Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature SIMD support for Hamming Distance. #160

Merged
merged 3 commits into from
Jun 27, 2024

Conversation

kpango
Copy link
Member

@kpango kpango commented Jun 24, 2024

I have added full support for SSE2, AVX2, and AVX512 for performance improvement when using Distance Type Hamming.
The code was expanded to the required number of loops per instruction set using macro to UNROLL with bit specification.
The code has been designed to support dimensional data in multiples of 16.

AVX2 uses SSE2 for fraction processing, and AVX512 uses AVX2 and SSE2 for fraction processing.
No fractional processing is performed in SSE2 because it is assumed that no fractional numbers will be produced.

I separately wrote a following test code to test the comparison between the function implemented this time and existing functions and general Hamming functions.
The benchmark results for AVX2 and SSE2 are Error 0, so there should be no major problems with calculation accuracy and speed as shown in the following result.

However, I have not been able to test AVX512 because I do not have a test environment at this time.

Test Code is below:

#include <emmintrin.h>
#include <immintrin.h>
#include <cstdint>
#include <cstdio>
#include <cmath>
#include <chrono>
#include <random>
#include <vector>
#include <iostream>

#define UNROLL_MACRO_1(MACRO, BIT_SIZE) MACRO(0, 0)
#define UNROLL_MACRO_2(MACRO, BIT_SIZE) UNROLL_MACRO_1(MACRO, BIT_SIZE) MACRO(1, BIT_SIZE)
#define UNROLL_MACRO_4(MACRO, BIT_SIZE) UNROLL_MACRO_2(MACRO, BIT_SIZE) MACRO(2, BIT_SIZE * 2) MACRO(3, BIT_SIZE * 3)
#define UNROLL_MACRO_8(MACRO, BIT_SIZE) UNROLL_MACRO_4(MACRO, BIT_SIZE) MACRO(4, BIT_SIZE * 4) MACRO(5, BIT_SIZE * 5) MACRO(6, BIT_SIZE * 6) MACRO(7, BIT_SIZE * 7)

#define UNROLL_BODY_SSE2(i, BIT_SIZE)                                                                                                                                            \
    __m128i vres##i = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(uinta + BIT_SIZE)), _mm_loadu_si128(reinterpret_cast<const __m128i *>(uintb + BIT_SIZE))); \
    count += _mm_popcnt_u32(_mm_extract_epi32(vres##i, 0));                                                                                                                      \
    count += _mm_popcnt_u32(_mm_extract_epi32(vres##i, 1));                                                                                                                      \
    count += _mm_popcnt_u32(_mm_extract_epi32(vres##i, 2));                                                                                                                      \
    count += _mm_popcnt_u32(_mm_extract_epi32(vres##i, 3));

#define UNROLL_BODY_AVX2(i, BIT_SIZE)                                                                                                                                                     \
    __m256i vres##i = _mm256_xor_si256(_mm256_loadu_si256(reinterpret_cast<const __m256i *>(uinta + BIT_SIZE)), _mm256_loadu_si256(reinterpret_cast<const __m256i *>(uintb + BIT_SIZE))); \
    count += _mm_popcnt_u64(_mm256_extract_epi64(vres##i, 0));                                                                                                                            \
    count += _mm_popcnt_u64(_mm256_extract_epi64(vres##i, 1));                                                                                                                            \
    count += _mm_popcnt_u64(_mm256_extract_epi64(vres##i, 2));                                                                                                                            \
    count += _mm_popcnt_u64(_mm256_extract_epi64(vres##i, 3));

#define UNROLL_BODY_AVX512(i, BIT_SIZE)                                                                                                                                                   \
    __m512i vres##i = _mm512_xor_si512(_mm512_loadu_si512(reinterpret_cast<const __m512i *>(uinta + BIT_SIZE)), _mm512_loadu_si512(reinterpret_cast<const __m512i *>(uintb + BIT_SIZE))); \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 0));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 1));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 2));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 3));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 4));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 5));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 6));                                                                                                                            \
    count += _mm_popcnt_u64(_mm512_extract_epi64(vres##i, 7));

#define UNROLL_LOOP_SSE2(FACTOR) UNROLL_MACRO_##FACTOR(UNROLL_BODY_SSE2, 16)
#define UNROLL_LOOP_AVX2(FACTOR) UNROLL_MACRO_##FACTOR(UNROLL_BODY_AVX2, 32)
#define UNROLL_LOOP_AVX512(FACTOR) UNROLL_MACRO_##FACTOR(UNROLL_BODY_AVX512, 64)

#define PROCESS_LOOP(INSTRUCTION_SET, FACTOR, STEP) \
    while (uinta + STEP <= last)                    \
    {                                               \
        UNROLL_LOOP_##INSTRUCTION_SET(FACTOR)       \
            uinta += STEP;                          \
        uintb += STEP;                              \
    }

#define PROCESS_ALL_LOOPS(INSTRUCTION_SET, BIT_SIZE) \
    PROCESS_LOOP(INSTRUCTION_SET, 8, BIT_SIZE * 8)   \
    PROCESS_LOOP(INSTRUCTION_SET, 4, BIT_SIZE * 4)   \
    PROCESS_LOOP(INSTRUCTION_SET, 2, BIT_SIZE * 2)   \
    PROCESS_LOOP(INSTRUCTION_SET, 1, BIT_SIZE)

#define PROCESS_REMAINING_DATA_WITH_SSE2_AND_AVX2_AVX512() \
    if (uinta < last)                                      \
    {                                                      \
        PROCESS_LOOP(AVX2, 1, 32)                          \
        PROCESS_LOOP(SSE2, 1, 16)                          \
    }

#define PROCESS_REMAINING_DATA_WITH_SSE2_AVX2() \
    if (uinta < last)                           \
    {                                           \
        PROCESS_LOOP(SSE2, 1, 16)               \
    }

#define DO_NOTHING()

#define COMPARE_HAMMING_DISTANCE(INSTRUCTION_SET, BIT_SIZE, PROCESS_REMAINING_DATA) \
    const uint8_t *last = reinterpret_cast<const uint8_t *>(a + size);              \
    const uint8_t *uinta = reinterpret_cast<const uint8_t *>(a);                    \
    const uint8_t *uintb = reinterpret_cast<const uint8_t *>(b);                    \
    size_t count = 0;                                                               \
    PROCESS_ALL_LOOPS(INSTRUCTION_SET, BIT_SIZE)                                    \
    PROCESS_REMAINING_DATA                                                          \
    return static_cast<double>(count);

template <typename OBJECT_TYPE>
inline static double compareHammingDistance(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size)
{
#if defined(__AVX512F__)
    COMPARE_HAMMING_DISTANCE(AVX512, 64, PROCESS_REMAINING_DATA_WITH_SSE2_AND_AVX2_AVX512())
#elif defined(__AVX2__)
    COMPARE_HAMMING_DISTANCE(AVX2, 32, PROCESS_REMAINING_DATA_WITH_SSE2_AVX2())
#else
    COMPARE_HAMMING_DISTANCE(SSE2, 16, DO_NOTHING())
#endif
}

template <typename OBJECT_TYPE>
inline static double compareHammingDistanceBuiltin(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size)
{
    size_t count = 0;
    for (size_t i = 0; i < size; ++i)
    {
        count += __builtin_popcount(a[i] ^ b[i]);
    }
    return static_cast<double>(count);
}

template <typename OBJECT_TYPE>
inline static double compareHammingDistanceNGT(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size)
{
    const uint64_t *last = reinterpret_cast<const uint64_t *>(a + size);

    const uint64_t *uinta = reinterpret_cast<const uint64_t *>(a);
    const uint64_t *uintb = reinterpret_cast<const uint64_t *>(b);
    size_t count = 0;
    while (uinta < last)
    {
        count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
        count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
    }

    return static_cast<double>(count);
}

inline static double popCount(uint32_t x)
{
    x = (x & 0x55555555) + (x >> 1 & 0x55555555);
    x = (x & 0x33333333) + (x >> 2 & 0x33333333);
    x = (x & 0x0F0F0F0F) + (x >> 4 & 0x0F0F0F0F);
    x = (x & 0x00FF00FF) + (x >> 8 & 0x00FF00FF);
    x = (x & 0x0000FFFF) + (x >> 16 & 0x0000FFFF);
    return x;
}
template <typename OBJECT_TYPE>
inline static double compareHammingDistanceNoSIMD(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size)
{
    const uint32_t *last = reinterpret_cast<const uint32_t *>(a + size);
    const uint32_t *uinta = reinterpret_cast<const uint32_t *>(a);
    const uint32_t *uintb = reinterpret_cast<const uint32_t *>(b);
    size_t count = 0;
    while (uinta < last)
    {
        count += popCount(*uinta++ ^ *uintb++);
    }
    return static_cast<double>(count);
}

void generateTestData(std::vector<uint8_t> &data, std::vector<size_t> &dimensions, size_t num_tests, size_t max_dimension)
{
    std::mt19937 gen(42);
    std::uniform_int_distribution<> dis(0, 255);
    for (size_t i = 0; i < num_tests; ++i)
    {
        size_t dimension = ((rand() % (max_dimension / 16)) + 1) * 16;
        dimensions[i] = dimension;
        for (size_t j = 0; j < dimension; ++j)
        {
            data[i * max_dimension + j] = dis(gen);
        }
    }
}

int main()
{
    const size_t num_tests = 6000000;
    const size_t max_dimension = 8192;

    std::vector<uint8_t> test_data_a(max_dimension * num_tests);
    std::vector<uint8_t> test_data_b(max_dimension * num_tests);
    std::vector<size_t> dimensions(num_tests);

    generateTestData(test_data_a, dimensions, num_tests, max_dimension);
    generateTestData(test_data_b, dimensions, num_tests, max_dimension);

    std::vector<double> primitive_distances(num_tests);
    for (size_t i = 0; i < num_tests; ++i)
    {
        primitive_distances[i] = compareHammingDistanceBuiltin(test_data_a.data() + i * max_dimension, test_data_b.data() + i * max_dimension, dimensions[i]);
    }

    auto benchmark = [&](auto hammingDistanceFunc, const char *label)
    {
        double total_time = 0.0;
        size_t errors = 0;

        for (size_t n = 0; n < 100; ++n)
        {
            auto start = std::chrono::high_resolution_clock::now();
            for (size_t i = 0; i < num_tests; ++i)
            {
                double simd_distance = hammingDistanceFunc(test_data_a.data() + i * max_dimension, test_data_b.data() + i * max_dimension, dimensions[i]);
                if (simd_distance != primitive_distances[i])
                {
                    ++errors;
                }
            }
            auto end = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> elapsed = end - start;
            total_time += elapsed.count();
        }

        std::cout << label << ": " << total_time << " seconds, Errors: " << errors << "\n";
    };

    std::cout << "start benchmarking \n";
    benchmark(compareHammingDistanceBuiltin<uint8_t>, "Builtin POPCNT Hamming Distance");
    benchmark(compareHammingDistanceNoSIMD<uint8_t>, "NGT Original Non-SIMD Hamming Distance");
    benchmark(compareHammingDistanceNGT<uint8_t>, "NGT Original Hamming Distance");
    benchmark(compareHammingDistance<uint8_t>, "New Macro and SIMD Hamming Distance");
    std::cout << "benchmark finished \n";

    return 0;
}

Test Result:

> ./hamming_distance_avx2
start benchmarking
Builtin POPCNT Hamming Distance: 1481.64 seconds, Errors: 0
NGT Original Non-SIMD Hamming Distance: 2569.08 seconds, Errors: 0
NGT Original Hamming Distance: 393.814 seconds, Errors: 0
New Macro and SIMD Hamming Distance: 367.782 seconds, Errors: 0
benchmark finished

> ./hamming_distance_sse2
start benchmarking
Builtin POPCNT Hamming Distance: 1474.34 seconds, Errors: 0
NGT Original Non-SIMD Hamming Distance: 2581.55 seconds, Errors: 0
NGT Original Hamming Distance: 397.84 seconds, Errors: 0
New Macro and SIMD Hamming Distance: 433.393 seconds, Errors: 0
benchmark finished

> ./hamming_distance_avx2
start benchmarking
Builtin POPCNT Hamming Distance: 1477.9 seconds, Errors: 0
NGT Original Non-SIMD Hamming Distance: 2585.72 seconds, Errors: 0
NGT Original Hamming Distance: 402.682 seconds, Errors: 0
New Macro and SIMD Hamming Distance: 381.26 seconds, Errors: 0
New Template and SIMD Hamming Distance: 380.898 seconds, Errors: 0
benchmark finished

> ./hamming_distance_sse2
start benchmarking
Builtin POPCNT Hamming Distance: 1488.5 seconds, Errors: 0
NGT Original Non-SIMD Hamming Distance: 2572.27 seconds, Errors: 0
NGT Original Hamming Distance: 407.198 seconds, Errors: 0
New Macro and SIMD Hamming Distance: 429.891 seconds, Errors: 0
New Template and SIMD Hamming Distance: 433.764 seconds, Errors: 0
benchmark finished

@kpango kpango force-pushed the feature/hamming-distance/support-simd branch from 671603e to 949a425 Compare June 24, 2024 17:58
…improve PrimitiveComparator performance

Signed-off-by: kpango <kpango@vdaas.org>
@kpango kpango force-pushed the feature/hamming-distance/support-simd branch from 949a425 to a60af08 Compare June 24, 2024 17:58
@masajiro
Copy link
Member

Thanks!

@masajiro masajiro merged commit 65e5729 into main Jun 27, 2024
@kpango kpango deleted the feature/hamming-distance/support-simd branch June 27, 2024 04:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants