Skip to content

Commit

Permalink
Merge pull request #432 from kroma-network/feat/add-simd-int
Browse files Browse the repository at this point in the history
feat(math): add simd int
  • Loading branch information
chokobole authored Jun 11, 2024
2 parents 77dc287 + 84f7475 commit 8284250
Show file tree
Hide file tree
Showing 13 changed files with 698 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ IncludeBlocks: Regroup
IncludeCategories:
# C system headers. The header_dependency_test.py contains a copy of this
# list; be sure to update that test anytime this list changes.
- Regex: '^[<"](aio|arpa/inet|assert|complex|cpio|ctype|curses|dirent|dlfcn|errno|fcntl|fenv|float|fmtmsg|fnmatch|ftw|glob|grp|iconv|inttypes|iso646|langinfo|libgen|limits|locale|math|monetary|mqueue|ndbm|netdb|net/if|netinet/in|netinet/tcp|nl_types|poll|pthread|pwd|regex|sched|search|semaphore|setjmp|signal|spawn|stdalign|stdarg|stdatomic|stdbool|stddef|stdint|stdio|stdlib|stdnoreturn|string|strings|stropts|sys/ioctl|sys/ipc|syslog|sys/mman|sys/msg|sys/resource|sys/select|sys/sem|sys/shm|sys/socket|sys/stat|sys/statvfs|sys/time|sys/times|sys/types|sys/uio|sys/un|sys/utsname|sys/wait|tar|term|termios|tgmath|threads|time|trace|uchar|ulimit|uncntrl|unistd|utime|utmpx|wchar|wctype|wordexp)\.h[">]$'
- Regex: '^[<"](aio|arm_neon|arpa/inet|assert|complex|cpio|ctype|curses|dirent|dlfcn|errno|fcntl|fenv|float|fmtmsg|fnmatch|ftw|glob|grp|iconv|immintrin|inttypes|iso646|langinfo|libgen|limits|locale|math|monetary|mqueue|ndbm|netdb|net/if|netinet/in|netinet/tcp|nl_types|poll|pthread|pwd|regex|sched|search|semaphore|setjmp|signal|spawn|stdalign|stdarg|stdatomic|stdbool|stddef|stdint|stdio|stdlib|stdnoreturn|string|strings|stropts|sys/ioctl|sys/ipc|syslog|sys/mman|sys/msg|sys/resource|sys/select|sys/sem|sys/shm|sys/socket|sys/stat|sys/statvfs|sys/time|sys/times|sys/types|sys/uio|sys/un|sys/utsname|sys/wait|tar|term|termios|tgmath|threads|time|trace|uchar|ulimit|uncntrl|unistd|utime|utmpx|wchar|wctype|wordexp)\.h[">]$'
Priority: 20
# C++ system headers (as of C++23). The header_dependency_test.py contains a
# copy of this list; be sure to update that test anytime this list changes.
Expand Down
31 changes: 31 additions & 0 deletions tachyon/math/base/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
load(
"//bazel:tachyon_cc.bzl",
"tachyon_avx512_defines",
"tachyon_cc_benchmark",
"tachyon_cc_library",
"tachyon_cc_unittest",
Expand Down Expand Up @@ -120,6 +122,33 @@ tachyon_cc_library(
],
)

tachyon_cc_library(
name = "simd_int",
srcs = if_x86_64([
"simd_int128_x86.cc",
"simd_int256_x86.cc",
]) + if_has_avx512([
"simd_int512_x86.cc",
]) + if_aarch64([
"simd_int128_arm64.cc",
]),
hdrs = ["simd_int.h"],
copts = if_x86_64([
"-mavx2",
]) + if_has_avx512([
"-mavx512f",
]),
defines = tachyon_avx512_defines(),
deps = [
"//tachyon:export",
"//tachyon/base:bit_cast",
"//tachyon/base:compiler_specific",
"//tachyon/base:logging",
"//tachyon/build:build_config",
"//tachyon/math/base:big_int",
],
)

