diff --git a/.clang-format b/.clang-format index 27d4923fe..239b228ab 100644 --- a/.clang-format +++ b/.clang-format @@ -17,7 +17,7 @@ IncludeBlocks: Regroup IncludeCategories: # C system headers. The header_dependency_test.py contains a copy of this # list; be sure to update that test anytime this list changes. - - Regex: '^[<"](aio|arpa/inet|assert|complex|cpio|ctype|curses|dirent|dlfcn|errno|fcntl|fenv|float|fmtmsg|fnmatch|ftw|glob|grp|iconv|inttypes|iso646|langinfo|libgen|limits|locale|math|monetary|mqueue|ndbm|netdb|net/if|netinet/in|netinet/tcp|nl_types|poll|pthread|pwd|regex|sched|search|semaphore|setjmp|signal|spawn|stdalign|stdarg|stdatomic|stdbool|stddef|stdint|stdio|stdlib|stdnoreturn|string|strings|stropts|sys/ioctl|sys/ipc|syslog|sys/mman|sys/msg|sys/resource|sys/select|sys/sem|sys/shm|sys/socket|sys/stat|sys/statvfs|sys/time|sys/times|sys/types|sys/uio|sys/un|sys/utsname|sys/wait|tar|term|termios|tgmath|threads|time|trace|uchar|ulimit|uncntrl|unistd|utime|utmpx|wchar|wctype|wordexp)\.h[">]$' + - Regex: '^[<"](aio|arm_neon|arpa/inet|assert|complex|cpio|ctype|curses|dirent|dlfcn|errno|fcntl|fenv|float|fmtmsg|fnmatch|ftw|glob|grp|iconv|immintrin|inttypes|iso646|langinfo|libgen|limits|locale|math|monetary|mqueue|ndbm|netdb|net/if|netinet/in|netinet/tcp|nl_types|poll|pthread|pwd|regex|sched|search|semaphore|setjmp|signal|spawn|stdalign|stdarg|stdatomic|stdbool|stddef|stdint|stdio|stdlib|stdnoreturn|string|strings|stropts|sys/ioctl|sys/ipc|syslog|sys/mman|sys/msg|sys/resource|sys/select|sys/sem|sys/shm|sys/socket|sys/stat|sys/statvfs|sys/time|sys/times|sys/types|sys/uio|sys/un|sys/utsname|sys/wait|tar|term|termios|tgmath|threads|time|trace|uchar|ulimit|uncntrl|unistd|utime|utmpx|wchar|wctype|wordexp)\.h[">]$' Priority: 20 # C++ system headers (as of C++23). The header_dependency_test.py contains a # copy of this list; be sure to update that test anytime this list changes. diff --git a/tachyon/math/base/BUILD.bazel b/tachyon/math/base/BUILD.bazel index f1a519029..2180ae036 100644 --- a/tachyon/math/base/BUILD.bazel +++ b/tachyon/math/base/BUILD.bazel @@ -1,5 +1,7 @@ +load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64") load( "//bazel:tachyon_cc.bzl", + "tachyon_avx512_defines", "tachyon_cc_benchmark", "tachyon_cc_library", "tachyon_cc_unittest", @@ -120,6 +122,33 @@ tachyon_cc_library( ], ) +tachyon_cc_library( + name = "simd_int", + srcs = if_x86_64([ + "simd_int128_x86.cc", + "simd_int256_x86.cc", + ]) + if_has_avx512([ + "simd_int512_x86.cc", + ]) + if_aarch64([ + "simd_int128_arm64.cc", + ]), + hdrs = ["simd_int.h"], + copts = if_x86_64([ + "-mavx2", + ]) + if_has_avx512([ + "-mavx512f", + ]), + defines = tachyon_avx512_defines(), + deps = [ + "//tachyon:export", + "//tachyon/base:bit_cast", + "//tachyon/base:compiler_specific", + "//tachyon/base:logging", + "//tachyon/build:build_config", + "//tachyon/math/base:big_int", + ], +) + tachyon_cc_unittest( name = "base_unittests", srcs = [ @@ -132,6 +161,7 @@ tachyon_cc_unittest( "rational_field_unittest.cc", "semigroups_unittest.cc", "sign_unittest.cc", + "simd_int_unittest.cc", ], deps = [ ":big_int", @@ -140,6 +170,7 @@ tachyon_cc_unittest( ":groups", ":rational_field", ":sign", + ":simd_int", "//tachyon/base:optional", "//tachyon/base/buffer:vector_buffer", "//tachyon/base/containers:container_util", diff --git a/tachyon/math/base/big_int.h b/tachyon/math/base/big_int.h index 37e124441..fadc60fd1 100644 --- a/tachyon/math/base/big_int.h +++ b/tachyon/math/base/big_int.h @@ -520,6 +520,28 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt { return *this; } + constexpr BigInt operator<<(uint32_t count) const { + BigInt ret = *this; + ret.MulBy2ExpInPlace(count); + return ret; + } + + constexpr BigInt& operator<<=(uint32_t count) { + MulBy2ExpInPlace(count); + return *this; + } + + constexpr BigInt operator>>(uint32_t count) const { + BigInt ret = *this; + ret.DivBy2ExpInPlace(count); + return ret; + } + + constexpr BigInt& operator>>=(uint32_t count) { + DivBy2ExpInPlace(count); + return *this; + } + constexpr BigInt Add(const BigInt& other) const { uint64_t unused = 0; return Add(other, unused); diff --git a/tachyon/math/base/simd_int.h b/tachyon/math/base/simd_int.h new file mode 100644 index 000000000..89baf3c68 --- /dev/null +++ b/tachyon/math/base/simd_int.h @@ -0,0 +1,125 @@ +// Copyright 2024 Ulvetanna Inc. +// Use of this source code is governed by a Apache-2.0 style license that +// can be found in the LICENSE.ulvetanna file. + +#ifndef TACHYON_MATH_BASE_SIMD_INT_H_ +#define TACHYON_MATH_BASE_SIMD_INT_H_ + +#include +#include + +#include +#include + +#include "tachyon/build/build_config.h" +#include "tachyon/export.h" +#include "tachyon/math/base/big_int.h" + +namespace tachyon::math { + +template +class SimdInt { + public: + constexpr static size_t kBits = Bits; + constexpr static size_t kLimbNums = Bits / 64; + + using value_type = BigInt; + + SimdInt() = default; + template >* = nullptr> + explicit SimdInt(T value) : SimdInt(value_type(value)) {} + explicit SimdInt(const value_type& value) : value_(value) {} + + static SimdInt Zero() { return SimdInt(); } + static SimdInt One() { return SimdInt(1); } + static SimdInt Max() { return SimdInt(value_type::Max()); } + static SimdInt Broadcast(uint8_t value); + static SimdInt Broadcast(uint16_t value); + static SimdInt Broadcast(uint32_t value); + static SimdInt Broadcast(uint64_t value); + static SimdInt Random() { return SimdInt(value_type::Random()); } + + const value_type& value() const { return value_; } + + bool IsZero() const { return *this == Zero(); } + bool IsOne() const { return *this == One(); } + bool IsMax() const { return *this == Max(); } + + bool operator==(const SimdInt& other) const; + bool operator!=(const SimdInt& other) const { return !operator==(other); } + + SimdInt operator&(const SimdInt& other) const; + SimdInt& operator&=(const SimdInt& other) { return *this = *this & other; } + + SimdInt operator|(const SimdInt& other) const; + SimdInt& operator|=(const SimdInt& other) { return *this = *this | other; } + + SimdInt operator^(const SimdInt& other) const; + SimdInt& operator^=(const SimdInt& other) { return *this = *this ^ other; } + + SimdInt operator!() const { return *this ^ Max(); } + + SimdInt operator>>(uint32_t count) const; + SimdInt& operator>>=(uint32_t count) { return *this = *this >> count; } + + SimdInt operator<<(uint32_t count) const; + SimdInt& operator<<=(uint32_t count) { return *this = *this << count; } + + std::string ToString() const { return value_.ToString(); } + std::string ToHexString(bool pad_zero = false) const { + return value_.ToHexString(pad_zero); + } + + private: + value_type value_; +}; + +// clang-format off +#define SPECIALIZE_SIMD_INT(bits) \ + using SimdInt##bits = SimdInt; \ + \ + template <> \ + SimdInt##bits SimdInt##bits::Broadcast(uint8_t value); \ + \ + template <> \ + SimdInt##bits SimdInt##bits::Broadcast(uint16_t value); \ + \ + template <> \ + SimdInt##bits SimdInt##bits::Broadcast(uint32_t value); \ + \ + template <> \ + SimdInt##bits SimdInt##bits::Broadcast(uint64_t value); \ + \ + template <> \ + bool SimdInt##bits::operator==(const SimdInt##bits& value) const; \ + \ + template <> \ + SimdInt##bits SimdInt##bits::operator&(const SimdInt##bits& value) const; \ + \ + template <> \ + SimdInt##bits SimdInt##bits::operator|(const SimdInt##bits& value) const; \ + \ + template <> \ + SimdInt##bits SimdInt##bits::operator^(const SimdInt##bits& value) const; \ + \ + template <> \ + SimdInt##bits SimdInt##bits::operator>>(uint32_t count) const; \ + \ + template <> \ + SimdInt##bits SimdInt##bits::operator<<(uint32_t count) const +// clang-format on + +SPECIALIZE_SIMD_INT(128); +#if ARCH_CPU_X86_64 +SPECIALIZE_SIMD_INT(256); +#if defined(TACHYON_HAS_AVX512) +SPECIALIZE_SIMD_INT(512); +#endif +#endif + +#undef SPECIALIZE_SIMD_INT + +} // namespace tachyon::math + +#endif // TACHYON_MATH_BASE_SIMD_INT_H_ diff --git a/tachyon/math/base/simd_int128_arm64.cc b/tachyon/math/base/simd_int128_arm64.cc new file mode 100644 index 000000000..f4a8c73cd --- /dev/null +++ b/tachyon/math/base/simd_int128_arm64.cc @@ -0,0 +1,106 @@ +// Copyright 2024 Ulvetanna Inc. +// Use of this source code is governed by a Apache-2.0 style license that +// can be found in the LICENSE.ulvetanna file. + +#include + +#include + +#include "tachyon/base/bit_cast.h" +#include "tachyon/base/compiler_specific.h" +#include "tachyon/base/logging.h" +#include "tachyon/math/base/simd_int.h" + +namespace tachyon::math { + +namespace { + +uint8x16_t ToVector(const SimdInt128& value) { + return vld1q_u8(reinterpret_cast(&value)); +} + +SimdInt128 FromVector(uint8x16_t vector) { + SimdInt128 ret; + vst1q_u8(reinterpret_cast(&ret), vector); + return ret; +} + +SimdInt128 FromVector(uint16x8_t vector) { + SimdInt128 ret; + vst1q_u16(reinterpret_cast(&ret), vector); + return ret; +} + +SimdInt128 FromVector(uint32x4_t vector) { + SimdInt128 ret; + vst1q_u32(reinterpret_cast(&ret), vector); + return ret; +} + +SimdInt128 FromVector(uint64x2_t vector) { + SimdInt128 ret; + vst1q_u64(reinterpret_cast(&ret), vector); + return ret; +} + +} // namespace + +// static +template <> +SimdInt128 SimdInt128::Broadcast(uint8_t value) { + return FromVector(vdupq_n_u8(value)); +} + +// static +template <> +SimdInt128 SimdInt128::Broadcast(uint16_t value) { + return FromVector(vdupq_n_u16(value)); +} + +// static +template <> +SimdInt128 SimdInt128::Broadcast(uint32_t value) { + return FromVector(vdupq_n_u32(value)); +} + +// static +template <> +SimdInt128 SimdInt128::Broadcast(uint64_t value) { + return FromVector(vdupq_n_u64(value)); +} + +template <> +bool SimdInt128::operator==(const SimdInt128& other) const { + return value_ == other.value_; +} + +template <> +SimdInt128 SimdInt128::operator&(const SimdInt128& other) const { + return FromVector(vandq_u8(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt128 SimdInt128::operator|(const SimdInt128& other) const { + return FromVector(vorrq_u8(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt128 SimdInt128::operator^(const SimdInt128& other) const { + return FromVector(veorq_u8(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt128 SimdInt128::operator>>(uint32_t count) const { + if (UNLIKELY(count == 0)) return *this; + if (UNLIKELY(count >= 128)) return SimdInt128(); + return SimdInt128(value_ >> count); +} + +template <> +SimdInt128 SimdInt128::operator<<(uint32_t count) const { + if (UNLIKELY(count == 0)) return *this; + if (UNLIKELY(count >= 128)) return SimdInt128(); + return SimdInt128(value_ << count); +} + +} // namespace tachyon::math diff --git a/tachyon/math/base/simd_int128_x86.cc b/tachyon/math/base/simd_int128_x86.cc new file mode 100644 index 000000000..597354dd6 --- /dev/null +++ b/tachyon/math/base/simd_int128_x86.cc @@ -0,0 +1,123 @@ +// Copyright 2024 Ulvetanna Inc. +// Use of this source code is governed by a Apache-2.0 style license that +// can be found in the LICENSE.ulvetanna file. + +#include + +#include + +#include "tachyon/base/bit_cast.h" +#include "tachyon/base/compiler_specific.h" +#include "tachyon/base/logging.h" +#include "tachyon/math/base/simd_int.h" + +namespace tachyon::math { + +namespace { + +__m128i ToVector(const SimdInt128& value) { + return _mm_loadu_si128(reinterpret_cast(&value)); +} + +SimdInt128 FromVector(__m128i vector) { + SimdInt128 ret; + _mm_storeu_si128(reinterpret_cast<__m128i*>(&ret), vector); + return ret; +} + +} // namespace + +// static +template <> +SimdInt128 SimdInt128::Broadcast(uint8_t value) { + return FromVector(_mm_set1_epi8(value)); +} + +// static +template <> +SimdInt128 SimdInt128::Broadcast(uint16_t value) { + return FromVector(_mm_set1_epi16(value)); +} + +// static +template <> +SimdInt128 SimdInt128::Broadcast(uint32_t value) { + return FromVector(_mm_set1_epi32(value)); +} + +// static +template <> +SimdInt128 SimdInt128::Broadcast(uint64_t value) { + return FromVector(_mm_set1_epi64(base::bit_cast<__m64>(value))); +} + +template <> +bool SimdInt128::operator==(const SimdInt128& other) const { + __m128i neq = _mm_xor_si128(ToVector(*this), ToVector(other)); + return _mm_test_all_zeros(neq, neq) == 1; +} + +template <> +SimdInt128 SimdInt128::operator&(const SimdInt128& other) const { + return FromVector(_mm_and_si128(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt128 SimdInt128::operator|(const SimdInt128& other) const { + return FromVector(_mm_or_si128(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt128 SimdInt128::operator^(const SimdInt128& other) const { + return FromVector(_mm_xor_si128(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt128 SimdInt128::operator>>(uint32_t count) const { + if (UNLIKELY(count == 0)) return *this; + if (UNLIKELY(count >= 128)) return SimdInt128(); + // See + // https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688 + for (uint32_t i = 1; i < 128; ++i) { + if (i == count) { + __m128i carry = _mm_bsrli_si128(ToVector(*this), 8); + if (count >= 64) { + return FromVector( + _mm_srli_epi64(carry, std::max(count - 64, uint32_t{0}))); + } else { + carry = _mm_slli_epi64(carry, std::max(64 - count, uint32_t{0})); + + __m128i val = _mm_srli_epi64(ToVector(*this), count); + return FromVector(_mm_or_si128(val, carry)); + } + } + } + NOTREACHED(); + return SimdInt128(); +} + +template <> +SimdInt128 SimdInt128::operator<<(uint32_t count) const { + if (UNLIKELY(count == 0)) return *this; + if (UNLIKELY(count >= 128)) return SimdInt128(); + // See + // https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688 + for (uint32_t i = 1; i < 128; ++i) { + if (i == count) { + __m128i carry = _mm_bslli_si128(ToVector(*this), 8); + if (count >= 64) { + return FromVector( + _mm_slli_epi64(carry, std::max(count - 64, uint32_t{0}))); + } else { + carry = _mm_srli_epi64(carry, std::max(64 - count, uint32_t{0})); + + __m128i val = _mm_slli_epi64(ToVector(*this), count); + return FromVector(_mm_or_si128(val, carry)); + } + } + } + NOTREACHED(); + return SimdInt128(); +} + +} // namespace tachyon::math diff --git a/tachyon/math/base/simd_int256_x86.cc b/tachyon/math/base/simd_int256_x86.cc new file mode 100644 index 000000000..794accaaf --- /dev/null +++ b/tachyon/math/base/simd_int256_x86.cc @@ -0,0 +1,90 @@ +// Copyright 2024 Ulvetanna Inc. +// Use of this source code is governed by a Apache-2.0 style license that +// can be found in the LICENSE.ulvetanna file. + +#include + +#include "tachyon/base/bit_cast.h" +#include "tachyon/base/compiler_specific.h" +#include "tachyon/math/base/simd_int.h" + +namespace tachyon::math { + +namespace { + +__m256i ToVector(const SimdInt256& value) { + return base::bit_cast<__m256i>(value); +} + +SimdInt256 FromVector(__m256i vector) { + SimdInt256 ret; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&ret), vector); + return ret; +} + +} // namespace + +// static +template <> +SimdInt256 SimdInt256::Broadcast(uint8_t value) { + return FromVector(_mm256_set1_epi8(value)); +} + +// static +template <> +SimdInt256 SimdInt256::Broadcast(uint16_t value) { + return FromVector(_mm256_set1_epi16(value)); +} + +// static +template <> +SimdInt256 SimdInt256::Broadcast(uint32_t value) { + return FromVector(_mm256_set1_epi32(value)); +} + +// static +template <> +SimdInt256 SimdInt256::Broadcast(uint64_t value) { + return FromVector(_mm256_set1_epi64x(value)); +} + +template <> +bool SimdInt256::operator==(const SimdInt256& other) const { + __m256i pcmp = _mm256_cmpeq_epi32(ToVector(*this), ToVector(other)); + unsigned int bitmask = + base::bit_cast(_mm256_movemask_epi8(pcmp)); + return bitmask == 0xffffffff; +} + +template <> +SimdInt256 SimdInt256::operator&(const SimdInt256& other) const { + return FromVector(_mm256_and_si256(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt256 SimdInt256::operator|(const SimdInt256& other) const { + return FromVector(_mm256_or_si256(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt256 SimdInt256::operator^(const SimdInt256& other) const { + return FromVector(_mm256_xor_si256(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt256 SimdInt256::operator>>(uint32_t count) const { + if (UNLIKELY(count == 0)) return *this; + if (UNLIKELY(count >= 256)) return SimdInt256(); + // TODO(chokobole): Optimize this. + return SimdInt256(value_ >> count); +} + +template <> +SimdInt256 SimdInt256::operator<<(uint32_t count) const { + if (UNLIKELY(count == 0)) return *this; + if (UNLIKELY(count >= 256)) return SimdInt256(); + // TODO(chokobole): Optimize this. + return SimdInt256(value_ << count); +} + +} // namespace tachyon::math diff --git a/tachyon/math/base/simd_int512_x86.cc b/tachyon/math/base/simd_int512_x86.cc new file mode 100644 index 000000000..06a478870 --- /dev/null +++ b/tachyon/math/base/simd_int512_x86.cc @@ -0,0 +1,86 @@ +// Copyright 2024 Ulvetanna Inc. +// Use of this source code is governed by a Apache-2.0 style license that +// can be found in the LICENSE.ulvetanna file. + +#include + +#include "tachyon/base/bit_cast.h" +#include "tachyon/base/compiler_specific.h" +#include "tachyon/math/base/simd_int.h" + +namespace tachyon::math { + +namespace { + +__m512i ToVector(const SimdInt512& value) { return _mm512_loadu_si512(&value); } + +SimdInt512 FromVector(__m512i vector) { + SimdInt512 ret; + _mm512_storeu_si512(&ret, vector); + return ret; +} + +} // namespace + +// static +template <> +SimdInt512 SimdInt512::Broadcast(uint8_t value) { + return FromVector(_mm512_set1_epi8(value)); +} + +// static +template <> +SimdInt512 SimdInt512::Broadcast(uint16_t value) { + return FromVector(_mm512_set1_epi16(value)); +} + +// static +template <> +SimdInt512 SimdInt512::Broadcast(uint32_t value) { + return FromVector(_mm512_set1_epi32(value)); +} + +// static +template <> +SimdInt512 SimdInt512::Broadcast(uint64_t value) { + return FromVector(_mm512_set1_epi64(value)); +} + +template <> +bool SimdInt512::operator==(const SimdInt512& other) const { + __mmask16 pcmp = _mm512_cmpeq_epi32_mask(ToVector(*this), ToVector(other)); + return pcmp == 0xffff; +} + +template <> +SimdInt512 SimdInt512::operator&(const SimdInt512& other) const { + return FromVector(_mm512_and_si512(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt512 SimdInt512::operator|(const SimdInt512& other) const { + return FromVector(_mm512_or_si512(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt512 SimdInt512::operator^(const SimdInt512& other) const { + return FromVector(_mm512_xor_si512(ToVector(*this), ToVector(other))); +} + +template <> +SimdInt512 SimdInt512::operator>>(uint32_t count) const { + if (UNLIKELY(count == 0)) return *this; + if (UNLIKELY(count >= 512)) return SimdInt512(); + // TODO(chokobole): Optimize this. + return SimdInt512(value_ >> count); +} + +template <> +SimdInt512 SimdInt512::operator<<(uint32_t count) const { + if (UNLIKELY(count == 0)) return *this; + if (UNLIKELY(count >= 512)) return SimdInt512(); + // TODO(chokobole): Optimize this. + return SimdInt512(value_ << count); +} + +} // namespace tachyon::math diff --git a/tachyon/math/base/simd_int_unittest.cc b/tachyon/math/base/simd_int_unittest.cc new file mode 100644 index 000000000..758080399 --- /dev/null +++ b/tachyon/math/base/simd_int_unittest.cc @@ -0,0 +1,102 @@ +#include "tachyon/math/base/simd_int.h" + +#include "gtest/gtest.h" + +namespace tachyon::math { + +template +class SimdIntTest : public testing::Test {}; + +using SimdIntTypes = testing::Types; +TYPED_TEST_SUITE(SimdIntTest, SimdIntTypes); + +TYPED_TEST(SimdIntTest, Zero) { + using SimdInt = TypeParam; + + EXPECT_TRUE(SimdInt().IsZero()); + EXPECT_TRUE(SimdInt::Zero().IsZero()); + EXPECT_FALSE(SimdInt::One().IsZero()); + EXPECT_FALSE(SimdInt::Max().IsZero()); +} + +TYPED_TEST(SimdIntTest, One) { + using SimdInt = TypeParam; + + EXPECT_FALSE(SimdInt().IsOne()); + EXPECT_FALSE(SimdInt::Zero().IsOne()); + EXPECT_TRUE(SimdInt::One().IsOne()); + EXPECT_FALSE(SimdInt::Max().IsOne()); +} + +TYPED_TEST(SimdIntTest, Max) { + using SimdInt = TypeParam; + + EXPECT_FALSE(SimdInt().IsMax()); + EXPECT_FALSE(SimdInt::Zero().IsMax()); + EXPECT_FALSE(SimdInt::One().IsMax()); + EXPECT_TRUE(SimdInt::Max().IsMax()); +} + +template +void TestBroadcast() { + using BigInt = typename SimdInt::value_type; + + T v = base::Uniform(base::Range()); + BigInt expected; + for (size_t i = 0; i < BigInt::kByteNums / sizeof(T); ++i) { + reinterpret_cast(&expected)[i] = v; + } + BigInt actual = SimdInt::Broadcast(v).value(); + EXPECT_EQ(actual, expected); +} + +TYPED_TEST(SimdIntTest, Broadcast) { + using SimdInt = TypeParam; + + TestBroadcast(); + TestBroadcast(); + TestBroadcast(); + TestBroadcast(); +} + +TYPED_TEST(SimdIntTest, EqualityOperations) { + using SimdInt = TypeParam; + using BigInt = typename SimdInt::value_type; + + SimdInt a = SimdInt::Random(); + SimdInt b = SimdInt(a.value() + BigInt(1)); + EXPECT_EQ(a, a); + EXPECT_NE(a, b); +} + +TYPED_TEST(SimdIntTest, BitOperations) { + using SimdInt = TypeParam; + using BigInt = typename SimdInt::value_type; + + SimdInt a = SimdInt::Random(); + SimdInt not_a = !a; + EXPECT_EQ(a & a, a); + EXPECT_TRUE((a & not_a).IsZero()); + EXPECT_EQ(a | a, a); + EXPECT_TRUE((a | not_a).IsMax()); + EXPECT_TRUE((a ^ a).IsZero()); + EXPECT_TRUE((a ^ not_a).IsMax()); + + size_t count = base::Uniform(base::Range(0, SimdInt::kBits)); + BigInt expected = a.value() << count; + EXPECT_EQ((a << count).value(), expected); + + expected = a.value() >> count; + EXPECT_EQ((a >> count).value(), expected); +} + +} // namespace tachyon::math diff --git a/tachyon/math/finite_fields/baby_bear/packed_baby_bear_avx512.cc b/tachyon/math/finite_fields/baby_bear/packed_baby_bear_avx512.cc index 32be1f197..296096bdd 100644 --- a/tachyon/math/finite_fields/baby_bear/packed_baby_bear_avx512.cc +++ b/tachyon/math/finite_fields/baby_bear/packed_baby_bear_avx512.cc @@ -19,14 +19,12 @@ __m512i kZero; __m512i kOne; __m512i ToVector(const PackedBabyBearAVX512& packed) { - return _mm512_loadu_si512( - reinterpret_cast(packed.values().data())); + return _mm512_loadu_si512(packed.values().data()); } PackedBabyBearAVX512 FromVector(__m512i vector) { PackedBabyBearAVX512 ret; - _mm512_storeu_si512(reinterpret_cast<__m512i_u*>(ret.values().data()), - vector); + _mm512_storeu_si512(ret.values().data(), vector); return ret; } diff --git a/tachyon/math/finite_fields/goldilocks/BUILD.bazel b/tachyon/math/finite_fields/goldilocks/BUILD.bazel index e6872cfcb..2829d55f6 100644 --- a/tachyon/math/finite_fields/goldilocks/BUILD.bazel +++ b/tachyon/math/finite_fields/goldilocks/BUILD.bazel @@ -1,5 +1,5 @@ load("@bazel_skylib//rules:common_settings.bzl", "string_flag") -load("//bazel:tachyon.bzl", "if_x86_64") +load("//bazel:tachyon.bzl", "if_has_avx512", "if_x86_64") load("//bazel:tachyon_cc.bzl", "tachyon_asm_prime_field_defines", "tachyon_cc_library", "tachyon_cc_unittest") load( "//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", @@ -44,7 +44,11 @@ tachyon_cc_library( name = "goldilocks_prime_field_x86_special", srcs = if_x86_64(["goldilocks_prime_field_x86_special.cc"]), hdrs = if_x86_64(["goldilocks_prime_field_x86_special.h"]), - copts = if_x86_64(["-mavx2"]), + copts = if_x86_64([ + "-mavx2", + ]) + if_has_avx512([ + "-mavx512f", + ]), defines = tachyon_asm_prime_field_defines(), deps = if_x86_64([ ":goldilocks_config", diff --git a/tachyon/math/finite_fields/koala_bear/packed_koala_bear_avx512.cc b/tachyon/math/finite_fields/koala_bear/packed_koala_bear_avx512.cc index 71144799a..fab1914da 100644 --- a/tachyon/math/finite_fields/koala_bear/packed_koala_bear_avx512.cc +++ b/tachyon/math/finite_fields/koala_bear/packed_koala_bear_avx512.cc @@ -19,14 +19,12 @@ __m512i kZero; __m512i kOne; __m512i ToVector(const PackedKoalaBearAVX512& packed) { - return _mm512_loadu_si512( - reinterpret_cast(packed.values().data())); + return _mm512_loadu_si512(packed.values().data()); } PackedKoalaBearAVX512 FromVector(__m512i vector) { PackedKoalaBearAVX512 ret; - _mm512_storeu_si512(reinterpret_cast<__m512i_u*>(ret.values().data()), - vector); + _mm512_storeu_si512(ret.values().data(), vector); return ret; } diff --git a/tachyon/math/finite_fields/mersenne31/packed_mersenne31_avx512.cc b/tachyon/math/finite_fields/mersenne31/packed_mersenne31_avx512.cc index 316d5394e..b812aa425 100644 --- a/tachyon/math/finite_fields/mersenne31/packed_mersenne31_avx512.cc +++ b/tachyon/math/finite_fields/mersenne31/packed_mersenne31_avx512.cc @@ -21,14 +21,12 @@ __mmask16 kEvens = 0b0101010101010101; __mmask16 kOdds = 0b1010101010101010; __m512i ToVector(const PackedMersenne31AVX512& packed) { - return _mm512_loadu_si512( - reinterpret_cast(packed.values().data())); + return _mm512_loadu_si512(packed.values().data()); } PackedMersenne31AVX512 FromVector(__m512i vector) { PackedMersenne31AVX512 ret; - _mm512_storeu_si512(reinterpret_cast<__m512i_u*>(ret.values().data()), - vector); + _mm512_storeu_si512(ret.values().data(), vector); return ret; }