diff --git a/include/oneapi/dpl/pstl/algorithm_impl.h b/include/oneapi/dpl/pstl/algorithm_impl.h index c04897f1..61b0ca2c 100644 --- a/include/oneapi/dpl/pstl/algorithm_impl.h +++ b/include/oneapi/dpl/pstl/algorithm_impl.h @@ -2421,6 +2421,37 @@ __pattern_sort(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAcce }); } +// Add separate patterns for std::ranges::sort due to std::indirectly_swappable requirement, +// which implies the use of std::ranges::iter_swap, which can be customized externally +#if _ONEDPL_CPP20_RANGES_PRESENT +template +void +__pattern_sort_ranges(_Tag, _ExecutionPolicy&&, _RandomAccessIterator __first, _RandomAccessIterator __last, + _Compare __comp) noexcept +{ + static_assert(__is_serial_tag_v<_Tag> || __is_parallel_forward_tag_v<_Tag>); + + std::ranges::sort(__first, __last, __comp); +} + +template +void +__pattern_sort_ranges(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAccessIterator __first, + _RandomAccessIterator __last, _Compare __comp) +{ + using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag; + + __internal::__except_handler([&]() { + __par_backend::__parallel_stable_sort( + __backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), __first, __last, __comp, + [](_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare __comp) { + std::ranges::sort(__first, __last, __comp); + }, + __last - __first); + }); +} +#endif // _ONEDPL_CPP20_RANGES_PRESENT + //------------------------------------------------------------------------ // stable_sort //------------------------------------------------------------------------ diff --git a/include/oneapi/dpl/pstl/algorithm_ranges_impl.h b/include/oneapi/dpl/pstl/algorithm_ranges_impl.h index 5fd5d608..c148304f 100644 --- a/include/oneapi/dpl/pstl/algorithm_ranges_impl.h +++ b/include/oneapi/dpl/pstl/algorithm_ranges_impl.h @@ -343,7 +343,7 @@ __pattern_sort_ranges(_Tag __tag, _ExecutionPolicy&& __exec, _R&& __r, _Comp __c auto __comp_2 = [__comp, __proj](auto&& __val1, auto&& __val2) { return std::invoke(__comp, std::invoke(__proj, std::forward(__val1)), std::invoke(__proj, std::forward(__val2)));}; - oneapi::dpl::__internal::__pattern_sort(__tag, std::forward<_ExecutionPolicy>(__exec), std::ranges::begin(__r), + oneapi::dpl::__internal::__pattern_sort_ranges(__tag, std::forward<_ExecutionPolicy>(__exec), std::ranges::begin(__r), std::ranges::begin(__r) + std::ranges::size(__r), __comp_2); return std::ranges::borrowed_iterator_t<_R>(std::ranges::begin(__r) + std::ranges::size(__r)); diff --git a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h index f8b56bef..ed55f0f7 100644 --- a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h +++ b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h @@ -1043,7 +1043,7 @@ sort(_ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp, _Proj __proj) { const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec, __rng); - oneapi::dpl::__internal::__ranges::__pattern_sort(__dispatch_tag, ::std::forward<_ExecutionPolicy>(__exec), + oneapi::dpl::__internal::__ranges::__pattern_sort_cpp17_ranges(__dispatch_tag, ::std::forward<_ExecutionPolicy>(__exec), views::all(::std::forward<_Range>(__rng)), __comp, __proj); } diff --git a/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h b/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h index 0e2a333a..4bc67f18 100644 --- a/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h +++ b/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h @@ -749,7 +749,7 @@ __pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1& template void -__pattern_sort(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp, _Proj __proj) +__pattern_sort_cpp17_ranges(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp, _Proj __proj) { if (__rng.size() >= 2) __par_backend_hetero::__parallel_stable_sort(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), @@ -762,7 +762,7 @@ template __tag, _ExecutionPolicy&& __exec, _R&& __r, _Comp __comp, _Proj __proj) { - oneapi::dpl::__internal::__ranges::__pattern_sort(__tag, std::forward<_ExecutionPolicy>(__exec), + oneapi::dpl::__internal::__ranges::__pattern_sort_cpp17_ranges(__tag, std::forward<_ExecutionPolicy>(__exec), oneapi::dpl::__ranges::views::all(__r), __comp, __proj); return std::ranges::borrowed_iterator_t<_R>(std::ranges::begin(__r) + std::ranges::size(__r)); diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index d7fde2e9..9e4f116d 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -29,6 +29,7 @@ #include #include +#include "../../utils.h" #include "../../iterator_impl.h" #include "../../execution_impl.h" #include "../../utils_ranges.h" @@ -1914,7 +1915,21 @@ __parallel_stable_sort(oneapi::dpl::__internal::__device_backend_tag __backend_t return __parallel_radix_sort<__internal::__is_comp_ascending<::std::decay_t<_Compare>>::value>( __backend_tag, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng), __proj); } -#endif + +#if _ONEDPL_CPP20_RANGES_PRESENT +template < + typename _ExecutionPolicy, typename _Range, typename _Compare, typename _Proj, + ::std::enable_if_t< + __is_radix_sort_usable_for_type, _Compare>::value, int> = 0> +auto +__parallel_stable_sort_ranges(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec, + _Range&& __rng, _Compare, _Proj __proj) +{ + return __parallel_radix_sort<__internal::__is_comp_ascending<::std::decay_t<_Compare>>::value>( + __backend_tag, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng), __proj); +} +#endif // _ONEDPL_CPP20_RANGES_PRESENT +#endif // _USE_RADIX_SORT template < typename _ExecutionPolicy, typename _Range, typename _Compare, typename _Proj, @@ -1924,9 +1939,25 @@ auto __parallel_stable_sort(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp, _Proj __proj) { - return __parallel_sort_impl(__backend_tag, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng), - oneapi::dpl::__internal::__compare<_Compare, _Proj>{__comp, __proj}); + return __parallel_sort_impl( + __backend_tag, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng), + oneapi::dpl::__internal::__compare<_Compare, _Proj>{__comp, __proj}); +} + +#if _ONEDPL_CPP20_RANGES_PRESENT +template < + typename _ExecutionPolicy, typename _Range, typename _Compare, typename _Proj, + ::std::enable_if_t< + !__is_radix_sort_usable_for_type, _Compare>::value, int> = 0> +auto +__parallel_stable_sort_ranges(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec, + _Range&& __rng, _Compare __comp, _Proj __proj) +{ + return __parallel_sort_impl( + __backend_tag, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng), + oneapi::dpl::__internal::__compare<_Compare, _Proj>{__comp, __proj}); } +#endif // _ONEDPL_CPP20_RANGES_PRESENT //------------------------------------------------------------------------ // parallel_partial_sort - async pattern diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h index 19a4f25b..61cc7f5d 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge_sort.h @@ -23,8 +23,12 @@ #include // std::min, std::max_element #include // std::decay_t, std::integral_constant +#if _ONEDPL_CPP20_RANGES_PRESENT +# include // std::ranges::iter_swap +#endif + #include "sycl_defs.h" // __dpl_sycl::__local_accessor, __dpl_sycl::__group_barrier -#include "../../utils.h" // __dpl_bit_floor, __dpl_bit_ceil +#include "../../utils.h" // __dpl_bit_floor, __dpl_bit_ceil, __classic_sort_policy, __ranges_sort_policy #include "parallel_backend_sycl_merge.h" // __find_start_point, __serial_merge namespace oneapi @@ -34,6 +38,7 @@ namespace dpl namespace __par_backend_hetero { +template struct __subgroup_bubble_sorter { template @@ -48,8 +53,17 @@ struct __subgroup_bubble_sorter auto& __second_item = __storage_acc[j]; if (__comp(__second_item, __first_item)) { - using std::swap; - swap(__first_item, __second_item); + if constexpr (std::is_same_v<_SortPolicy, oneapi::dpl::__internal::__classic_sort_policy>) + { + using std::swap; + swap(__first_item, __second_item); + } +#if _ONEDPL_CPP20_RANGES_PRESENT + else + { + std::ranges::iter_swap(&__first_item, &__second_item); + } +#endif } } } @@ -103,7 +117,7 @@ struct __group_merge_path_sorter } }; -template +template struct __leaf_sorter { using _Tp = oneapi::dpl::__internal::__value_t<_Range>; @@ -111,7 +125,7 @@ struct __leaf_sorter using _StorageAcc = __dpl_sycl::__local_accessor<_Tp>; // TODO: select a better sub-group sorter depending on sort stability, // a type (e.g. it can be trivially copied for shuffling within a sub-group) - using _SubGroupSorter = __subgroup_bubble_sorter; + using _SubGroupSorter = __subgroup_bubble_sorter<_SortPolicy>; using _GroupSorter = __group_merge_path_sorter; static std::uint32_t @@ -316,11 +330,11 @@ class __sort_global_kernel; template class __sort_copy_back_kernel; -template +template auto __submit_selecting_leaf(_ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp) { - using _Leaf = __leaf_sorter, _Compare>; + using _Leaf = __leaf_sorter<_SortPolicy, std::decay_t<_Range>, _Compare>; using _Tp = oneapi::dpl::__internal::__value_t<_Range>; using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; @@ -380,20 +394,22 @@ __submit_selecting_leaf(_ExecutionPolicy&& __exec, _Range&& __rng, _Compare __co std::forward<_ExecutionPolicy>(__exec), std::forward<_Range>(__rng), __comp, __leaf); }; -template +template auto __parallel_sort_impl(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp) { + static_assert(std::is_same_v<_SortPolicy, oneapi::dpl::__internal::__classic_sort_policy> || + std::is_same_v<_SortPolicy, oneapi::dpl::__internal::__ranges_sort_policy>); if (__rng.size() <= std::numeric_limits::max()) { - return __submit_selecting_leaf(std::forward<_ExecutionPolicy>(__exec), - std::forward<_Range>(__rng), __comp); + return __submit_selecting_leaf(std::forward<_ExecutionPolicy>(__exec), + std::forward<_Range>(__rng), __comp); } else { - return __submit_selecting_leaf(std::forward<_ExecutionPolicy>(__exec), - std::forward<_Range>(__rng), __comp); + return __submit_selecting_leaf(std::forward<_ExecutionPolicy>(__exec), + std::forward<_Range>(__rng), __comp); } } diff --git a/include/oneapi/dpl/pstl/utils.h b/include/oneapi/dpl/pstl/utils.h index 856a5eb2..fd70912a 100644 --- a/include/oneapi/dpl/pstl/utils.h +++ b/include/oneapi/dpl/pstl/utils.h @@ -784,6 +784,12 @@ union __lazy_ctor_storage } }; +// Helpers to distinguish between std::sort and std::ranges::sort calls on a backend level +// It is used to select a correct swap method: +// std::swap for std::sort and std::ranges::swap for std::ranges::sort +struct __ranges_sort_policy{}; +struct __classic_sort_policy{}; + } // namespace __internal } // namespace dpl } // namespace oneapi