tachyon_cc_unittest(
name = "base_unittests",
srcs = [
Expand All @@ -132,6 +161,7 @@ tachyon_cc_unittest(
"rational_field_unittest.cc",
"semigroups_unittest.cc",
"sign_unittest.cc",
"simd_int_unittest.cc",
],
deps = [
":big_int",
Expand All @@ -140,6 +170,7 @@ tachyon_cc_unittest(
":groups",
":rational_field",
":sign",
":simd_int",
"//tachyon/base:optional",
"//tachyon/base/buffer:vector_buffer",
"//tachyon/base/containers:container_util",
Expand Down
22 changes: 22 additions & 0 deletions tachyon/math/base/big_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,28 @@ struct ALIGNAS(internal::LimbsAlignment(N)) BigInt {
return *this;
}

constexpr BigInt operator<<(uint32_t count) const {
BigInt ret = *this;
ret.MulBy2ExpInPlace(count);
return ret;
}

constexpr BigInt& operator<<=(uint32_t count) {
MulBy2ExpInPlace(count);
return *this;
}

constexpr BigInt operator>>(uint32_t count) const {
BigInt ret = *this;
ret.DivBy2ExpInPlace(count);
return ret;
}

constexpr BigInt& operator>>=(uint32_t count) {
DivBy2ExpInPlace(count);
return *this;
}

constexpr BigInt Add(const BigInt& other) const {
uint64_t unused = 0;
return Add(other, unused);
Expand Down
125 changes: 125 additions & 0 deletions tachyon/math/base/simd_int.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright 2024 Ulvetanna Inc.
// Use of this source code is governed by a Apache-2.0 style license that
// can be found in the LICENSE.ulvetanna file.

#ifndef TACHYON_MATH_BASE_SIMD_INT_H_
#define TACHYON_MATH_BASE_SIMD_INT_H_

#include <stddef.h>
#include <stdint.h>

#include <string>
#include <type_traits>

#include "tachyon/build/build_config.h"
#include "tachyon/export.h"
#include "tachyon/math/base/big_int.h"

namespace tachyon::math {

template <size_t Bits>
class SimdInt {
public:
constexpr static size_t kBits = Bits;
constexpr static size_t kLimbNums = Bits / 64;

using value_type = BigInt<kLimbNums>;

SimdInt() = default;
template <typename T,
std::enable_if_t<std::is_constructible_v<value_type, T>>* = nullptr>
explicit SimdInt(T value) : SimdInt(value_type(value)) {}
explicit SimdInt(const value_type& value) : value_(value) {}

static SimdInt Zero() { return SimdInt(); }
static SimdInt One() { return SimdInt(1); }
static SimdInt Max() { return SimdInt(value_type::Max()); }
static SimdInt Broadcast(uint8_t value);
static SimdInt Broadcast(uint16_t value);
static SimdInt Broadcast(uint32_t value);
static SimdInt Broadcast(uint64_t value);
static SimdInt Random() { return SimdInt(value_type::Random()); }

const value_type& value() const { return value_; }

bool IsZero() const { return *this == Zero(); }
bool IsOne() const { return *this == One(); }
bool IsMax() const { return *this == Max(); }

bool operator==(const SimdInt& other) const;
bool operator!=(const SimdInt& other) const { return !operator==(other); }

SimdInt operator&(const SimdInt& other) const;
SimdInt& operator&=(const SimdInt& other) { return *this = *this & other; }

SimdInt operator|(const SimdInt& other) const;
SimdInt& operator|=(const SimdInt& other) { return *this = *this | other; }

SimdInt operator^(const SimdInt& other) const;
SimdInt& operator^=(const SimdInt& other) { return *this = *this ^ other; }

SimdInt operator!() const { return *this ^ Max(); }

SimdInt operator>>(uint32_t count) const;
SimdInt& operator>>=(uint32_t count) { return *this = *this >> count; }

SimdInt operator<<(uint32_t count) const;
SimdInt& operator<<=(uint32_t count) { return *this = *this << count; }

std::string ToString() const { return value_.ToString(); }
std::string ToHexString(bool pad_zero = false) const {
return value_.ToHexString(pad_zero);
}

private:
value_type value_;
};

// clang-format off
#define SPECIALIZE_SIMD_INT(bits) \
using SimdInt##bits = SimdInt<bits>; \
\
template <> \
SimdInt##bits SimdInt##bits::Broadcast(uint8_t value); \
\
template <> \
SimdInt##bits SimdInt##bits::Broadcast(uint16_t value); \
\
template <> \
SimdInt##bits SimdInt##bits::Broadcast(uint32_t value); \
\
template <> \
SimdInt##bits SimdInt##bits::Broadcast(uint64_t value); \
\
template <> \
bool SimdInt##bits::operator==(const SimdInt##bits& value) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator&(const SimdInt##bits& value) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator|(const SimdInt##bits& value) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator^(const SimdInt##bits& value) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator>>(uint32_t count) const; \
\
template <> \
SimdInt##bits SimdInt##bits::operator<<(uint32_t count) const
// clang-format on

SPECIALIZE_SIMD_INT(128);
#if ARCH_CPU_X86_64
SPECIALIZE_SIMD_INT(256);
#if defined(TACHYON_HAS_AVX512)
SPECIALIZE_SIMD_INT(512);
#endif
#endif

#undef SPECIALIZE_SIMD_INT

} // namespace tachyon::math

#endif // TACHYON_MATH_BASE_SIMD_INT_H_
106 changes: 106 additions & 0 deletions tachyon/math/base/simd_int128_arm64.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2024 Ulvetanna Inc.
// Use of this source code is governed by a Apache-2.0 style license that
// can be found in the LICENSE.ulvetanna file.

#include <arm_neon.h>

#include <algorithm>

#include "tachyon/base/bit_cast.h"
#include "tachyon/base/compiler_specific.h"
#include "tachyon/base/logging.h"
#include "tachyon/math/base/simd_int.h"

namespace tachyon::math {

namespace {

uint8x16_t ToVector(const SimdInt128& value) {
return vld1q_u8(reinterpret_cast<const uint8_t*>(&value));
}

SimdInt128 FromVector(uint8x16_t vector) {
SimdInt128 ret;
vst1q_u8(reinterpret_cast<uint8_t*>(&ret), vector);
return ret;
}

SimdInt128 FromVector(uint16x8_t vector) {
SimdInt128 ret;
vst1q_u16(reinterpret_cast<uint16_t*>(&ret), vector);
return ret;
}

SimdInt128 FromVector(uint32x4_t vector) {
SimdInt128 ret;
vst1q_u32(reinterpret_cast<uint32_t*>(&ret), vector);
return ret;
}

SimdInt128 FromVector(uint64x2_t vector) {
SimdInt128 ret;
vst1q_u64(reinterpret_cast<uint64_t*>(&ret), vector);
return ret;
}

} // namespace

// static
template <>
SimdInt128 SimdInt128::Broadcast(uint8_t value) {
return FromVector(vdupq_n_u8(value));
}

// static
template <>
SimdInt128 SimdInt128::Broadcast(uint16_t value) {
return FromVector(vdupq_n_u16(value));
}

// static
template <>
SimdInt128 SimdInt128::Broadcast(uint32_t value) {
return FromVector(vdupq_n_u32(value));
}

// static
template <>
SimdInt128 SimdInt128::Broadcast(uint64_t value) {
return FromVector(vdupq_n_u64(value));
}

template <>
bool SimdInt128::operator==(const SimdInt128& other) const {
return value_ == other.value_;
}

template <>
SimdInt128 SimdInt128::operator&(const SimdInt128& other) const {
return FromVector(vandq_u8(ToVector(*this), ToVector(other)));
}

template <>
SimdInt128 SimdInt128::operator|(const SimdInt128& other) const {
return FromVector(vorrq_u8(ToVector(*this), ToVector(other)));
}

template <>
SimdInt128 SimdInt128::operator^(const SimdInt128& other) const {
return FromVector(veorq_u8(ToVector(*this), ToVector(other)));
}

template <>
SimdInt128 SimdInt128::operator>>(uint32_t count) const {
if (UNLIKELY(count == 0)) return *this;
if (UNLIKELY(count >= 128)) return SimdInt128();
return SimdInt128(value_ >> count);
}

template <>
SimdInt128 SimdInt128::operator<<(uint32_t count) const {
if (UNLIKELY(count == 0)) return *this;
if (UNLIKELY(count >= 128)) return SimdInt128();
return SimdInt128(value_ << count);
}

} // namespace tachyon::math
Loading

0 comments on commit 8284250

Please sign in to comment.