Skip to content

Commit

Permalink
fn: allow execution of stencils with 0d domain (#1754)
Browse files Browse the repository at this point in the history
The total number of points in the limit of zero dimensions is 1. This allows to execute 0d stencils. In GT4Py field view this gives executing stencils with 0d output fields the expected behavior.
  • Loading branch information
havogt authored Apr 20, 2023
1 parent 062a37c commit 43d826c
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 22 deletions.
16 changes: 15 additions & 1 deletion include/gridtools/common/integral_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "host_device.hpp"

namespace gridtools {

// This predicate checks if the the class has `std::integral_constant` as a public base.
// Note that it is not the same as `class is an instantiation of gridtools::integral_constant`.
// Also it is not the same as `class has `integral nested value_type type and value static member`.
Expand All @@ -41,6 +40,21 @@ namespace gridtools {
constexpr GT_FUNCTION operator T() const noexcept { return V; }
};

// Returns the value_type of an integral_constant or returns the type itself
template <class, class = void>
struct to_integral_type;
template <class T>
struct to_integral_type<T, std::enable_if_t<std::is_integral_v<T>>> {
using type = T;
};
template <class T>
struct to_integral_type<T, std::enable_if_t<is_integral_constant<T>::value>> {
using type = typename T::value_type;
};

template <class T>
using to_integral_type_t = typename to_integral_type<T>::type;

// This predicate checks if the the class has `gridtools::integral_constant` as a public base, which has arithmetic
// operators defined.
template <class, class = void>
Expand Down
2 changes: 1 addition & 1 deletion include/gridtools/common/stride_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace gridtools {

template <class Sizes>
auto total_size(Sizes const &sizes) {
return tuple_util::fold([](auto l, auto r) { return l * r; }, sizes);
return tuple_util::fold([](auto l, auto r) { return l * r; }, integral_constant<int, 1>{}, sizes);
}
} // namespace stride_util
} // namespace gridtools
26 changes: 17 additions & 9 deletions include/gridtools/thread_pool/dummy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,38 @@

#pragma once

#include "../common/integral_constant.hpp"

namespace gridtools {
namespace thread_pool {
struct dummy {
friend auto thread_pool_get_thread_num(dummy) { return 0; }
friend auto thread_pool_get_max_threads(dummy) { return 1; }

template <class F, class I>
template <class F, class I, class I_t = to_integral_type_t<I>>
friend void thread_pool_parallel_for_loop(dummy, F const &f, I lim) {
for (I i = 0; i < lim; ++i)
for (I_t i = 0; i < lim; ++i)
f(i);
}

template <class F, class I, class J>
template <class F, class I, class J, class I_t = to_integral_type_t<I>, class J_t = to_integral_type_t<J>>
friend void thread_pool_parallel_for_loop(dummy, F const &f, I i_lim, J j_lim) {
for (J j = 0; j < j_lim; ++j)
for (I i = 0; i < i_lim; ++i)
for (J_t j = 0; j < j_lim; ++j)
for (I_t i = 0; i < i_lim; ++i)
f(i, j);
}

template <class F, class I, class J, class K>
template <class F,
class I,
class J,
class K,
class I_t = to_integral_type_t<I>,
class J_t = to_integral_type_t<J>,
class K_t = to_integral_type_t<K>>
friend void thread_pool_parallel_for_loop(dummy, F const &f, I i_lim, J j_lim, K k_lim) {
for (K k = 0; k < k_lim; ++k)
for (J j = 0; j < j_lim; ++j)
for (I i = 0; i < i_lim; ++i)
for (K_t k = 0; k < k_lim; ++k)
for (J_t j = 0; j < j_lim; ++j)
for (I_t i = 0; i < i_lim; ++i)
f(i, j, k);
}
};
Expand Down
27 changes: 18 additions & 9 deletions include/gridtools/thread_pool/omp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,47 @@

#pragma once

#include "../common/integral_constant.hpp"

#if defined(_OPENMP) || defined(GT_HIP_OPENMP_WORKAROUND)
#include <omp.h>
#endif

namespace gridtools {
namespace thread_pool {

struct omp {
#if defined(_OPENMP) || defined(GT_HIP_OPENMP_WORKAROUND)
friend auto thread_pool_get_thread_num(omp) { return omp_get_thread_num(); }
friend auto thread_pool_get_max_threads(omp) { return omp_get_max_threads(); }

template <class F, class I>
template <class F, class I, class I_t = to_integral_type_t<I>>
friend void thread_pool_parallel_for_loop(omp, F const &f, I lim) {
#pragma omp parallel for
for (I i = 0; i < lim; ++i)
for (I_t i = 0; i < lim; ++i)
f(i);
}

template <class F, class I, class J>
template <class F, class I, class J, class I_t = to_integral_type_t<I>, class J_t = to_integral_type_t<J>>
friend void thread_pool_parallel_for_loop(omp, F const &f, I i_lim, J j_lim) {
#pragma omp parallel for collapse(2)
for (J j = 0; j < j_lim; ++j)
for (I i = 0; i < i_lim; ++i)
for (J_t j = 0; j < j_lim; ++j)
for (I_t i = 0; i < i_lim; ++i)
f(i, j);
}

template <class F, class I, class J, class K>
template <class F,
class I,
class J,
class K,
class I_t = to_integral_type_t<I>,
class J_t = to_integral_type_t<J>,
class K_t = to_integral_type_t<K>>
friend void thread_pool_parallel_for_loop(omp, F const &f, I i_lim, J j_lim, K k_lim) {
#pragma omp parallel for collapse(3)
for (K k = 0; k < k_lim; ++k)
for (J j = 0; j < j_lim; ++j)
for (I i = 0; i < i_lim; ++i)
for (K_t k = 0; k < k_lim; ++k)
for (J_t j = 0; j < j_lim; ++j)
for (I_t i = 0; i < i_lim; ++i)
f(i, j, k);
}
#endif
Expand Down
2 changes: 1 addition & 1 deletion tests/regression/fn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ gridtools_add_fn_regression_test(fn_copy SOURCES fn_copy.cpp PERFTEST)
gridtools_add_fn_regression_test(fn_unstructured_nabla SOURCES fn_unstructured_nabla.cpp PERFTEST)
gridtools_add_fn_regression_test(fn_tridiagonal_solve SOURCES fn_tridiagonal_solve.cpp PERFTEST)
gridtools_add_fn_regression_test(fn_cartesian_vertical_advection SOURCES fn_cartesian_vertical_advection.cpp PERFTEST)
gridtools_add_fn_regression_test(fn_empty_domain SOURCES fn_empty_domain.cpp)
gridtools_add_fn_regression_test(fn_domain SOURCES fn_domain.cpp)
gridtools_add_fn_regression_test(fn_vertical_indirection SOURCES fn_vertical_indirection.cpp)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace {

struct empty_stencil {
GT_FUNCTION constexpr auto operator()() const {
return []() { return 0.0f; };
return []() { return 1.0f; };
}
};

Expand All @@ -32,6 +32,21 @@ namespace {
auto domain = unstructured_domain(tuple{0, 0}, tuple{0, 0});
auto backend = make_backend(fn_backend_t{}, domain);
backend.stencil_executor()().arg(out).assign(0_c, empty_stencil{}).execute();

using float_t = typename TypeParam::float_t;
ASSERT_EQ(float_t(0), out->host_view()({}));
}

GT_REGRESSION_TEST(zero_dimensional_domain_stencil, test_environment<>, fn_backend_t) {
auto out = TypeParam::make_storage([](...) { return 0; });

// executes the stencil once at the origin; seems to be the natural limit of removing all loops/dimensions in an
// expression like `for(d0: range_d0) for (d1: range_d1) ... { out(d0, d1, ...) = 1.0f)};`
auto domain = unstructured_domain(tuple{}, tuple{});
auto backend = make_backend(fn_backend_t{}, domain);
backend.stencil_executor()().arg(out).assign(0_c, empty_stencil{}).execute();

ASSERT_EQ(1.0f, out->host_view()({}));
}

struct empty_column : fwd {
Expand All @@ -50,6 +65,9 @@ namespace {
auto domain = unstructured_domain(tuple{0, 0}, tuple{0, 0});
auto backend = make_backend(fn_backend_t{}, domain);
backend.vertical_executor()().arg(out).assign(0_c, empty_column{}, 0.0f).execute();

using float_t = typename TypeParam::float_t;
ASSERT_EQ(float_t(0), out->host_view()({}));
}

} // namespace
4 changes: 4 additions & 0 deletions tests/unit_tests/common/test_integral_constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,9 @@ namespace gridtools {
static_assert(is_integral_constant<std::integral_constant<int, 42>>::value);
static_assert(is_gr_integral_constant<integral_constant<int, 42>>::value);
static_assert(!is_gr_integral_constant<std::integral_constant<int, 42>>::value);

static_assert(std::is_same_v<to_integral_type_t<integral_constant<int, 42>>, int>);
static_assert(std::is_same_v<to_integral_type_t<std::integral_constant<int, 42>>, int>);
static_assert(std::is_same_v<to_integral_type_t<int>, int>);
} // namespace
} // namespace gridtools

0 comments on commit 43d826c

Please sign in to comment.