Skip to content

Commit

Permalink
[oneDPL][ranges] + tests for ranges::copy, ranges::copy_if, ranges::m…
Browse files Browse the repository at this point in the history
…erge
  • Loading branch information
MikeDvorskiy committed Apr 5, 2024
1 parent 8c719e1 commit 946e7aa
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 62 deletions.
5 changes: 2 additions & 3 deletions test/parallel_api/ranges/std_ranges.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ main()

test_range_algo{}(oneapi::dpl::ranges::adjacent_find, std::ranges::adjacent_find, pred_2, proj);

test_range_algo<data_in_in>{}(oneapi::dpl::ranges::search, std::ranges::search, pred_2, proj);
test_range_algo<data_in_val_n>{}(oneapi::dpl::ranges::search_n, std::ranges::search_n, pred_2, proj);

test_range_algo<data_in_in>{}(oneapi::dpl::ranges::search, std::ranges::search, pred_2, proj, proj);
test_range_algo{}(oneapi::dpl::ranges::search_n, std::ranges::search_n, 3, 5, pred_2, proj);
#endif //_ENABLE_STD_RANGES_TESTING

return TestUtils::done(_ENABLE_STD_RANGES_TESTING);
Expand Down
11 changes: 9 additions & 2 deletions test/parallel_api/ranges/std_ranges_2.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ main()
#if _ENABLE_STD_RANGES_TESTING

using namespace test_std_ranges;

#if 1
test_range_algo{}(oneapi::dpl::ranges::count_if, std::ranges::count_if, pred, proj);
test_range_algo{}(oneapi::dpl::ranges::count, std::ranges::count, 4, proj);

test_range_algo<data_in_in>{}(oneapi::dpl::ranges::equal, std::ranges::equal, pred_2, proj);
test_range_algo<data_in_in>{}(oneapi::dpl::ranges::equal, std::ranges::equal, pred_2, proj, proj);

test_range_algo{}(oneapi::dpl::ranges::is_sorted, std::ranges::is_sorted, std::ranges::less{}, proj);
test_range_algo{}(oneapi::dpl::ranges::is_sorted, std::ranges::is_sorted, std::ranges::greater{}, proj);
Expand All @@ -43,6 +43,13 @@ main()
test_range_algo{}(oneapi::dpl::ranges::max_element, std::ranges::max_element, std::ranges::less{}, proj);
test_range_algo{}(oneapi::dpl::ranges::max_element, std::ranges::max_element, std::ranges::greater{}, proj);

test_range_algo<data_in_out, /*RetTypeCheck*/false>{}(oneapi::dpl::ranges::copy, std::ranges::copy);
test_range_algo<data_in_out, /*RetTypeCheck*/false>{}(oneapi::dpl::ranges::copy_if, std::ranges::copy_if,
pred, proj);
#endif
test_range_algo<data_in_in_out>{}(oneapi::dpl::ranges::merge, std::ranges::merge, std::ranges::less{}, proj, proj);
test_range_algo<data_in_in_out>{}(oneapi::dpl::ranges::merge, std::ranges::merge, std::ranges::greater{}, proj, proj);

#endif //_ENABLE_STD_RANGES_TESTING

return TestUtils::done(_ENABLE_STD_RANGES_TESTING);
Expand Down
129 changes: 72 additions & 57 deletions test/parallel_api/ranges/std_ranges_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ enum TestDataMode
data_in,
data_in_out,
data_in_in,
data_in_in_out,
data_in_val_n,
data_in_in_out
};

template<typename Container, TestDataMode Ranges = data_in, bool RetTypeCheck = true>
Expand All @@ -65,26 +64,25 @@ struct test
operator()(oneapi::dpl::execution::par_unseq, algo, args...);
}

template<typename Policy, typename Algo, typename Checker, typename FunctorOrVal, typename Proj = std::identity,
typename Transform = std::identity>
template<typename Policy, typename Algo, typename Checker, typename Transform>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == data_in>
operator()(Policy&& exec, Algo algo, Checker checker, FunctorOrVal f, Proj proj = {}, Transform tr = {})
operator()(Policy&& exec, Algo algo, Checker checker, Transform tr, auto... args)
{
constexpr int max_n = 10;
int data[max_n] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int expected[max_n] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};

auto expected_view = tr(std::ranges::subrange(expected, expected + max_n));
auto expected_res = checker(expected_view, f, proj);
auto expected_res = checker(expected_view, args...);
{
Container cont(exec, data, max_n);
typename Container::type& A = cont();

auto res = algo(exec, tr(A), f, proj);
auto res = algo(exec, tr(A), args...);

//check result
if constexpr(RetTypeCheck)
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), f, proj))>, "Wrong return type");
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), args...))>, "Wrong return type");

