-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #432 from kroma-network/feat/add-simd-int
feat(math): add simd int
- Loading branch information
Showing
13 changed files
with
698 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
// Copyright 2024 Ulvetanna Inc. | ||
// Use of this source code is governed by a Apache-2.0 style license that | ||
// can be found in the LICENSE.ulvetanna file. | ||
|
||
#ifndef TACHYON_MATH_BASE_SIMD_INT_H_ | ||
#define TACHYON_MATH_BASE_SIMD_INT_H_ | ||
|
||
#include <stddef.h> | ||
#include <stdint.h> | ||
|
||
#include <string> | ||
#include <type_traits> | ||
|
||
#include "tachyon/build/build_config.h" | ||
#include "tachyon/export.h" | ||
#include "tachyon/math/base/big_int.h" | ||
|
||
namespace tachyon::math { | ||
|
||
template <size_t Bits> | ||
class SimdInt { | ||
public: | ||
constexpr static size_t kBits = Bits; | ||
constexpr static size_t kLimbNums = Bits / 64; | ||
|
||
using value_type = BigInt<kLimbNums>; | ||
|
||
SimdInt() = default; | ||
template <typename T, | ||
std::enable_if_t<std::is_constructible_v<value_type, T>>* = nullptr> | ||
explicit SimdInt(T value) : SimdInt(value_type(value)) {} | ||
explicit SimdInt(const value_type& value) : value_(value) {} | ||
|
||
static SimdInt Zero() { return SimdInt(); } | ||
static SimdInt One() { return SimdInt(1); } | ||
static SimdInt Max() { return SimdInt(value_type::Max()); } | ||
static SimdInt Broadcast(uint8_t value); | ||
static SimdInt Broadcast(uint16_t value); | ||
static SimdInt Broadcast(uint32_t value); | ||
static SimdInt Broadcast(uint64_t value); | ||
static SimdInt Random() { return SimdInt(value_type::Random()); } | ||
|
||
const value_type& value() const { return value_; } | ||
|
||
bool IsZero() const { return *this == Zero(); } | ||
bool IsOne() const { return *this == One(); } | ||
bool IsMax() const { return *this == Max(); } | ||
|
||
bool operator==(const SimdInt& other) const; | ||
bool operator!=(const SimdInt& other) const { return !operator==(other); } | ||
|
||
SimdInt operator&(const SimdInt& other) const; | ||
SimdInt& operator&=(const SimdInt& other) { return *this = *this & other; } | ||
|
||
SimdInt operator|(const SimdInt& other) const; | ||
SimdInt& operator|=(const SimdInt& other) { return *this = *this | other; } | ||
|
||
SimdInt operator^(const SimdInt& other) const; | ||
SimdInt& operator^=(const SimdInt& other) { return *this = *this ^ other; } | ||
|
||
SimdInt operator!() const { return *this ^ Max(); } | ||
|
||
SimdInt operator>>(uint32_t count) const; | ||
SimdInt& operator>>=(uint32_t count) { return *this = *this >> count; } | ||
|
||
SimdInt operator<<(uint32_t count) const; | ||
SimdInt& operator<<=(uint32_t count) { return *this = *this << count; } | ||
|
||
std::string ToString() const { return value_.ToString(); } | ||
std::string ToHexString(bool pad_zero = false) const { | ||
return value_.ToHexString(pad_zero); | ||
} | ||
|
||
private: | ||
value_type value_; | ||
}; | ||
|
||
// clang-format off | ||
#define SPECIALIZE_SIMD_INT(bits) \ | ||
using SimdInt##bits = SimdInt<bits>; \ | ||
\ | ||
template <> \ | ||
SimdInt##bits SimdInt##bits::Broadcast(uint8_t value); \ | ||
\ | ||
template <> \ | ||
SimdInt##bits SimdInt##bits::Broadcast(uint16_t value); \ | ||
\ | ||
template <> \ | ||
SimdInt##bits SimdInt##bits::Broadcast(uint32_t value); \ | ||
\ | ||
template <> \ | ||
SimdInt##bits SimdInt##bits::Broadcast(uint64_t value); \ | ||
\ | ||
template <> \ | ||
bool SimdInt##bits::operator==(const SimdInt##bits& value) const; \ | ||
\ | ||
template <> \ | ||
SimdInt##bits SimdInt##bits::operator&(const SimdInt##bits& value) const; \ | ||
\ | ||
template <> \ | ||
SimdInt##bits SimdInt##bits::operator|(const SimdInt##bits& value) const; \ | ||
\ | ||
template <> \ | ||
SimdInt##bits SimdInt##bits::operator^(const SimdInt##bits& value) const; \ | ||
\ | ||
template <> \ | ||
SimdInt##bits SimdInt##bits::operator>>(uint32_t count) const; \ | ||
\ | ||
template <> \ | ||
SimdInt##bits SimdInt##bits::operator<<(uint32_t count) const | ||
// clang-format on | ||
|
||
SPECIALIZE_SIMD_INT(128); | ||
#if ARCH_CPU_X86_64 | ||
SPECIALIZE_SIMD_INT(256); | ||
#if defined(TACHYON_HAS_AVX512) | ||
SPECIALIZE_SIMD_INT(512); | ||
#endif | ||
#endif | ||
|
||
#undef SPECIALIZE_SIMD_INT | ||
|
||
} // namespace tachyon::math | ||
|
||
#endif // TACHYON_MATH_BASE_SIMD_INT_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
// Copyright 2024 Ulvetanna Inc. | ||
// Use of this source code is governed by a Apache-2.0 style license that | ||
// can be found in the LICENSE.ulvetanna file. | ||
|
||
#include <arm_neon.h> | ||
|
||
#include <algorithm> | ||
|
||
#include "tachyon/base/bit_cast.h" | ||
#include "tachyon/base/compiler_specific.h" | ||
#include "tachyon/base/logging.h" | ||
#include "tachyon/math/base/simd_int.h" | ||
|
||
namespace tachyon::math { | ||
|
||
namespace { | ||
|
||
uint8x16_t ToVector(const SimdInt128& value) { | ||
return vld1q_u8(reinterpret_cast<const uint8_t*>(&value)); | ||
} | ||
|
||
SimdInt128 FromVector(uint8x16_t vector) { | ||
SimdInt128 ret; | ||
vst1q_u8(reinterpret_cast<uint8_t*>(&ret), vector); | ||
return ret; | ||
} | ||
|
||
SimdInt128 FromVector(uint16x8_t vector) { | ||
SimdInt128 ret; | ||
vst1q_u16(reinterpret_cast<uint16_t*>(&ret), vector); | ||
return ret; | ||
} | ||
|
||
SimdInt128 FromVector(uint32x4_t vector) { | ||
SimdInt128 ret; | ||
vst1q_u32(reinterpret_cast<uint32_t*>(&ret), vector); | ||
return ret; | ||
} | ||
|
||
SimdInt128 FromVector(uint64x2_t vector) { | ||
SimdInt128 ret; | ||
vst1q_u64(reinterpret_cast<uint64_t*>(&ret), vector); | ||
return ret; | ||
} | ||
|
||
} // namespace | ||
|
||
// static | ||
template <> | ||
SimdInt128 SimdInt128::Broadcast(uint8_t value) { | ||
return FromVector(vdupq_n_u8(value)); | ||
} | ||
|
||
// static | ||
template <> | ||
SimdInt128 SimdInt128::Broadcast(uint16_t value) { | ||
return FromVector(vdupq_n_u16(value)); | ||
} | ||
|
||
// static | ||
template <> | ||
SimdInt128 SimdInt128::Broadcast(uint32_t value) { | ||
return FromVector(vdupq_n_u32(value)); | ||
} | ||
|
||
// static | ||
template <> | ||
SimdInt128 SimdInt128::Broadcast(uint64_t value) { | ||
return FromVector(vdupq_n_u64(value)); | ||
} | ||
|
||
template <> | ||
bool SimdInt128::operator==(const SimdInt128& other) const { | ||
return value_ == other.value_; | ||
} | ||
|
||
template <> | ||
SimdInt128 SimdInt128::operator&(const SimdInt128& other) const { | ||
return FromVector(vandq_u8(ToVector(*this), ToVector(other))); | ||
} | ||
|
||
template <> | ||
SimdInt128 SimdInt128::operator|(const SimdInt128& other) const { | ||
return FromVector(vorrq_u8(ToVector(*this), ToVector(other))); | ||
} | ||
|
||
template <> | ||
SimdInt128 SimdInt128::operator^(const SimdInt128& other) const { | ||
return FromVector(veorq_u8(ToVector(*this), ToVector(other))); | ||
} | ||
|
||
template <> | ||
SimdInt128 SimdInt128::operator>>(uint32_t count) const { | ||
if (UNLIKELY(count == 0)) return *this; | ||
if (UNLIKELY(count >= 128)) return SimdInt128(); | ||
return SimdInt128(value_ >> count); | ||
} | ||
|
||
template <> | ||
SimdInt128 SimdInt128::operator<<(uint32_t count) const { | ||
if (UNLIKELY(count == 0)) return *this; | ||
if (UNLIKELY(count >= 128)) return SimdInt128(); | ||
return SimdInt128(value_ << count); | ||
} | ||
|
||
} // namespace tachyon::math |
Oops, something went wrong.