Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(math): add simd int #432

Merged
merged 7 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ IncludeBlocks: Regroup
IncludeCategories:
# C system headers. The header_dependency_test.py contains a copy of this
# list; be sure to update that test anytime this list changes.
- Regex: '^[<"](aio|arpa/inet|assert|complex|cpio|ctype|curses|dirent|dlfcn|errno|fcntl|fenv|float|fmtmsg|fnmatch|ftw|glob|grp|iconv|inttypes|iso646|langinfo|libgen|limits|locale|math|monetary|mqueue|ndbm|netdb|net/if|netinet/in|netinet/tcp|nl_types|poll|pthread|pwd|regex|sched|search|semaphore|setjmp|signal|spawn|stdalign|stdarg|stdatomic|stdbool|stddef|stdint|stdio|stdlib|stdnoreturn|string|strings|stropts|sys/ioctl|sys/ipc|syslog|sys/mman|sys/msg|sys/resource|sys/select|sys/sem|sys/shm|sys/socket|sys/stat|sys/statvfs|sys/time|sys/times|sys/types|sys/uio|sys/un|sys/utsname|sys/wait|tar|term|termios|tgmath|threads|time|trace|uchar|ulimit|uncntrl|unistd|utime|utmpx|wchar|wctype|wordexp)\.h[">]$'
- Regex: '^[<"](aio|arm_neon|arpa/inet|assert|complex|cpio|ctype|curses|dirent|dlfcn|errno|fcntl|fenv|float|fmtmsg|fnmatch|ftw|glob|grp|iconv|immintrin|inttypes|iso646|langinfo|libgen|limits|locale|math|monetary|mqueue|ndbm|netdb|net/if|netinet/in|netinet/tcp|nl_types|poll|pthread|pwd|regex|sched|search|semaphore|setjmp|signal|spawn|stdalign|stdarg|stdatomic|stdbool|stddef|stdint|stdio|stdlib|stdnoreturn|string|strings|stropts|sys/ioctl|sys/ipc|syslog|sys/mman|sys/msg|sys/resource|sys/select|sys/sem|sys/shm|sys/socket|sys/stat|sys/statvfs|sys/time|sys/times|sys/types|sys/uio|sys/un|sys/utsname|sys/wait|tar|term|termios|tgmath|threads|time|trace|uchar|ulimit|uncntrl|unistd|utime|utmpx|wchar|wctype|wordexp)\.h[">]$'
Priority: 20
# C++ system headers (as of C++23). The header_dependency_test.py contains a
# copy of this list; be sure to update that test anytime this list changes.
Expand Down
31 changes: 31 additions & 0 deletions tachyon/math/base/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
load(
"//bazel:tachyon_cc.bzl",
"tachyon_avx512_defines",
"tachyon_cc_benchmark",
"tachyon_cc_library",
"tachyon_cc_unittest",
Expand Down Expand Up @@ -120,6 +122,33 @@ tachyon_cc_library(
],
)

tachyon_cc_library(
name = "simd_int",
srcs = if_x86_64([
"simd_int128_x86.cc",
"simd_int256_x86.cc",
]) + if_has_avx512([
"simd_int512_x86.cc",
]) + if_aarch64([
"simd_int128_arm64.cc",
]),
hdrs = ["simd_int.h"],
copts = if_x86_64([
"-mavx2",
]) + if_has_avx512([
"-mavx512f",
]),
defines = tachyon_avx512_defines(),
deps = [
"//tachyon:export",
"//tachyon/base:bit_cast",
"//tachyon/base:compiler_specific",
"//tachyon/base:logging",
"//tachyon/build:build_config",
"//tachyon/math/base:big_int",
],
)