auto bres = ret_in_val(expected_res, expected_view.begin()) == ret_in_val(res, tr(A).begin());
EXPECT_TRUE(bres, (std::string("wrong return value from algo with ranges: ") + typeid(Algo).name() +
Expand All @@ -95,31 +93,29 @@ struct test
EXPECT_EQ_N(expected, data, max_n, (std::string("wrong effect algo with ranges: ")
+ typeid(Algo).name() + typeid(decltype(tr(std::declval<Container&>()()))).name()).c_str());
}

template<typename Policy, typename Algo, typename Checker, typename Functor, typename Proj = std::identity,
typename Transform = std::identity>
template<typename Policy, typename Algo, typename Checker, typename Transform>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == data_in_out>
operator()(Policy&& exec, Algo algo, Checker checker, Functor f, Proj proj = {}, Transform tr = {})
operator()(Policy&& exec, Algo algo, Checker checker, Transform tr, auto... args)
{
constexpr int max_n = 10;
int data_in[max_n] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int data_out[max_n] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int expected[max_n] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};

auto src_view = tr(std::ranges::subrange(data_in, data_in + max_n));
auto expected_res = checker(src_view, expected, f, proj);
auto expected_res = checker(src_view, expected, args...);
{
Container cont_in(exec, data_in, max_n);
Container cont_out(exec, data_out, max_n);

typename Container::type& A = cont_in();
typename Container::type& B = cont_out();

auto res = algo(exec, tr(A), B, f, proj);
auto res = algo(exec, tr(A), B, args...);

//check result
if constexpr(RetTypeCheck)
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), B, f, proj))>, "Wrong return type");
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), B, args...))>, "Wrong return type");

auto bres_in = ret_in_val(expected_res, src_view.begin()) == ret_in_val(res, tr(A).begin());
EXPECT_TRUE(bres_in, (std::string("wrong return value from algo with input range: ") + typeid(Algo).name()).c_str());
Expand All @@ -132,67 +128,69 @@ struct test
EXPECT_EQ_N(expected, data_out, max_n, (std::string("wrong effect algo with ranges: ") + typeid(Algo).name()).c_str());
}

template<typename Policy, typename Algo, typename Checker, typename Functor, typename Proj = std::identity,
typename Transform = std::identity>
template<typename Policy, typename Algo, typename Checker, typename Transform>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == data_in_in>
operator()(Policy&& exec, Algo algo, Checker checker, Functor f, Proj proj = {}, Transform tr = {})
operator()(Policy&& exec, Algo algo, Checker checker, Transform tr, auto... args)
{
constexpr int max_n = 10;
int data_in1[max_n] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int data_in2[max_n] = {0, 0, 2, 3, 4, 5, 0, 0, 0, 0};

auto src_view1 = tr(std::ranges::subrange(data_in1, data_in1 + max_n));
auto src_view2 = tr(std::ranges::subrange(data_in2, data_in2 + max_n));
auto expected_res = checker(src_view1, src_view2, f, proj, proj);
auto expected_res = checker(src_view1, src_view2, args...);
{
Container cont_in1(exec, data_in1, max_n);
Container cont_in2(exec, data_in2, max_n);

typename Container::type& A = cont_in1();
typename Container::type& B = cont_in2();

auto res = algo(exec, tr(A), tr(B), f, proj, proj);
auto res = algo(exec, tr(A), tr(B), args...);

if constexpr(RetTypeCheck)
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), tr(B), f, proj, proj))>, "Wrong return type");
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), tr(B), args...))>, "Wrong return type");

auto bres_in = ret_in_val(expected_res, src_view1.begin()) == ret_in_val(res, tr(A).begin());
EXPECT_TRUE(bres_in, (std::string("wrong return value from algo: ") + typeid(Algo).name() +
typeid(decltype(tr(std::declval<Container&>()()))).name()).c_str());
}
}

template<typename Policy, typename Algo, typename Checker, typename FunctorOrVal, typename Proj = std::identity,
typename Transform = std::identity>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == data_in_val_n>
operator()(Policy&& exec, Algo algo, Checker checker, FunctorOrVal f, Proj proj = {}, Transform tr = {})
template<typename Policy, typename Algo, typename Checker, typename Transform>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == data_in_in_out>
operator()(Policy&& exec, Algo algo, Checker checker, Transform tr, auto... args)
{
constexpr int max_n = 10;
int data[max_n] = {0, 1, 2, 5, 5, 5, 6, 7, 8, 9};
int expected[max_n] = {0, 1, 2, 5, 5, 5, 6, 7, 8, 9};
int val = 5, n = 3;
int data_in1[max_n] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int data_in2[max_n] = {0, 0, 2, 3, 4, 5, 6, 6, 6, 6};
constexpr int max_n_out = max_n*2;
int data_out[max_n_out] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; //TODO: size
int expected[max_n_out] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};

