-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- introduce reduction framework within gridtools - implement `naive` and `gpu` reduction backends
- Loading branch information
Showing
20 changed files
with
1,170 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
/* | ||
* GridTools | ||
* | ||
* Copyright (c) 2014-2019, ETH Zurich | ||
* All rights reserved. | ||
* | ||
* Please, refer to the LICENSE file in the root directory. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
#pragma once | ||
|
||
/*** | ||
* `ct_dispatch` performs compile time dispatch on runtime value | ||
* | ||
* Usage: | ||
* | ||
* Say you have a function that accepts an integer in compile time: | ||
* | ||
* template <size_t N> int foo() { ... } | ||
* | ||
* And you need something like: | ||
* | ||
* auto bar(int n) { | ||
* switch(n) { | ||
* case 0: | ||
* return foo<0>(); | ||
* case 1: | ||
* return foo<1>(); | ||
* case 2: | ||
* return foo<2>(); | ||
* case 3: | ||
* return foo<3>(); | ||
* } | ||
* } | ||
* | ||
* You can use `ct_dispatch` here to reduce the boilerplate: | ||
* | ||
* auto bar(int n) { | ||
* return ct_dispatch<4>([](auto n) { | ||
* return foo<decltype(n)::value>();) | ||
* }, n); | ||
* } | ||
* | ||
*/ | ||
|
||
#include <cassert> | ||
#include <cstdlib> | ||
#include <type_traits> | ||
#include <utility> | ||
|
||
namespace gridtools { | ||
template <size_t Lim, class Sink, std::enable_if_t<Lim == 1, int> = 0> | ||
auto ct_dispatch(Sink &&sink, size_t n) { | ||
assert(n == 0); | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 0>()); | ||
} | ||
|
||
template <size_t Lim, class Sink, std::enable_if_t<Lim == 2, int> = 0> | ||
auto ct_dispatch(Sink &&sink, size_t n) { | ||
assert(n < 2); | ||
switch (n) { | ||
case 0: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 0>()); | ||
default: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 1>()); | ||
} | ||
} | ||
|
||
template <size_t Lim, class Sink, std::enable_if_t<Lim == 3, int> = 0> | ||
auto ct_dispatch(Sink &&sink, size_t n) { | ||
assert(n < 3); | ||
switch (n) { | ||
case 0: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 0>()); | ||
case 1: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 1>()); | ||
default: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 2>()); | ||
} | ||
} | ||
|
||
template <size_t Lim, class Sink, std::enable_if_t<Lim == 4, int> = 0> | ||
auto ct_dispatch(Sink &&sink, size_t n) { | ||
assert(n < 4); | ||
switch (n) { | ||
case 0: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 0>()); | ||
case 1: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 1>()); | ||
case 2: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 2>()); | ||
default: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 3>()); | ||
} | ||
} | ||
|
||
template <size_t Lim, class Sink, std::enable_if_t<Lim == 5, int> = 0> | ||
auto ct_dispatch(Sink &&sink, size_t n) { | ||
assert(n < 5); | ||
switch (n) { | ||
case 0: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 0>()); | ||
case 1: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 1>()); | ||
case 2: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 2>()); | ||
case 3: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 3>()); | ||
default: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 4>()); | ||
} | ||
} | ||
|
||
template <size_t Lim, class Sink, std::enable_if_t<Lim == 6, int> = 0> | ||
auto ct_dispatch(Sink &&sink, size_t n) { | ||
assert(n < 6); | ||
switch (n) { | ||
case 0: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 0>()); | ||
case 1: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 1>()); | ||
case 2: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 2>()); | ||
case 3: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 3>()); | ||
case 4: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 4>()); | ||
default: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 5>()); | ||
} | ||
} | ||
|
||
template <size_t Lim, class Sink, std::enable_if_t<Lim == 7, int> = 0> | ||
auto ct_dispatch(Sink &&sink, size_t n) { | ||
assert(n < 7); | ||
switch (n) { | ||
case 0: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 0>()); | ||
case 1: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 1>()); | ||
case 2: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 2>()); | ||
case 3: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 3>()); | ||
case 4: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 4>()); | ||
case 5: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 5>()); | ||
default: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 6>()); | ||
} | ||
} | ||
|
||
template <size_t Lim, class Sink, std::enable_if_t<Lim == 8, int> = 0> | ||
auto ct_dispatch(Sink &&sink, size_t n) { | ||
assert(n < 8); | ||
switch (n) { | ||
case 0: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 0>()); | ||
case 1: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 1>()); | ||
case 2: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 2>()); | ||
case 3: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 3>()); | ||
case 4: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 4>()); | ||
case 5: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 5>()); | ||
case 6: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 6>()); | ||
default: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, 7>()); | ||
} | ||
} | ||
|
||
template <size_t Lim, class Sink, std::enable_if_t<(Lim > 8), int> = 0> | ||
auto ct_dispatch(Sink &&sink, size_t n) { | ||
assert(n < Lim); | ||
switch (n) { | ||
case Lim - 8: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, Lim - 8>()); | ||
case Lim - 7: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, Lim - 7>()); | ||
case Lim - 6: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, Lim - 6>()); | ||
case Lim - 5: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, Lim - 5>()); | ||
case Lim - 4: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, Lim - 4>()); | ||
case Lim - 3: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, Lim - 3>()); | ||
case Lim - 2: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, Lim - 2>()); | ||
case Lim - 1: | ||
return std::forward<Sink>(sink)(std::integral_constant<size_t, Lim - 1>()); | ||
default: | ||
return ct_dispatch<Lim - 8>(std::forward<Sink>(sink), n); | ||
} | ||
} | ||
} // namespace gridtools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* GridTools | ||
* | ||
* Copyright (c) 2014-2019, ETH Zurich | ||
* All rights reserved. | ||
* | ||
* Please, refer to the LICENSE file in the root directory. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cassert> | ||
#include <cmath> | ||
#include <limits> | ||
#include <numeric> | ||
#include <type_traits> | ||
|
||
namespace gridtools { | ||
#if __cplusplus < 201703 | ||
template <class T> | ||
constexpr T gcd(T m, T n) { | ||
static_assert(!std::is_signed<T>::value, ""); | ||
return n == 0 ? m : gcd<T>(n, m % n); | ||
} | ||
|
||
template <class T, class U> | ||
constexpr std::common_type_t<T, U> gcd(T m, U n) { | ||
static_assert(std::is_integral<T>() && std::is_integral<U>(), "Arguments to gcd must be integer types"); | ||
static_assert(!std::is_same<std::remove_cv_t<T>, bool>(), "First argument to gcd cannot be bool"); | ||
static_assert(!std::is_same<std::remove_cv_t<U>, bool>(), "Second argument to gcd cannot be bool"); | ||
using res_t = std::common_type_t<T, U>; | ||
using ures_t = std::make_unsigned_t<res_t>; | ||
return static_cast<res_t>(gcd(static_cast<ures_t>(std::abs(m)), static_cast<ures_t>(std::abs(n)))); | ||
} | ||
|
||
template <class T, class U> | ||
constexpr std::common_type_t<T, U> lcm(T m, U n) { | ||
static_assert(std::is_integral<T>() && std::is_integral<U>(), "Arguments to lcm must be integer types"); | ||
static_assert(!std::is_same<std::remove_cv_t<T>, bool>(), "First argument to gcd cannot be bool"); | ||
static_assert(!std::is_same<std::remove_cv_t<U>, bool>(), "Second argument to gcd cannot be bool"); | ||
if (m == 0 || n == 0) | ||
return 0; | ||
using res_t = std::common_type_t<T, U>; | ||
res_t a = std::abs(m) / gcd(m, n); | ||
res_t b = std::abs(n); | ||
assert(std::numeric_limits<res_t>::max() / a > b); | ||
return a * b; | ||
} | ||
#else | ||
using std::gcd; | ||
using std::lcm; | ||
#endif | ||
} // namespace gridtools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
/* | ||
* GridTools | ||
* | ||
* Copyright (c) 2014-2019, ETH Zurich | ||
* All rights reserved. | ||
* | ||
* Please, refer to the LICENSE file in the root directory. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
#pragma once | ||
|
||
#include "reduction/frontend.hpp" | ||
#include "reduction/functions.hpp" |
Oops, something went wrong.