Skip to content

Commit

Permalink
feat(math): add SimdInt<512>
Browse files Browse the repository at this point in the history
  • Loading branch information
chokobole committed Jun 11, 2024
1 parent a2e7c2d commit 84f7475
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tachyon/math/base/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("//bazel:tachyon.bzl", "if_aarch64", "if_x86_64")
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
load(
"//bazel:tachyon_cc.bzl",
"tachyon_avx512_defines",
"tachyon_cc_benchmark",
"tachyon_cc_library",
"tachyon_cc_unittest",
Expand Down Expand Up @@ -126,13 +127,18 @@ tachyon_cc_library(
srcs = if_x86_64([
"simd_int128_x86.cc",
"simd_int256_x86.cc",
]) + if_has_avx512([
"simd_int512_x86.cc",
]) + if_aarch64([
"simd_int128_arm64.cc",
]),
hdrs = ["simd_int.h"],
copts = if_x86_64([
"-mavx2",
]) + if_has_avx512([
"-mavx512f",
]),
defines = tachyon_avx512_defines(),
deps = [
"//tachyon:export",
"//tachyon/base:bit_cast",
Expand Down
3 changes: 3 additions & 0 deletions tachyon/math/base/simd_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ class SimdInt {
SPECIALIZE_SIMD_INT(128);
#if ARCH_CPU_X86_64
SPECIALIZE_SIMD_INT(256);
#if defined(TACHYON_HAS_AVX512)
SPECIALIZE_SIMD_INT(512);
#endif
#endif

#undef SPECIALIZE_SIMD_INT
Expand Down
86 changes: 86 additions & 0 deletions tachyon/math/base/simd_int512_x86.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2024 Ulvetanna Inc.
// Use of this source code is governed by a Apache-2.0 style license that
// can be found in the LICENSE.ulvetanna file.

#include <immintrin.h>

#include "tachyon/base/bit_cast.h"
#include "tachyon/base/compiler_specific.h"
#include "tachyon/math/base/simd_int.h"

namespace tachyon::math {

namespace {

__m512i ToVector(const SimdInt512& value) { return _mm512_loadu_si512(&value); }

SimdInt512 FromVector(__m512i vector) {
SimdInt512 ret;
_mm512_storeu_si512(&ret, vector);
return ret;
}

} // namespace

// static
template <>
SimdInt512 SimdInt512::Broadcast(uint8_t value) {
return FromVector(_mm512_set1_epi8(value));
}

// static
template <>
SimdInt512 SimdInt512::Broadcast(uint16_t value) {
return FromVector(_mm512_set1_epi16(value));
}

// static
template <>
SimdInt512 SimdInt512::Broadcast(uint32_t value) {
return FromVector(_mm512_set1_epi32(value));
}

// static
template <>
SimdInt512 SimdInt512::Broadcast(uint64_t value) {
return FromVector(_mm512_set1_epi64(value));
}

template <>
bool SimdInt512::operator==(const SimdInt512& other) const {
__mmask16 pcmp = _mm512_cmpeq_epi32_mask(ToVector(*this), ToVector(other));
return pcmp == 0xffff;
}

template <>
SimdInt512 SimdInt512::operator&(const SimdInt512& other) const {
return FromVector(_mm512_and_si512(ToVector(*this), ToVector(other)));
}

template <>
SimdInt512 SimdInt512::operator|(const SimdInt512& other) const {
return FromVector(_mm512_or_si512(ToVector(*this), ToVector(other)));
}

template <>
SimdInt512 SimdInt512::operator^(const SimdInt512& other) const {
return FromVector(_mm512_xor_si512(ToVector(*this), ToVector(other)));
}

template <>
SimdInt512 SimdInt512::operator>>(uint32_t count) const {
if (UNLIKELY(count == 0)) return *this;
if (UNLIKELY(count >= 512)) return SimdInt512();
// TODO(chokobole): Optimize this.
return SimdInt512(value_ >> count);
}

template <>
SimdInt512 SimdInt512::operator<<(uint32_t count) const {
if (UNLIKELY(count == 0)) return *this;
if (UNLIKELY(count >= 512)) return SimdInt512();
// TODO(chokobole): Optimize this.
return SimdInt512(value_ << count);
}

} // namespace tachyon::math
4 changes: 4 additions & 0 deletions tachyon/math/base/simd_int_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ using SimdIntTypes = testing::Types<SimdInt128
#if ARCH_CPU_X86_64
,
SimdInt256
#if defined(TACHYON_HAS_AVX512)
,
SimdInt512
#endif
#endif
>;
TYPED_TEST_SUITE(SimdIntTest, SimdIntTypes);
Expand Down

0 comments on commit 84f7475

Please sign in to comment.