auto expected_view = tr(std::ranges::subrange(expected, expected + max_n));
auto expected_res = checker(expected_view, n, val, f, proj);
auto src_view1 = tr(std::ranges::subrange(data_in1, data_in1 + max_n));
auto src_view2 = tr(std::ranges::subrange(data_in2, data_in2 + max_n));
auto expected_res = checker(src_view1, src_view2, expected, args...);
{
Container cont(exec, data, max_n);
typename Container::type& A = cont();
Container cont_in1(exec, data_in1, max_n);
Container cont_in2(exec, data_in2, max_n);
Container cont_out(exec, data_out, max_n_out);

auto res = algo(exec, tr(A), n, val, f, proj);
typename Container::type& A = cont_in1();
typename Container::type& B = cont_in2();
typename Container::type& С = cont_out();

auto res = algo(exec, tr(A), tr(B), С, args...);

//check result
if constexpr(RetTypeCheck)
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), n, val, f, proj))>, "Wrong return type");
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), tr(B), С.begin(), args...))>, "Wrong return type");

auto bres = ret_in_val(expected_res, expected_view.begin()) == ret_in_val(res, tr(A).begin());
EXPECT_TRUE(bres, (std::string("wrong return value from algo with ranges: ") + typeid(Algo).name()).c_str());
auto bres_in = ret_in_val(expected_res, src_view1.begin()) == ret_in_val(res, tr(A).begin());
EXPECT_TRUE(bres_in, (std::string("wrong return value from algo: ") + typeid(Algo).name() +
typeid(decltype(tr(std::declval<Container&>()()))).name()).c_str());
}

//check result
EXPECT_EQ_N(expected, data, max_n, (std::string("wrong effect algo with ranges: ")
+ typeid(Algo).name() + typeid(decltype(tr(std::declval<Container&>()()))).name()).c_str());
EXPECT_EQ_N(expected, data_out, max_n_out, (std::string("wrong effect algo with ranges: ") + typeid(Algo).name()).c_str());
}

private:

template<typename, typename = void>
Expand All @@ -209,14 +207,27 @@ struct test
static constexpr
bool check_in<T, std::void_t<decltype(std::declval<T>().in)>> = true;

template<typename, typename = void>
static constexpr bool check_in1{};

template<typename T>
static constexpr
bool check_in1<T, std::void_t<decltype(std::declval<T>().in1)>> = true;

template<typename, typename = void>
static constexpr bool check_in2{};

template<typename T>
static constexpr
bool check_in2<T, std::void_t<decltype(std::declval<T>().in2)>> = true;

template<typename, typename = void>
static constexpr bool check_out{};

template<typename T>
static constexpr
bool check_out<T, std::void_t<decltype(std::declval<T>().out)>> = true;


template<typename, typename = void>
static constexpr bool is_range{};

Expand All @@ -229,6 +240,10 @@ struct test
{
if constexpr (check_in<Ret>)
return std::distance(begin, ret.in);
else if constexpr (check_in1<Ret>)
return std::distance(begin, ret.in1);
else if constexpr (check_in2<Ret>)
return std::distance(begin, ret.in2);
else if constexpr (is_iterator<Ret>)
return std::distance(begin, ret);
else if constexpr(is_range<Ret>)
Expand Down Expand Up @@ -363,7 +378,7 @@ using usm_span = usm_subrange_impl<std::span<int>>;
template<TestDataMode TestDataMode = data_in, bool RetTypeCheck = true, bool ForwardRangeCheck = true>
struct test_range_algo
{
void operator()(auto algo, auto checker, auto f, auto proj)
void operator()(auto algo, auto checker, auto... args)
{

auto subrange_view = [](auto&& v) { return std::ranges::subrange(v); };
Expand All @@ -375,24 +390,24 @@ struct test_range_algo
};

if constexpr(ForwardRangeCheck)
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, std::identity{}, forward_view);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, forward_view, args...);

test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, std::identity{}, subrange_view);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, std::identity{}, span_view);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, proj, std::views::all);
test<host_subrange, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, proj, std::views::all);
test<host_span, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, proj, std::views::all);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, subrange_view, args...);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, span_view, args...);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, std::views::all, args...);
test<host_subrange, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, std::views::all, args...);
test<host_span, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, std::views::all, args...);

#if 1//_ONEDPL_HETERO_BACKEND
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj, oneapi::dpl::views::all);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj, subrange_view);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, std::identity{}, span_view);
test<usm_subrange, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj);
test<usm_span, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, std::identity{}, args...);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, oneapi::dpl::views::all, args...);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, subrange_view, args...);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, span_view, args...);
test<usm_subrange, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, std::identity{}, args...);
test<usm_span, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, std::identity{}, args...);

#if 0 //sycl buffer
test<sycl_buffer, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, std::identity{}, oneapi::dpl::views::all);
test<sycl_buffer, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, oneapi::dpl::views::all, f, args...);
#endif

#endif //_ONEDPL_HETERO_BACKEND
Expand Down

0 comments on commit 946e7aa

Please sign in to comment.