Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix coalesced access checks in matrix_vector_op #372

Merged
merged 9 commits into from
Nov 17, 2021
45 changes: 24 additions & 21 deletions cpp/include/raft/linalg/matrix_vector_op.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,11 +17,24 @@
#pragma once

#include <raft/cuda_utils.cuh>
#include <raft/pow2_utils.cuh>
#include <raft/vectorized.cuh>

namespace raft {
namespace linalg {

namespace {
template <size_t VecBytes>
struct AlignedAccess {
template <typename T>
static inline bool test(const T *matrix, size_t strideBytes) {
return Pow2<VecBytes>::isAligned(matrix) &&
Pow2<VecBytes>::isAligned(strideBytes) &&
Pow2<sizeof(T)>::isAligned(VecBytes);
}
};
}; // namespace

template <typename Type, int veclen_, typename Lambda, typename IdxType>
__global__ void matrixVectorOpKernel(Type *out, const Type *matrix,
const Type *vector, IdxType D, IdxType N,
Expand Down Expand Up @@ -101,24 +114,19 @@ void matrixVectorOp(Type *out, const Type *matrix, const Type *vec, IdxType D,
IdxType stride = rowMajor ? D : N;
size_t stride_bytes = stride * sizeof(Type);

auto test_aligned_access = [stride_bytes, matrix](const int n_bytes) {
return n_bytes / sizeof(Type) && stride_bytes % n_bytes == 0 &&
reinterpret_cast<uintptr_t>(matrix) % sizeof(Type);
};

if (test_aligned_access(16)) {
if (AlignedAccess<16>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 16 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(8)) {
} else if (AlignedAccess<8>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 8 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(4)) {
} else if (AlignedAccess<4>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 4 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(2)) {
} else if (AlignedAccess<2>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 2 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (1 / sizeof(Type)) {
} else if (AlignedAccess<1>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 1 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else {
Expand Down Expand Up @@ -209,24 +217,19 @@ void matrixVectorOp(Type *out, const Type *matrix, const Type *vec1,
IdxType stride = rowMajor ? D : N;
size_t stride_bytes = stride * sizeof(Type);

auto test_aligned_access = [stride_bytes, matrix](const int n_bytes) {
return n_bytes / sizeof(Type) && stride_bytes % n_bytes == 0 &&
reinterpret_cast<uintptr_t>(matrix) % sizeof(Type);
};

if (test_aligned_access(16)) {
if (AlignedAccess<16>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 16 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(8)) {
} else if (AlignedAccess<8>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 8 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(4)) {
} else if (AlignedAccess<4>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 4 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(2)) {
} else if (AlignedAccess<2>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 2 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (1 / sizeof(Type)) {
} else if (AlignedAccess<1>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 1 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else {
Expand Down
161 changes: 161 additions & 0 deletions cpp/include/raft/pow2_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "cuda_utils.cuh"

namespace raft {

/**
* @brief Fast arithmetics and alignment checks for power-of-two values known at compile time.
*
* @tparam Value_ a compile-time value representable as a power-of-two.
*/
template <auto Value_>
struct Pow2 {
achirkin marked this conversation as resolved.
Show resolved Hide resolved
typedef decltype(Value_) Type;
static constexpr Type Value = Value_;
static constexpr Type Log2 = log2(Value);
static constexpr Type Mask = Value - 1;

static_assert(std::is_integral<Type>::value, "Value must be integral.");
static_assert(Value && !(Value & Mask), "Value must be power of two.");

#define Pow2_IsRepresentableAs(I) \
(std::is_integral<I>::value && Type(I(Value)) == Value)

/**
* Integer division by Value truncated toward zero
* (same as `x / Value` in C++).
*
* Invariant: `x = Value * quot(x) + rem(x)`
*/
template <typename I>
static constexpr HDI std::enable_if_t<Pow2_IsRepresentableAs(I), I> quot(
I x) noexcept {
if constexpr (std::is_signed<I>::value)
return (x >> I(Log2)) + (x < 0 && (x & I(Mask)));
if constexpr (std::is_unsigned<I>::value) return x >> I(Log2);
}

/**
* Remainder of integer division by Value truncated toward zero
* (same as `x % Value` in C++).
*
* Invariant: `x = Value * quot(x) + rem(x)`.
*/
template <typename I>
static constexpr HDI std::enable_if_t<Pow2_IsRepresentableAs(I), I> rem(
I x) noexcept {
if constexpr (std::is_signed<I>::value)
return x < 0 ? -((-x) & I(Mask)) : (x & I(Mask));
if constexpr (std::is_unsigned<I>::value) return x & I(Mask);
}

/**
* Integer division by Value truncated toward negative infinity
* (same as `x // Value` in Python).
*
* Invariant: `x = Value * div(x) + mod(x)`.
*
* Note, `div` and `mod` for negative values are slightly faster
* than `quot` and `rem`, but behave slightly different
* compared to normal C++ operators `/` and `%`.
*/
template <typename I>
static constexpr HDI std::enable_if_t<Pow2_IsRepresentableAs(I), I> div(
I x) noexcept {
return x >> I(Log2);
}

/**
* x modulo Value operation (remainder of the `div(x)`)
* (same as `x % Value` in Python).
*
* Invariant: `mod(x) >= 0`
* Invariant: `x = Value * div(x) + mod(x)`.
*
* Note, `div` and `mod` for negative values are slightly faster
* than `quot` and `rem`, but behave slightly different
* compared to normal C++ operators `/` and `%`.
*/
template <typename I>
static constexpr HDI std::enable_if_t<Pow2_IsRepresentableAs(I), I> mod(
I x) noexcept {
return x & I(Mask);
}

#define Pow2_CHECK_TYPE(T) \
static_assert(std::is_pointer<T>::value || std::is_integral<T>::value, \
"Only pointer or integral types make sense here")

/**
* Tell whether the pointer or integral is Value-aligned.
* NB: for pointers, the alignment is checked in bytes, not in elements.
*/
template <typename PtrT>
static constexpr HDI bool isAligned(PtrT p) noexcept {
Pow2_CHECK_TYPE(PtrT);
if constexpr (Pow2_IsRepresentableAs(PtrT)) return mod(p) == 0;
if constexpr (!Pow2_IsRepresentableAs(PtrT))
return mod(reinterpret_cast<Type>(p)) == 0;
}

/** Tell whether two pointers have the same address modulo Value. */
template <typename PtrT, typename PtrS>
static constexpr HDI bool areSameAlignOffsets(PtrT a, PtrS b) noexcept {
Pow2_CHECK_TYPE(PtrT);
Pow2_CHECK_TYPE(PtrS);
Type x, y;
if constexpr (Pow2_IsRepresentableAs(PtrT))
x = Type(mod(a));
else
x = mod(reinterpret_cast<Type>(a));
if constexpr (Pow2_IsRepresentableAs(PtrS))
y = Type(mod(b));
else
y = mod(reinterpret_cast<Type>(b));
return x == y;
}

/** Get this or next Value-aligned address (in bytes) or integral. */
template <typename PtrT>
static constexpr HDI PtrT roundUp(PtrT p) noexcept {
Pow2_CHECK_TYPE(PtrT);
if constexpr (Pow2_IsRepresentableAs(PtrT))
return p + PtrT(Mask) - mod(p + PtrT(Mask));
if constexpr (!Pow2_IsRepresentableAs(PtrT)) {
auto x = reinterpret_cast<Type>(p);
return reinterpret_cast<PtrT>(x + Mask - mod(x + Mask));
}
}

/** Get this or previous Value-aligned address (in bytes) or integral. */
template <typename PtrT>
static constexpr HDI PtrT roundDown(PtrT p) noexcept {
Pow2_CHECK_TYPE(PtrT);
if constexpr (Pow2_IsRepresentableAs(PtrT)) return p - mod(p);
if constexpr (!Pow2_IsRepresentableAs(PtrT)) {
auto x = reinterpret_cast<Type>(p);
return reinterpret_cast<PtrT>(x - mod(x));
}
}
#undef Pow2_CHECK_TYPE
#undef Pow2_IsRepresentableAs
};

}; // namespace raft
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ add_executable(test_raft
test/eigen_solvers.cu
test/handle.cpp
test/integer_utils.cpp
test/pow2_utils.cu
test/label/label.cu
test/label/merge_labels.cu
test/lap/lap.cu
Expand Down
109 changes: 109 additions & 0 deletions cpp/test/pow2_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <raft/pow2_utils.cuh>

namespace raft {

template <auto Val, typename TargetT>
struct Pow2Test : public ::testing::Test {
typedef Pow2<Val> P;
std::vector<TargetT> data;

void SetUp() override {
std::vector<TargetT> pos = {0, 1, 2, 7, 15, 16, 17, 31, 35, 1024, 1623};
data.insert(data.end(), pos.begin(), pos.end());
if constexpr (std::is_signed<TargetT>::value) {
std::vector<TargetT> neg = {-0, -1, -2, -5, -15, -16, -17, -156};
data.insert(data.end(), neg.begin(), neg.end());
}
data.push_back(std::numeric_limits<TargetT>::min());
data.push_back(std::numeric_limits<TargetT>::max());
}

void quotRem() {
for (auto x : data) {
ASSERT_EQ(P::quot(x), x / P::Value) << " where x = " << x;
ASSERT_EQ(P::rem(x), x % P::Value) << " where x = " << x;
ASSERT_EQ(x, P::quot(x) * P::Value + P::rem(x));
}
}

void divMod() {
for (auto x : data) {
ASSERT_GE(P::mod(x), 0) << " where x = " << x;
ASSERT_EQ(x, P::div(x) * P::Value + P::mod(x));
}
}

void round() {
for (auto x : data) {
if (x <= std::numeric_limits<TargetT>::max() - TargetT(P::Value))
ASSERT_GE(P::roundUp(x), x);
if (x >= std::numeric_limits<TargetT>::min() + TargetT(P::Value))
ASSERT_LE(P::roundDown(x), x);
ASSERT_EQ(x - P::roundDown(x), P::mod(x)) << " where x = " << x;
ASSERT_EQ(P::mod(P::roundUp(x) + P::mod(x) - x), 0)
<< " where x = " << x;
}
}

void alignment() {
for (auto x : data) {
ASSERT_TRUE(P::areSameAlignOffsets(x, x));
if (x <= std::numeric_limits<TargetT>::max() - TargetT(P::Value)) {
ASSERT_TRUE(P::areSameAlignOffsets(x, x + TargetT(P::Value)));
int aligned_count = 0;
int same_aligned_count = 0;
for (int i = 0; i < int(P::Value); i++) {
aligned_count += P::isAligned(x + i);
same_aligned_count += P::areSameAlignOffsets(x, x + i);
}
ASSERT_EQ(aligned_count, 1) << " where x = " << x;
ASSERT_EQ(same_aligned_count, 1) << " where x = " << x;
}
}
}
};

#define TEST_IT(T) \
TEST_F(T, quotRem) { divMod(); } \
TEST_F(T, divMod) { divMod(); } \
TEST_F(T, round) { round(); } \
TEST_F(T, alignment) { alignment(); }

typedef Pow2Test<16, int> Pow2_i32_i32_16;
typedef Pow2Test<1UL, uint64_t> Pow2_u64_u64_1;
typedef Pow2Test<128UL, int> Pow2_u64_i32_128;
typedef Pow2Test<32LL, uint16_t> Pow2_ll_u16_32;
typedef Pow2Test<16, uint64_t> Pow2_i32_u64_16;
TEST_IT(Pow2_i32_i32_16);
TEST_IT(Pow2_u64_u64_1);
TEST_IT(Pow2_u64_i32_128);
TEST_IT(Pow2_ll_u16_32);
TEST_IT(Pow2_i32_u64_16);

TEST(Pow2, pointers) {
typedef Pow2<32UL> P;
for (ptrdiff_t i = 0; i <= ptrdiff_t(P::Value); i++) {
auto *p = reinterpret_cast<float *>(16345 + i);
ASSERT_GE(P::roundUp(p), p);
ASSERT_LE(P::roundDown(p), p);
}
}

} // namespace raft