tachyon_cc_unittest(
name = "base_unittests",
srcs = [
Expand All @@ -132,6 +161,7 @@ tachyon_cc_unittest(
"rational_field_unittest.cc",
"semigroups_unittest.cc",
"sign_unittest.cc",
"simd_int_unittest.cc",
],
deps = [
":big_int",
Expand All @@ -140,6 +170,7 @@ tachyon_cc_unittest(
":groups",
":rational_field",
":sign",
":simd_int",
"//tachyon/base:optional",
"//tachyon/base/buffer:vector_buffer",
"//tachyon/base/containers:container_util",
Expand Down
22 changes: 22 additions & 0 deletions tachyon/math/base/big_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,28 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return *this;
}

constexpr BigInt operator<<(uint32_t count) const {
BigInt ret = *this;
ret.MulBy2ExpInPlace(count);
return ret;
}

constexpr BigInt& operator<<=(uint32_t count) {
MulBy2ExpInPlace(count);
return *this;
}

constexpr BigInt operator>>(uint32_t count) const {
BigInt ret = *this;
ret.DivBy2ExpInPlace(count);
return ret;
}

constexpr BigInt& operator>>=(uint32_t count) {
DivBy2ExpInPlace(count);
return *this;
}

constexpr BigInt Add(const BigInt& other) const {
uint64_t unused = 0;
return Add(other, unused);
Expand Down
125 changes: 125 additions & 0 deletions tachyon/math/base/simd_int.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright 2024 Ulvetanna Inc.
// Use of this source code is governed by a Apache-2.0 style license that
// can be found in the LICENSE.ulvetanna file.

#ifndef TACHYON_MATH_BASE_SIMD_INT_H_
#define TACHYON_MATH_BASE_SIMD_INT_H_

#include <stddef.h>
#include <stdint.h>

#include <string>
#include <type_traits>

#include "tachyon/build/build_config.h"
#include "tachyon/export.h"
#include "tachyon/math/base/big_int.h"

namespace tachyon::math {

template <size_t Bits>
class SimdInt {
public:
constexpr static size_t kBits = Bits;
constexpr static size_t kLimbNums = Bits / 64;

using value_type = BigInt<kLimbNums>;

SimdInt() = default;
template <typename T,
std::enable_if_t<std::is_constructible_v<value_type, T>>* = nullptr>
explicit SimdInt(T value) : SimdInt(value_type(value)) {}
explicit SimdInt(const value_type& value) : value_(value) {}

static SimdInt Zero() { return SimdInt(); }
static SimdInt One() { return SimdInt(1); }
static SimdInt Max() { return SimdInt(value_type::Max()); }
static SimdInt Broadcast(uint8_t value);
static SimdInt Broadcast(uint16_t value);
static SimdInt Broadcast(uint32_t value);
static SimdInt Broadcast(uint64_t value);
static SimdInt Random() { return SimdInt(value_type::Random()); }

const value_type& value() const { return value_; }

bool IsZero() const { return *this == Zero(); }
bool IsOne() const { return *this == One(); }
bool IsMax() const { return *this == Max(); }

bool operator==(const SimdInt& other) const;
bool operator!=(const SimdInt& other) const { return !operator==(other); }

SimdInt operator&(const SimdInt& other) const;
SimdInt& operator&=(const SimdInt& other) { return *this = *this & other; }

SimdInt operator|(const SimdInt& other) const;
SimdInt& operator|=(const SimdInt& other) { return *this = *this | other; }

SimdInt operator^(const SimdInt& other) const;
SimdInt& operator^=(const SimdInt& other) { return *this = *this ^ other; }

SimdInt operator!() const { return *this ^ Max(); }

SimdInt operator>>(uint32_t count) const;
SimdInt& operator>>=(uint32_t count) { return *this = *this >> count; }

SimdInt operator<<(uint32_t count) const;
SimdInt& operator<<=(uint32_t count) { return *this = *this << count; }

std::string ToString() const { return value_.ToString(); }
std::string ToHexString(bool pad_zero = false) const {
return value_.ToHexString(pad_zero);
}

private:
value_type value_;
};

// clang-format off
#define SPECIALIZE_SIMD_INT(bits) \
using SimdInt##bits = SimdInt<bits>; \
\
template <> \
SimdInt##bits SimdInt##bits::Broadcast(uint8_t value); \
\
template <> \
SimdInt##bits SimdInt##bits::Broadcast(uint16_t value); \
\
template <> \
SimdInt##bits SimdInt##bits::Broadcast(uint32_t value); \
\
template <> \
SimdInt##bits SimdInt##bits::Broadcast(uint64_t value); \
\
template <> \
bool SimdInt##bits::operator==(const SimdInt##bits& value) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator&(const SimdInt##bits& value) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator|(const SimdInt##bits& value) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator^(const SimdInt##bits& value) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator>>(uint32_t count) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator<<(uint32_t count) const
// clang-format on

SPECIALIZE_SIMD_INT(128);
#if ARCH_CPU_X86_64
SPECIALIZE_SIMD_INT(256);
#if defined(TACHYON_HAS_AVX512)
SPECIALIZE_SIMD_INT(512);
#endif
#endif

#undef SPECIALIZE_SIMD_INT

} // namespace tachyon::math

#endif // TACHYON_MATH_BASE_SIMD_INT_H_
106 changes: 106 additions & 0 deletions tachyon/math/base/simd_int128_arm64.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2024 Ulvetanna Inc.
// Use of this source code is governed by a Apache-2.0 style license that
// can be found in the LICENSE.ulvetanna file.

#include <arm_neon.h>

#include <algorithm>

#include "tachyon/base/bit_cast.h"
#include "tachyon/base/compiler_specific.h"
#include "tachyon/base/logging.h"
#include "tachyon/math/base/simd_int.h"

namespace tachyon::math {

namespace {

uint8x16_t ToVector(const SimdInt128& value) {
return vld1q_u8(reinterpret_cast<const uint8_t*>(&value));
}

SimdInt128 FromVector(uint8x16_t vector) {
SimdInt128 ret;
vst1q_u8(reinterpret_cast<uint8_t*>(&ret), vector);
return ret;
}

SimdInt128 FromVector(uint16x8_t vector) {
SimdInt128 ret;
vst1q_u16(reinterpret_cast<uint16_t*>(&ret), vector);
return ret;
}

SimdInt128 FromVector(uint32x4_t vector) {
SimdInt128 ret;
vst1q_u32(reinterpret_cast<uint32_t*>(&ret), vector);
return ret;
}

SimdInt128 FromVector(uint64x2_t vector) {
SimdInt128 ret;
vst1q_u64(reinterpret_cast<uint64_t*>(&ret), vector);
return ret;
}

} // namespace

// static
template <>
SimdInt128 SimdInt128::Broadcast(uint8_t value) {
return FromVector(vdupq_n_u8(value));
}

// static
template <>
SimdInt128 SimdInt128::Broadcast(uint16_t value) {
return FromVector(vdupq_n_u16(value));
}

// static
template <>
SimdInt128 SimdInt128::Broadcast(uint32_t value) {
return FromVector(vdupq_n_u32(value));
}

// static
template <>
SimdInt128 SimdInt128::Broadcast(uint64_t value) {
return FromVector(vdupq_n_u64(value));
}

template <>
bool SimdInt128::operator==(const SimdInt128& other) const {
return value_ == other.value_;
}

template <>
SimdInt128 SimdInt128::operator&(const SimdInt128& other) const {
return FromVector(vandq_u8(ToVector(*this), ToVector(other)));
}

template <>
SimdInt128 SimdInt128::operator|(const SimdInt128& other) const {
return FromVector(vorrq_u8(ToVector(*this), ToVector(other)));
}

template <>
SimdInt128 SimdInt128::operator^(const SimdInt128& other) const {
return FromVector(veorq_u8(ToVector(*this), ToVector(other)));
}

template <>
SimdInt128 SimdInt128::operator>>(uint32_t count) const {
if (UNLIKELY(count == 0)) return *this;
if (UNLIKELY(count >= 128)) return SimdInt128();
return SimdInt128(value_ >> count);
}

template <>
SimdInt128 SimdInt128::operator<<(uint32_t count) const {
if (UNLIKELY(count == 0)) return *this;
if (UNLIKELY(count >= 128)) return SimdInt128();
return SimdInt128(value_ << count);
}

} // namespace tachyon::math
Loading
Loading