Skip to content

Commit

Permalink
Skip unit stride consistency check if the dimension size is 1 (#1651)
Browse files Browse the repository at this point in the history
  • Loading branch information
anstaf authored Sep 29, 2021
1 parent 9f4b951 commit 65518e6
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions include/gridtools/storage/adapter/python_sid_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include <pybind11/pybind11.h>

#include <boost/optional.hpp>
#include <boost/variant.hpp>

#include "../../common/array.hpp"
Expand All @@ -39,11 +38,12 @@ namespace gridtools {
template <size_t, class>
struct kind {};

template <int UnitStrideDim>
template <size_t UnitStrideDim>
struct transform_strides_f {
bool m_unit_stride_can_be_used;
template <size_t I, class T, std::enable_if_t<I == UnitStrideDim, int> = 0>
integral_constant<pybind11::ssize_t, 1> operator()(T val) const {
if (val != 1)
if (m_unit_stride_can_be_used && val != 1)
throw std::domain_error("incompatible strides, expected unit stride");
return {};
}
Expand All @@ -54,7 +54,29 @@ namespace gridtools {
}
};

template <class T, size_t Dim, class Kind, int UnitStrideDim>
template <size_t UnitStrideDim,
class Strides,
class Shape,
std::enable_if_t<(UnitStrideDim >= tuple_util::size<std::decay_t<Strides>>::value), int> = 0>
Strides &&assign_unit_stride(Strides &&strides, Shape &&) {
return std::forward<Strides>(strides);
}

template <size_t UnitStrideDim,
class Strides,
class Shape,
std::enable_if_t<(UnitStrideDim < tuple_util::size<std::decay_t<Strides>>::value), int> = 0>
decltype(auto) assign_unit_stride(Strides &&strides, Shape &&shape) {
// Numpy may shuffle array layout if the array sizes are equal to one in some dimensions.
// In this case the static unit stride calculation doesn't match the actual strides.
// Luckily we can ignore that because those strides will not be used (corresponding shape == 1).
// See: https://numpy.org/devdocs/release/1.8.0-notes.html#npy-relaxed-strides-checking
bool unit_stride_can_be_used = tuple_util::get<UnitStrideDim>(std::forward<Shape>(shape)) > 1;
return tuple_util::transform_index(transform_strides_f<UnitStrideDim>{unit_stride_can_be_used},
tuple_util::convert_to<tuple>(std::forward<Strides>(strides)));
}

template <class T, size_t Dim, class Kind, size_t UnitStrideDim>
struct wrapper {
pybind11::buffer_info m_info;

Expand All @@ -68,8 +90,7 @@ namespace gridtools {
assert(obj.m_info.strides[i] % obj.m_info.itemsize == 0);
res[i] = obj.m_info.strides[i] / obj.m_info.itemsize;
}
return tuple_util::transform_index(
transform_strides_f<UnitStrideDim>{}, tuple_util::convert_to<tuple>(res));
return assign_unit_stride<UnitStrideDim>(std::move(res), sid_get_upper_bounds(obj));
}

friend std::array<integral_constant<pybind11::ssize_t, 0>, Dim> sid_get_lower_bounds(wrapper const &) {
Expand All @@ -90,7 +111,7 @@ namespace gridtools {
}
};

template <class T, std::size_t Dim, class Kind = void, int UnitStrideDim = -1>
template <class T, std::size_t Dim, class Kind = void, size_t UnitStrideDim = size_t(-1)>
wrapper<T, Dim, Kind, UnitStrideDim> as_sid(pybind11::buffer const &src) {
static_assert(
std::is_trivially_copyable<T>::value, "as_sid should be instantiated with the trivially copyable type");
Expand Down Expand Up @@ -243,7 +264,7 @@ namespace gridtools {
return res;
}

template <class T, size_t Dim, class Kind = void, int UnitStrideDim = -1>
template <class T, size_t Dim, class Kind = void, size_t UnitStrideDim = size_t(-1)>
auto as_cuda_sid(pybind11::object const &src) {
static_assert(std::is_trivially_copyable<T>::value,
"as_cuda_sid should be instantiated with the trivially copyable type");
Expand Down Expand Up @@ -320,8 +341,7 @@ namespace gridtools {
using sid::property;
return sid::synthetic()
.template set<property::origin>(sid::host_device::simple_ptr_holder<T *>{ptr})
.template set<property::strides>(tuple_util::transform_index(
transform_strides_f<UnitStrideDim>{}, tuple_util::convert_to<tuple>(strides)))
.template set<property::strides>(assign_unit_stride<UnitStrideDim>(std::move(strides), shape))
.template set<property::strides_kind, kind<Dim, Kind>>()
.template set<property::lower_bounds>(array<integral_constant<size_t, 0>, Dim>())
.template set<property::upper_bounds>(shape);
Expand Down

0 comments on commit 65518e6

Please sign in to comment.