Skip to content

Commit

Permalink
Use sid::make_loop in sid::make_unrolled_loop if unroll factor is 1
Browse files Browse the repository at this point in the history
  • Loading branch information
fthaler committed Sep 25, 2024
1 parent ef4fed7 commit ffd070e
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions include/gridtools/sid/loop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,20 +638,25 @@ namespace gridtools {
return {};
}

template <class Key, int UnrollFactor, class NumSteps, class Step = integral_constant<int, 1>>
template <class Key,
int UnrollFactor,
class NumSteps,
class Step = integral_constant<int, 1>,
std::enable_if_t<(UnrollFactor > 1), int> = 0>
constexpr GT_FUNCTION auto make_unrolled_loop(NumSteps num_steps, Step step = {}) {
using u = integral_constant<int, UnrollFactor>;
return [step,
unrolled = make_loop<Key>(num_steps / u(), step * u()),
epilogue = make_loop<Key>(num_steps % u(), step),
epilogue_start = step * ((num_steps / u()) * u())](auto &&fun) {
return [unrolled = unrolled([step, fun=std::forward<decltype(fun)>(fun)](auto &&ptr, auto const strides) {
::gridtools::host_device::for_each<meta::make_indices_c<UnrollFactor>>([&](auto) {
fun(std::forward<decltype(ptr)>(ptr), strides);
shift(std::forward<decltype(ptr)>(ptr), get_stride<Key>(strides), step);
});
shift(std::forward<decltype(ptr)>(ptr), get_stride<Key>(strides), -step * u());
}),
return [unrolled =
unrolled([step, fun = std::forward<decltype(fun)>(fun)](auto &&ptr, auto const strides) {
::gridtools::host_device::for_each<meta::make_indices_c<UnrollFactor>>([&](auto) {
fun(std::forward<decltype(ptr)>(ptr), strides);
shift(std::forward<decltype(ptr)>(ptr), get_stride<Key>(strides), step);
});
shift(std::forward<decltype(ptr)>(ptr), get_stride<Key>(strides), -step * u());
}),
epilogue = epilogue(std::forward<decltype(fun)>(fun)),
epilogue_start](auto &&ptr, auto const &strides) {
unrolled(std::forward<decltype(ptr)>(ptr), strides);
Expand All @@ -662,6 +667,15 @@ namespace gridtools {
};
}

template <class Key,
int UnrollFactor,
class NumSteps,
class Step = integral_constant<int, 1>,
std::enable_if_t<(UnrollFactor == 1), int> = 0>
constexpr GT_FUNCTION auto make_unrolled_loop(NumSteps num_steps, Step step = {}) {
return make_loop<Key>(num_steps, step);
}

/**
* A helper that allows to use `SID`s with C++11 range based loop
*
Expand Down

0 comments on commit ffd070e

Please sign in to comment.