From 574ba4d85edd05537e1924e74844d5486b696241 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Mon, 5 Dec 2022 05:40:58 -0800 Subject: [PATCH 01/13] [SYCL] Add support for radix_sorter Co-authored-by: Andrei Fedorov --- sycl/include/sycl/detail/group_sort_impl.hpp | 423 ++++++++++++++++++ .../experimental/group_helpers_sorters.hpp | 100 ++++- sycl/include/sycl/group_algorithm.hpp | 1 - sycl/include/sycl/sycl.hpp | 1 + 4 files changed, 523 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index b82f49c1e0fa1..34025db334dee 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -250,6 +250,429 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, } } +// traits for ascending functors +template struct is_comp_ascending { + static constexpr bool value = false; +}; +template struct is_comp_ascending> { + static constexpr bool value = true; +}; + +// traits for descending functors +template struct is_comp_descending { + static constexpr bool value = false; +}; +template struct is_comp_descending> { + static constexpr bool value = true; +}; + +// get number of states radix bits can represent +__attribute__((always_inline)) constexpr std::uint32_t +get_states_in_bits(std::uint32_t radix_bits) { + return (1 << radix_bits); +} + +//------------------------------------------------------------------------ +// ordered traits for a given size and integral/float flag +//------------------------------------------------------------------------ + +template struct get_ordered {}; + +template <> struct get_ordered<1, true> { + using _type = uint8_t; + constexpr static std::int8_t mask = 0x80; +}; + +template <> struct get_ordered<2, true> { + using _type = uint16_t; + constexpr static std::int16_t mask = 0x8000; +}; + +template <> struct get_ordered<4, true> { + using _type = uint32_t; + constexpr static std::int32_t mask = 0x80000000; +}; + +template <> struct get_ordered<8, true> { + using _type = uint64_t; + constexpr static std::int64_t mask = 0x8000000000000000; +}; + +template <> struct get_ordered<2, false> { + using _type = uint16_t; + constexpr static std::uint32_t nmask = 0xFFFF; // for negative numbers + constexpr static std::uint32_t pmask = 0x8000; // for positive numbers +}; + +template <> struct get_ordered<4, false> { + using _type = uint32_t; + constexpr static std::uint32_t nmask = 0xFFFFFFFF; // for negative numbers + constexpr static std::uint32_t pmask = 0x80000000; // for positive numbers +}; + +template <> struct get_ordered<8, false> { + using _type = uint64_t; + constexpr static std::uint64_t nmask = + 0xFFFFFFFFFFFFFFFF; // for negative numbers + constexpr static std::uint64_t pmask = + 0x8000000000000000; // for positive numbers +}; + +//------------------------------------------------------------------------ +// ordered type for a given type +//------------------------------------------------------------------------ + +// for unknown/unsupported type we do not have any trait +template struct ordered {}; + +// for unsigned integrals we use the same type +template +struct ordered<_T, std::enable_if_t::value && + std::is_unsigned<_T>::value>> { + using _type = _T; +}; + +// for signed integrals or floatings we map: size -> corresponding unsigned +// integral +template +struct ordered<_T, std::enable_if_t<(std::is_integral<_T>::value && + std::is_signed<_T>::value) || + !std::is_integral<_T>::value>> { + using _type = + typename get_ordered::value>::_type; +}; + +// shorthands +template using ordered_t = typename ordered<_T>::_type; + +//------------------------------------------------------------------------ +// functions for conversion to ordered type +//------------------------------------------------------------------------ + +// for already ordered types (any uints) we use the same type +template +__attribute__((always_inline)) +std::enable_if_t>::value, ordered_t<_T>> +convert_to_ordered(_T value) { + return value; +} + +// converts integral type to ordered (in terms of bitness) type +template +__attribute__((always_inline)) +std::enable_if_t>::value && + std::is_integral<_T>::value, + ordered_t<_T>> +convert_to_ordered(_T value) { + _T result = value ^ get_ordered::mask; + return *reinterpret_cast *>(&result); +} + +// converts floating type to ordered (in terms of bitness) type +template +__attribute__((always_inline)) +std::enable_if_t>::value && + !std::is_integral<_T>::value, + ordered_t<_T>> +convert_to_ordered(_T value) { + ordered_t<_T> uvalue = *reinterpret_cast *>(&value); + // check if value negative + ordered_t<_T> is_negative = uvalue >> (sizeof(_T) * CHAR_BIT - 1); + // for positive: 00..00 -> 00..00 -> 10..00 + // for negative: 00..01 -> 11..11 -> 11..11 + ordered_t<_T> ordered_mask = + (is_negative * get_ordered::nmask) | + get_ordered::pmask; + return uvalue ^ ordered_mask; +} + +//------------------------------------------------------------------------ +// bit pattern functions +//------------------------------------------------------------------------ + +// required for descending comparator support +template struct invert_if { + template + __attribute__((always_inline)) _T operator()(_T value) { + return value; + } +}; + +// invert value if descending comparator is passed +template <> struct invert_if { + template + __attribute__((always_inline)) _T operator()(_T value) { + return ~value; + } + + // invertation for bool type have to be logical, rather than bit + __attribute__((always_inline)) bool operator()(bool value) { return !value; } +}; +// get bit values in a certain bucket of a value +template +__attribute__((always_inline)) std::uint32_t +get_bucket_value(_T value, std::uint32_t radix_iter) { + // invert value if we need to sort in descending order + value = invert_if{}(value); + + // get bucket offset idx from the end of bit type (least significant bits) + std::uint32_t bucket_offset = radix_iter * radix_bits; + + // get offset mask for one bucket, e.g. + // radix_bits=2: 0000 0001 -> 0000 0100 -> 0000 0011 + ordered_t<_T> bucket_mask = (1u << radix_bits) - 1u; + + // get bits under bucket mask + return (value >> bucket_offset) & bucket_mask; +} +template +__attribute__((always_inline)) T get_default_value(std::less) { + return std::numeric_limits::max(); +} + +template +__attribute__((always_inline)) T get_default_value(std::greater) { + return std::numeric_limits::lowest(); +} + +template struct values_assigner { + template + void operator()(IterOut output, size_t idx_out, IterIn input, size_t idx_in) { + output[idx_out] = input[idx_in]; + } + + template + void operator()(IterOut output, size_t idx_out, T value) { + output[idx_out] = value; + } +}; + +template <> struct values_assigner { + template + void operator()(IterOut, size_t, IterIn, size_t) {} + + template void operator()(IterOut, size_t, T) {} +}; + +// The iteration of radix sort for unknown number of elements per work item +template +void perform_radix_iter_joint(GroupT group, const uint32_t items_per_work_item, + const uint32_t radix_iter, const size_t n, + KeysT *keys_input, ValueT *vals_input, + KeysT *keys_output, ValueT *vals_output, + uint32_t *memory, CompareT comp) { + const uint32_t radix_states = get_states_in_bits(radix_bits); + const size_t wgsize = group.get_local_linear_range(); + const size_t idx = group.get_local_linear_id(); + + constexpr bool is_comp_asc = + is_comp_ascending::type>::value; + + // 1.1. Zeroinitialize local memory + + uint32_t *scan_memory = reinterpret_cast(memory); + for (uint32_t state = 0; state < radix_states; ++state) + scan_memory[state * wgsize + idx] = 0; + + sycl::group_barrier(group); + + // 1.2. count values and write result to private count array and count memory + + for (uint32_t i = 0; i < items_per_work_item; ++i) { + const uint32_t val_idx = items_per_work_item * idx + i; + // get value, convert it to ordered (in terms of bitness) + const auto val = convert_to_ordered( + (val_idx < n) ? keys_input[val_idx] : get_default_value(comp)); + // get bit values in a certain bucket of a value + const uint32_t bucket_val = + get_bucket_value(val, radix_iter); + + // increment counter for this bit bucket + if (val_idx < n) + scan_memory[bucket_val * wgsize + idx]++; + } + + sycl::group_barrier(group); + + // 2.1 Scan. Upsweep: reduce over radix states + uint32_t reduced = 0; + for (uint32_t i = 0; i < radix_states; ++i) + reduced += scan_memory[idx * radix_states + i]; + + // 2.2. Exclusive scan: over work items + uint32_t scanned = + sycl::exclusive_scan_over_group(group, reduced, std::plus()); + + // 2.3. Exclusive downsweep: exclusive scan over radix states + for (uint32_t i = 0; i < radix_states; ++i) { + uint32_t value = scan_memory[idx * radix_states + i]; + scan_memory[idx * radix_states + i] = scanned; + scanned += value; + } + + sycl::group_barrier(group); + + uint32_t private_scan_memory[radix_states] = {0}; + + // 3. Reorder + for (uint32_t i = 0; i < items_per_work_item; ++i) { + const uint32_t val_idx = items_per_work_item * idx + i; + // get value, convert it to ordered (in terms of bitness) + auto val = convert_to_ordered((val_idx < n) ? keys_input[val_idx] + : get_default_value(comp)); + // get bit values in a certain bucket of a value + uint32_t bucket_val = + get_bucket_value(val, radix_iter); + + uint32_t new_offset_idx = private_scan_memory[bucket_val]++ + + scan_memory[bucket_val * wgsize + idx]; + if (val_idx < n) { + keys_output[new_offset_idx] = keys_input[val_idx]; + values_assigner()(vals_output, new_offset_idx, + vals_input, val_idx); + } + } +} + +// The iteration of radix sort for known number of elements per work item +template +void perform_radix_iter_static_size(GroupT group, const uint32_t radix_iter, + const uint32_t last_iter, KeysT *keys, + ValsT vals, std::byte *memory) { + const uint32_t radix_states = get_states_in_bits(radix_bits); + const size_t wgsize = group.get_local_linear_range(); + const size_t idx = group.get_local_linear_id(); + + // 1.1. count per witem: create a private array for storing count values + uint32_t count_arr[items_per_work_item] = {0}; + uint32_t ranks[items_per_work_item] = {0}; + + // 1.1. Zeroinitialize local memory + uint32_t *scan_memory = reinterpret_cast(memory); + for (uint32_t i = 0; i < radix_states; ++i) + scan_memory[i * wgsize + idx] = 0; + + sycl::group_barrier(group); + + uint32_t *pointers[items_per_work_item] = {nullptr}; + // 1.2. count values and write result to private count array + for (uint32_t i = 0; i < items_per_work_item; ++i) { + // get value, convert it to ordered (in terms of bitness) + ordered_t val = convert_to_ordered(keys[i]); + // get bit values in a certain bucket of a value + uint32_t bucket_val = + get_bucket_value(val, radix_iter); + pointers[i] = scan_memory + (bucket_val * wgsize + idx); + count_arr[i] = (*pointers[i])++; + } + sycl::group_barrier(group); + + // 2.1 Scan. Upsweep: reduce over radix states + uint32_t reduced = 0; + for (uint32_t i = 0; i < radix_states; ++i) + reduced += scan_memory[idx * radix_states + i]; + + // 2.2. Exclusive scan: over work items + uint32_t scanned = + sycl::exclusive_scan_over_group(group, reduced, std::plus()); + + // 2.3. Exclusive downsweep: exclusive scan over radix states + for (uint32_t i = 0; i < radix_states; ++i) { + uint32_t value = scan_memory[idx * radix_states + i]; + scan_memory[idx * radix_states + i] = scanned; + scanned += value; + } + + sycl::group_barrier(group); + + // 2.4. Fill ranks with offsets + for (uint32_t i = 0; i < items_per_work_item; ++i) + ranks[i] = count_arr[i] + *pointers[i]; + + sycl::group_barrier(group); + + // 3. Reorder + KeysT *keys_temp = reinterpret_cast(memory); + ValsT vals_temp = reinterpret_cast( + memory + wgsize * items_per_work_item * sizeof(KeysT)); + for (uint32_t i = 0; i < items_per_work_item; ++i) { + keys_temp[ranks[i]] = keys[i]; + values_assigner()(vals_temp, ranks[i], vals, i); + } + + sycl::group_barrier(group); + + // 4. Copy back to input + for (uint32_t i = 0; i < items_per_work_item; ++i) { + std::size_t shift = idx * items_per_work_item + i; + if constexpr (!is_blocked) { + if (radix_iter == last_iter - 1) + shift = i * wgsize + idx; + } + keys[i] = keys_temp[shift]; + values_assigner()(vals, i, vals_temp, shift); + } +} + +template +void private_sort(GroupT group, KeysT *keys, ValsT *values, const size_t n, + CompareT comp, std::byte *scratch, const uint32_t first_bit, + const uint32_t last_bit) { + const size_t wgsize = group.get_local_linear_range(); + constexpr uint32_t radix_states = get_states_in_bits(radix_bits); + const uint32_t first_iter = first_bit / radix_bits; + const uint32_t last_iter = last_bit / radix_bits; + + KeysT *keys_input = keys; + ValsT *vals_input = values; + const uint32_t runtime_items_per_work_item = (n - 1) / wgsize + 1; + + // set pointers to unaligned memory + uint32_t *scan_memory = reinterpret_cast(scratch); + KeysT *keys_output = reinterpret_cast( + scratch + radix_states * wgsize * sizeof(uint32_t)); + // Adding 4 bytes extra space for keys due to specifics of some hardware + // architectures. + ValsT *vals_output = reinterpret_cast( + keys_output + is_key_value_sort * n * sizeof(KeysT) + alignof(uint32_t)); + + for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) { + perform_radix_iter_joint( + group, runtime_items_per_work_item, radix_iter, n, keys_input, + vals_input, keys_output, vals_output, scan_memory, comp); + + sycl::group_barrier(group); + + std::swap(keys_input, keys_output); + std::swap(vals_input, vals_output); + } +} + +template +void private_memory_sort(Group group, T *keys, U *values, Compare comp, + std::byte *scratch, const uint32_t first_bit, + const uint32_t last_bit) { + (void)comp; + constexpr bool is_comp_asc = + is_comp_ascending::type>::value; + const uint32_t first_iter = first_bit / radix_bits; + const uint32_t last_iter = last_bit / radix_bits; + + for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) { + perform_radix_iter_static_size( + group, radix_iter, last_iter, keys, values, scratch); + sycl::group_barrier(group); + } +} + } // namespace detail } // __SYCL_INLINE_VER_NAMESPACE(_V1) } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp index 6144ff1a9f129..54283129f9313 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp @@ -8,8 +8,10 @@ #pragma once -#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +#if !defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0 #include +#include +#include namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { @@ -92,6 +94,102 @@ template > class default_sorter { } }; +enum class sorting_order { ascending, descending }; + +namespace detail { + +template +struct ConvertToComp { + using Type = std::less; +}; + +template struct ConvertToComp { + using Type = std::greater; +}; +} // namespace detail + + +template +class radix_sorter { + + std::byte *scratch = nullptr; + uint32_t first_bit = 0; + uint32_t last_bit = 0; + std::size_t scratch_size = 0; + + static constexpr uint32_t bits = BitsPerPass; + +public: + template + radix_sorter(sycl::span scratch_, + const std::bitset mask = + std::bitset( + std::numeric_limits::max())) + : scratch(scratch_.data()), scratch_size(scratch_.size()) { + static_assert((std::is_arithmetic::value || + std::is_same::value), + "radix sort is not usable"); + + first_bit = 0; + while (first_bit < mask.size() && !mask[first_bit]) + ++first_bit; + + last_bit = first_bit; + while (last_bit < mask.size() && mask[last_bit]) + ++last_bit; + } + + template + void operator()(GroupT g, PtrT first, PtrT last) { + (void)g; + (void)first; + (void)last; +#ifdef __SYCL_DEVICE_ONLY__ + sycl::detail::private_sort( + g, first, /*empty*/ first, (last - first) > 0 ? (last - first) : 0, + typename detail::ConvertToComp::Type{}, scratch, + first_bit, last_bit); +#endif + } + + template ValT operator()(GroupT g, ValT val) { + (void)g; + (void)val; +#ifdef __SYCL_DEVICE_ONLY__ + ValT result[]{val}; + sycl::detail::private_memory_sort( + g, result, /*empty*/ result, + typename detail::ConvertToComp::Type{}, scratch, + first_bit, last_bit); + return result[0]; +#else + return ValT{}; +#endif + } + + static constexpr std::size_t memory_required(sycl::memory_scope scope, + std::size_t range_size) { + // Scope is not important so far + (void)scope; + return range_size * sizeof(ValT) + + (1 << bits) * range_size * sizeof(uint32_t) + alignof(uint32_t); + } + + // memory_helpers + template + static constexpr size_t memory_required(sycl::memory_scope scope, + sycl::range local_range) { + // Scope is not important so far + (void)scope; + return std::max(local_range.size() * sizeof(ValT), + local_range.size() * (1 << bits) * sizeof(uint32_t)); + } +}; + } // namespace experimental } // namespace oneapi } // namespace ext diff --git a/sycl/include/sycl/group_algorithm.hpp b/sycl/include/sycl/group_algorithm.hpp index 05e335cb9200a..72b5bc712e2d9 100644 --- a/sycl/include/sycl/group_algorithm.hpp +++ b/sycl/include/sycl/group_algorithm.hpp @@ -15,7 +15,6 @@ #include #include #include -#include #include #include #include diff --git a/sycl/include/sycl/sycl.hpp b/sycl/include/sycl/sycl.hpp index 5a3e82cfacbf4..0f55034ab366d 100644 --- a/sycl/include/sycl/sycl.hpp +++ b/sycl/include/sycl/sycl.hpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include From 77e2b46175995d6d5ba09292bded03abc71cb178 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Mon, 5 Dec 2022 12:05:14 -0800 Subject: [PATCH 02/13] more support --- sycl/include/sycl/detail/group_sort_impl.hpp | 226 +++++++++---------- 1 file changed, 105 insertions(+), 121 deletions(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index 34025db334dee..1e45287a90daa 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -251,67 +251,58 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, } // traits for ascending functors -template struct is_comp_ascending { +template struct IsCompAscending { static constexpr bool value = false; }; -template struct is_comp_ascending> { - static constexpr bool value = true; -}; - -// traits for descending functors -template struct is_comp_descending { - static constexpr bool value = false; -}; -template struct is_comp_descending> { +template struct IsCompAscending> { static constexpr bool value = true; }; // get number of states radix bits can represent -__attribute__((always_inline)) constexpr std::uint32_t -get_states_in_bits(std::uint32_t radix_bits) { +constexpr std::uint32_t get_states_in_bits(std::uint32_t radix_bits) { return (1 << radix_bits); } //------------------------------------------------------------------------ -// ordered traits for a given size and integral/float flag +// Ordered traits for a given size and integral/float flag //------------------------------------------------------------------------ -template struct get_ordered {}; +template struct GetOrdered {}; -template <> struct get_ordered<1, true> { - using _type = uint8_t; +template <> struct GetOrdered<1, true> { + using Type = uint8_t; constexpr static std::int8_t mask = 0x80; }; -template <> struct get_ordered<2, true> { - using _type = uint16_t; +template <> struct GetOrdered<2, true> { + using Type = uint16_t; constexpr static std::int16_t mask = 0x8000; }; -template <> struct get_ordered<4, true> { - using _type = uint32_t; +template <> struct GetOrdered<4, true> { + using Type = uint32_t; constexpr static std::int32_t mask = 0x80000000; }; -template <> struct get_ordered<8, true> { - using _type = uint64_t; +template <> struct GetOrdered<8, true> { + using Type = uint64_t; constexpr static std::int64_t mask = 0x8000000000000000; }; -template <> struct get_ordered<2, false> { - using _type = uint16_t; +template <> struct GetOrdered<2, false> { + using Type = uint16_t; constexpr static std::uint32_t nmask = 0xFFFF; // for negative numbers constexpr static std::uint32_t pmask = 0x8000; // for positive numbers }; -template <> struct get_ordered<4, false> { - using _type = uint32_t; +template <> struct GetOrdered<4, false> { + using Type = uint32_t; constexpr static std::uint32_t nmask = 0xFFFFFFFF; // for negative numbers constexpr static std::uint32_t pmask = 0x80000000; // for positive numbers }; -template <> struct get_ordered<8, false> { - using _type = uint64_t; +template <> struct GetOrdered<8, false> { + using Type = uint64_t; constexpr static std::uint64_t nmask = 0xFFFFFFFFFFFFFFFF; // for negative numbers constexpr static std::uint64_t pmask = @@ -319,70 +310,67 @@ template <> struct get_ordered<8, false> { }; //------------------------------------------------------------------------ -// ordered type for a given type +// Ordered type for a given type //------------------------------------------------------------------------ // for unknown/unsupported type we do not have any trait -template struct ordered {}; +template struct Ordered {}; // for unsigned integrals we use the same type -template -struct ordered<_T, std::enable_if_t::value && - std::is_unsigned<_T>::value>> { - using _type = _T; +template +struct Ordered::value && + std::is_unsigned::value>> { + using Type = ValT; }; // for signed integrals or floatings we map: size -> corresponding unsigned // integral -template -struct ordered<_T, std::enable_if_t<(std::is_integral<_T>::value && - std::is_signed<_T>::value) || - !std::is_integral<_T>::value>> { - using _type = - typename get_ordered::value>::_type; +template +struct Ordered::value && + std::is_signed::value) || + !std::is_integral::value>> { + using Type = + typename GetOrdered::value>::Type; }; // shorthands -template using ordered_t = typename ordered<_T>::_type; +template using OrderedT = typename Ordered::Type; //------------------------------------------------------------------------ -// functions for conversion to ordered type +// functions for conversion to Ordered type //------------------------------------------------------------------------ -// for already ordered types (any uints) we use the same type -template -__attribute__((always_inline)) -std::enable_if_t>::value, ordered_t<_T>> -convert_to_ordered(_T value) { +// for already Ordered types (any uints) we use the same type +template +std::enable_if_t>, OrderedT> +convertToOrdered(ValT value) { return value; } -// converts integral type to ordered (in terms of bitness) type -template -__attribute__((always_inline)) -std::enable_if_t>::value && - std::is_integral<_T>::value, - ordered_t<_T>> -convert_to_ordered(_T value) { - _T result = value ^ get_ordered::mask; - return *reinterpret_cast *>(&result); +// converts integral type to Ordered (in terms of bitness) type +template +std::enable_if_t>::value && + std::is_integral::value, + OrderedT> +convertToOrdered(ValT value) { + ValT result = value ^ GetOrdered::mask; + return *reinterpret_cast *>(&result); } -// converts floating type to ordered (in terms of bitness) type -template -__attribute__((always_inline)) -std::enable_if_t>::value && - !std::is_integral<_T>::value, - ordered_t<_T>> -convert_to_ordered(_T value) { - ordered_t<_T> uvalue = *reinterpret_cast *>(&value); +// converts floating type to Ordered (in terms of bitness) type +template +std::enable_if_t>::value && + !std::is_integral::value, + OrderedT> +convertToOrdered(ValT value) { + OrderedT uvalue = *reinterpret_cast *>(&value); // check if value negative - ordered_t<_T> is_negative = uvalue >> (sizeof(_T) * CHAR_BIT - 1); + OrderedT is_negative = uvalue >> (sizeof(ValT) * CHAR_BIT - 1); // for positive: 00..00 -> 00..00 -> 10..00 // for negative: 00..01 -> 11..11 -> 11..11 - ordered_t<_T> ordered_mask = - (is_negative * get_ordered::nmask) | - get_ordered::pmask; + OrderedT ordered_mask = + (is_negative * GetOrdered::nmask) | + GetOrdered::pmask; return uvalue ^ ordered_mask; } @@ -391,83 +379,79 @@ convert_to_ordered(_T value) { //------------------------------------------------------------------------ // required for descending comparator support -template struct invert_if { - template - __attribute__((always_inline)) _T operator()(_T value) { - return value; - } +template struct InvertIf { + template ValT operator()(ValT value) { return value; } }; // invert value if descending comparator is passed -template <> struct invert_if { - template - __attribute__((always_inline)) _T operator()(_T value) { - return ~value; - } +template <> struct InvertIf { + template ValT operator()(ValT value) { return ~value; } // invertation for bool type have to be logical, rather than bit - __attribute__((always_inline)) bool operator()(bool value) { return !value; } + bool operator()(bool value) { return !value; } }; + // get bit values in a certain bucket of a value -template -__attribute__((always_inline)) std::uint32_t -get_bucket_value(_T value, std::uint32_t radix_iter) { +template +std::uint32_t +getBucketValue(ValT value, std::uint32_t radix_iter) { // invert value if we need to sort in descending order - value = invert_if{}(value); + value = InvertIf{}(value); // get bucket offset idx from the end of bit type (least significant bits) std::uint32_t bucket_offset = radix_iter * radix_bits; // get offset mask for one bucket, e.g. // radix_bits=2: 0000 0001 -> 0000 0100 -> 0000 0011 - ordered_t<_T> bucket_mask = (1u << radix_bits) - 1u; + OrderedT bucket_mask = (1u << radix_bits) - 1u; // get bits under bucket mask return (value >> bucket_offset) & bucket_mask; } -template -__attribute__((always_inline)) T get_default_value(std::less) { - return std::numeric_limits::max(); +template ValT getDefaultValue(std::less) { + return std::numeric_limits::max(); } -template -__attribute__((always_inline)) T get_default_value(std::greater) { - return std::numeric_limits::lowest(); +template ValT getDefaultValue(std::greater) { + return std::numeric_limits::lowest(); } -template struct values_assigner { - template - void operator()(IterOut output, size_t idx_out, IterIn input, size_t idx_in) { +template struct ValuesAssigner { + template + void operator()(IterOutT output, size_t idx_out, IterInT input, + size_t idx_in) { output[idx_out] = input[idx_in]; } - template - void operator()(IterOut output, size_t idx_out, T value) { + template + void operator()(IterOutT output, size_t idx_out, ValT value) { output[idx_out] = value; } }; -template <> struct values_assigner { - template - void operator()(IterOut, size_t, IterIn, size_t) {} +template <> struct ValuesAssigner { + template + void operator()(IterOutT, size_t, IterInT, size_t) {} - template void operator()(IterOut, size_t, T) {} + template + void operator()(IterOutT, size_t, ValT) {} }; // The iteration of radix sort for unknown number of elements per work item template -void perform_radix_iter_joint(GroupT group, const uint32_t items_per_work_item, - const uint32_t radix_iter, const size_t n, - KeysT *keys_input, ValueT *vals_input, - KeysT *keys_output, ValueT *vals_output, - uint32_t *memory, CompareT comp) { +void perform_radix_iter_dynamic_size(GroupT group, + const uint32_t items_per_work_item, + const uint32_t radix_iter, const size_t n, + KeysT *keys_input, ValueT *vals_input, + KeysT *keys_output, ValueT *vals_output, + uint32_t *memory, CompareT comp) { const uint32_t radix_states = get_states_in_bits(radix_bits); const size_t wgsize = group.get_local_linear_range(); const size_t idx = group.get_local_linear_id(); constexpr bool is_comp_asc = - is_comp_ascending::type>::value; + IsCompAscending::type>::value; // 1.1. Zeroinitialize local memory @@ -481,12 +465,12 @@ void perform_radix_iter_joint(GroupT group, const uint32_t items_per_work_item, for (uint32_t i = 0; i < items_per_work_item; ++i) { const uint32_t val_idx = items_per_work_item * idx + i; - // get value, convert it to ordered (in terms of bitness) - const auto val = convert_to_ordered( - (val_idx < n) ? keys_input[val_idx] : get_default_value(comp)); + // get value, convert it to Ordered (in terms of bitness) + const auto val = convertToOrdered((val_idx < n) ? keys_input[val_idx] + : getDefaultValue(comp)); // get bit values in a certain bucket of a value const uint32_t bucket_val = - get_bucket_value(val, radix_iter); + getBucketValue(val, radix_iter); // increment counter for this bit bucket if (val_idx < n) @@ -518,18 +502,18 @@ void perform_radix_iter_joint(GroupT group, const uint32_t items_per_work_item, // 3. Reorder for (uint32_t i = 0; i < items_per_work_item; ++i) { const uint32_t val_idx = items_per_work_item * idx + i; - // get value, convert it to ordered (in terms of bitness) - auto val = convert_to_ordered((val_idx < n) ? keys_input[val_idx] - : get_default_value(comp)); + // get value, convert it to Ordered (in terms of bitness) + auto val = convertToOrdered((val_idx < n) ? keys_input[val_idx] + : getDefaultValue(comp)); // get bit values in a certain bucket of a value uint32_t bucket_val = - get_bucket_value(val, radix_iter); + getBucketValue(val, radix_iter); uint32_t new_offset_idx = private_scan_memory[bucket_val]++ + scan_memory[bucket_val * wgsize + idx]; if (val_idx < n) { keys_output[new_offset_idx] = keys_input[val_idx]; - values_assigner()(vals_output, new_offset_idx, + ValuesAssigner()(vals_output, new_offset_idx, vals_input, val_idx); } } @@ -560,11 +544,11 @@ void perform_radix_iter_static_size(GroupT group, const uint32_t radix_iter, uint32_t *pointers[items_per_work_item] = {nullptr}; // 1.2. count values and write result to private count array for (uint32_t i = 0; i < items_per_work_item; ++i) { - // get value, convert it to ordered (in terms of bitness) - ordered_t val = convert_to_ordered(keys[i]); + // get value, convert it to Ordered (in terms of bitness) + Ordered_t val = convertToOrdered(keys[i]); // get bit values in a certain bucket of a value uint32_t bucket_val = - get_bucket_value(val, radix_iter); + getBucketValue(val, radix_iter); pointers[i] = scan_memory + (bucket_val * wgsize + idx); count_arr[i] = (*pointers[i])++; } @@ -600,7 +584,7 @@ void perform_radix_iter_static_size(GroupT group, const uint32_t radix_iter, memory + wgsize * items_per_work_item * sizeof(KeysT)); for (uint32_t i = 0; i < items_per_work_item; ++i) { keys_temp[ranks[i]] = keys[i]; - values_assigner()(vals_temp, ranks[i], vals, i); + ValuesAssigner()(vals_temp, ranks[i], vals, i); } sycl::group_barrier(group); @@ -613,7 +597,7 @@ void perform_radix_iter_static_size(GroupT group, const uint32_t radix_iter, shift = i * wgsize + idx; } keys[i] = keys_temp[shift]; - values_assigner()(vals, i, vals_temp, shift); + ValuesAssigner()(vals, i, vals_temp, shift); } } @@ -642,7 +626,7 @@ void private_sort(GroupT group, KeysT *keys, ValsT *values, const size_t n, keys_output + is_key_value_sort * n * sizeof(KeysT) + alignof(uint32_t)); for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) { - perform_radix_iter_joint( + perform_radix_iter_dynamic_size( group, runtime_items_per_work_item, radix_iter, n, keys_input, vals_input, keys_output, vals_output, scan_memory, comp); @@ -661,7 +645,7 @@ void private_memory_sort(Group group, T *keys, U *values, Compare comp, const uint32_t last_bit) { (void)comp; constexpr bool is_comp_asc = - is_comp_ascending::type>::value; + IsCompAscending::type>::value; const uint32_t first_iter = first_bit / radix_bits; const uint32_t last_iter = last_bit / radix_bits; From 00c246424cb5b4d017c038d33a916394e800a17a Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Tue, 6 Dec 2022 04:28:05 -0800 Subject: [PATCH 03/13] more support --- sycl/include/sycl/detail/group_sort_impl.hpp | 42 +++++++++---------- .../experimental/group_helpers_sorters.hpp | 4 +- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index 1e45287a90daa..36e39aa0e2d9f 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -259,7 +259,7 @@ template struct IsCompAscending> { }; // get number of states radix bits can represent -constexpr std::uint32_t get_states_in_bits(std::uint32_t radix_bits) { +constexpr std::uint32_t getStatesInBits(std::uint32_t radix_bits) { return (1 << radix_bits); } @@ -440,13 +440,13 @@ template <> struct ValuesAssigner { // The iteration of radix sort for unknown number of elements per work item template -void perform_radix_iter_dynamic_size(GroupT group, - const uint32_t items_per_work_item, - const uint32_t radix_iter, const size_t n, - KeysT *keys_input, ValueT *vals_input, - KeysT *keys_output, ValueT *vals_output, - uint32_t *memory, CompareT comp) { - const uint32_t radix_states = get_states_in_bits(radix_bits); +void performRadixIterDynamicSize(GroupT group, + const uint32_t items_per_work_item, + const uint32_t radix_iter, const size_t n, + KeysT *keys_input, ValueT *vals_input, + KeysT *keys_output, ValueT *vals_output, + uint32_t *memory, CompareT comp) { + const uint32_t radix_states = getStatesInBits(radix_bits); const size_t wgsize = group.get_local_linear_range(); const size_t idx = group.get_local_linear_id(); @@ -523,10 +523,10 @@ void perform_radix_iter_dynamic_size(GroupT group, template -void perform_radix_iter_static_size(GroupT group, const uint32_t radix_iter, - const uint32_t last_iter, KeysT *keys, - ValsT vals, std::byte *memory) { - const uint32_t radix_states = get_states_in_bits(radix_bits); +void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter, + const uint32_t last_iter, KeysT *keys, + ValsT vals, std::byte *memory) { + const uint32_t radix_states = getStatesInBits(radix_bits); const size_t wgsize = group.get_local_linear_range(); const size_t idx = group.get_local_linear_id(); @@ -545,7 +545,7 @@ void perform_radix_iter_static_size(GroupT group, const uint32_t radix_iter, // 1.2. count values and write result to private count array for (uint32_t i = 0; i < items_per_work_item; ++i) { // get value, convert it to Ordered (in terms of bitness) - Ordered_t val = convertToOrdered(keys[i]); + OrderedT val = convertToOrdered(keys[i]); // get bit values in a certain bucket of a value uint32_t bucket_val = getBucketValue(val, radix_iter); @@ -604,11 +604,11 @@ void perform_radix_iter_static_size(GroupT group, const uint32_t radix_iter, template -void private_sort(GroupT group, KeysT *keys, ValsT *values, const size_t n, - CompareT comp, std::byte *scratch, const uint32_t first_bit, - const uint32_t last_bit) { +void privateSort(GroupT group, KeysT *keys, ValsT *values, const size_t n, + CompareT comp, std::byte *scratch, const uint32_t first_bit, + const uint32_t last_bit) { const size_t wgsize = group.get_local_linear_range(); - constexpr uint32_t radix_states = get_states_in_bits(radix_bits); + constexpr uint32_t radix_states = getStatesInBits(radix_bits); const uint32_t first_iter = first_bit / radix_bits; const uint32_t last_iter = last_bit / radix_bits; @@ -626,7 +626,7 @@ void private_sort(GroupT group, KeysT *keys, ValsT *values, const size_t n, keys_output + is_key_value_sort * n * sizeof(KeysT) + alignof(uint32_t)); for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) { - perform_radix_iter_dynamic_size( + performRadixIterDynamicSize( group, runtime_items_per_work_item, radix_iter, n, keys_input, vals_input, keys_output, vals_output, scan_memory, comp); @@ -640,7 +640,7 @@ void private_sort(GroupT group, KeysT *keys, ValsT *values, const size_t n, template -void private_memory_sort(Group group, T *keys, U *values, Compare comp, +void privateMemorySort(Group group, T *keys, U *values, Compare comp, std::byte *scratch, const uint32_t first_bit, const uint32_t last_bit) { (void)comp; @@ -650,8 +650,8 @@ void private_memory_sort(Group group, T *keys, U *values, Compare comp, const uint32_t last_iter = last_bit / radix_bits; for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) { - perform_radix_iter_static_size( + performRadixIterStaticSize( group, radix_iter, last_iter, keys, values, scratch); sycl::group_barrier(group); } diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp index 54283129f9313..ac4251667ae6e 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp @@ -146,7 +146,7 @@ class radix_sorter { (void)first; (void)last; #ifdef __SYCL_DEVICE_ONLY__ - sycl::detail::private_sort( g, first, /*empty*/ first, (last - first) > 0 ? (last - first) : 0, typename detail::ConvertToComp::Type{}, scratch, @@ -159,7 +159,7 @@ class radix_sorter { (void)val; #ifdef __SYCL_DEVICE_ONLY__ ValT result[]{val}; - sycl::detail::private_memory_sort( g, result, /*empty*/ result, From dc38d9a5705f9899c769094cc6cb8c1cfc749997 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Tue, 6 Dec 2022 04:39:27 -0800 Subject: [PATCH 04/13] more support --- sycl/include/sycl/detail/group_sort_impl.hpp | 11 +++++------ .../ext/oneapi/experimental/group_helpers_sorters.hpp | 8 +++----- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index 36e39aa0e2d9f..271a7b5ad915c 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -319,7 +319,7 @@ template struct Ordered {}; // for unsigned integrals we use the same type template struct Ordered::value && - std::is_unsigned::value>> { + std::is_unsigned::value>> { using Type = ValT; }; @@ -393,8 +393,7 @@ template <> struct InvertIf { // get bit values in a certain bucket of a value template -std::uint32_t -getBucketValue(ValT value, std::uint32_t radix_iter) { +std::uint32_t getBucketValue(ValT value, std::uint32_t radix_iter) { // invert value if we need to sort in descending order value = InvertIf{}(value); @@ -514,7 +513,7 @@ void performRadixIterDynamicSize(GroupT group, if (val_idx < n) { keys_output[new_offset_idx] = keys_input[val_idx]; ValuesAssigner()(vals_output, new_offset_idx, - vals_input, val_idx); + vals_input, val_idx); } } } @@ -641,8 +640,8 @@ template void privateMemorySort(Group group, T *keys, U *values, Compare comp, - std::byte *scratch, const uint32_t first_bit, - const uint32_t last_bit) { + std::byte *scratch, const uint32_t first_bit, + const uint32_t last_bit) { (void)comp; constexpr bool is_comp_asc = IsCompAscending::type>::value; diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp index ac4251667ae6e..63ba85e93699e 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp @@ -108,7 +108,6 @@ template struct ConvertToComp { }; } // namespace detail - template class radix_sorter { @@ -146,8 +145,7 @@ class radix_sorter { (void)first; (void)last; #ifdef __SYCL_DEVICE_ONLY__ - sycl::detail::privateSort( + sycl::detail::privateSort( g, first, /*empty*/ first, (last - first) > 0 ? (last - first) : 0, typename detail::ConvertToComp::Type{}, scratch, first_bit, last_bit); @@ -160,8 +158,8 @@ class radix_sorter { #ifdef __SYCL_DEVICE_ONLY__ ValT result[]{val}; sycl::detail::privateMemorySort( + /*is_blocked=*/true, + /*items_per_work_item=*/1, bits>( g, result, /*empty*/ result, typename detail::ConvertToComp::Type{}, scratch, first_bit, last_bit); From 0d81f5785e7ca080420d677d0c5a22ee10e1f546 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Tue, 6 Dec 2022 04:43:28 -0800 Subject: [PATCH 05/13] more support --- sycl/include/sycl/sycl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/sycl.hpp b/sycl/include/sycl/sycl.hpp index 9a0cbd211d76d..14b9361714e49 100644 --- a/sycl/include/sycl/sycl.hpp +++ b/sycl/include/sycl/sycl.hpp @@ -26,11 +26,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include #include From c1ac89b318c9d1781970816c9fcbdae57ad6d8e4 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Tue, 6 Dec 2022 05:26:03 -0800 Subject: [PATCH 06/13] Fix --- sycl/include/sycl/detail/group_sort_impl.hpp | 1 + .../sycl/ext/oneapi/experimental/group_helpers_sorters.hpp | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index 271a7b5ad915c..e99b5b4e39cd2 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -12,6 +12,7 @@ #include #include +#include #include #ifdef __SYCL_DEVICE_ONLY__ diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp index 63ba85e93699e..a1dffbce993ef 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp @@ -11,7 +11,6 @@ #if !defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0 #include #include -#include namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { From ce8b972de87763510e7b0c20764b3fd069bcd3b7 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Tue, 6 Dec 2022 05:26:22 -0800 Subject: [PATCH 07/13] Fix from another PR --- sycl/unittests/helpers/PiImage.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sycl/unittests/helpers/PiImage.hpp b/sycl/unittests/helpers/PiImage.hpp index 9038bca859077..e1dbec3edfb56 100644 --- a/sycl/unittests/helpers/PiImage.hpp +++ b/sycl/unittests/helpers/PiImage.hpp @@ -157,6 +157,14 @@ template class PiArray { bool MEntriesNeedUpdate = false; }; +#ifdef __cpp_deduction_guides +template +PiArray(std::vector) -> PiArray; + +template +PiArray(std::initializer_list) -> PiArray; +#endif // __cpp_deduction_guides + /// Convenience wrapper for pi_device_binary_property_set. class PiPropertySet { public: From 8ac386148c3dc0482e578d01b12fdd23096a0807 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Tue, 6 Dec 2022 05:39:34 -0800 Subject: [PATCH 08/13] Revert "Fix from another PR" This reverts commit ce8b972de87763510e7b0c20764b3fd069bcd3b7. --- sycl/unittests/helpers/PiImage.hpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sycl/unittests/helpers/PiImage.hpp b/sycl/unittests/helpers/PiImage.hpp index e1dbec3edfb56..9038bca859077 100644 --- a/sycl/unittests/helpers/PiImage.hpp +++ b/sycl/unittests/helpers/PiImage.hpp @@ -157,14 +157,6 @@ template class PiArray { bool MEntriesNeedUpdate = false; }; -#ifdef __cpp_deduction_guides -template -PiArray(std::vector) -> PiArray; - -template -PiArray(std::initializer_list) -> PiArray; -#endif // __cpp_deduction_guides - /// Convenience wrapper for pi_device_binary_property_set. class PiPropertySet { public: From c9c1d907a843a7f0017aa3a6e93ad40e2a612289 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Tue, 6 Dec 2022 05:50:51 -0800 Subject: [PATCH 09/13] more support --- sycl/include/sycl/detail/group_sort_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index e99b5b4e39cd2..b9680534f4359 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -11,8 +11,8 @@ #pragma once #include -#include #include +#include #include #ifdef __SYCL_DEVICE_ONLY__ From 1d02d00f4a9ed9e97365213136a2ac7dd92d2936 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Fri, 9 Dec 2022 07:43:12 -0800 Subject: [PATCH 10/13] apply comments --- sycl/include/sycl/detail/group_sort_impl.hpp | 16 +++++++++++----- .../experimental/group_helpers_sorters.hpp | 11 +++++++++-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index b9680534f4359..f3e3ae70fcaff 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -11,6 +11,7 @@ #pragma once #include +#include #include #include #include @@ -327,14 +328,17 @@ struct Ordered::value && // for signed integrals or floatings we map: size -> corresponding unsigned // integral template -struct Ordered::value && - std::is_signed::value) || - !std::is_integral::value>> { +struct Ordered< + ValT, std::enable_if_t< + (std::is_integral::value && std::is_signed::value) || + std::is_floating_point::value || + std::is_same::value || + std::is_same::value>> { using Type = typename GetOrdered::value>::Type; }; -// shorthands +// shorthand template using OrderedT = typename Ordered::Type; //------------------------------------------------------------------------ @@ -361,7 +365,9 @@ convertToOrdered(ValT value) { // converts floating type to Ordered (in terms of bitness) type template std::enable_if_t>::value && - !std::is_integral::value, + std::is_floating_point::value || + std::is_same::value || + std::is_same::value, OrderedT> convertToOrdered(ValT value) { OrderedT uvalue = *reinterpret_cast *>(&value); diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp index 7d43ce52e3a1f..48882357f6e6a 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp @@ -126,7 +126,8 @@ class radix_sorter { std::numeric_limits::max())) : scratch(scratch_.data()), scratch_size(scratch_.size()) { static_assert((std::is_arithmetic::value || - std::is_same::value), + std::is_same::value || + std::is_same::value), "radix sort is not usable"); first_bit = 0; @@ -148,6 +149,10 @@ class radix_sorter { g, first, /*empty*/ first, (last - first) > 0 ? (last - first) : 0, typename detail::ConvertToComp::Type{}, scratch, first_bit, last_bit); +#else + throw sycl::exception( + std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()), + "radix_sorter is not supported on host device."); #endif } @@ -164,7 +169,9 @@ class radix_sorter { first_bit, last_bit); return result[0]; #else - return ValT{}; + throw sycl::exception( + std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()), + "radix_sorter is not supported on host device."); #endif } From ab824b6d8383ca65aa186085f6e364b843266700 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Fri, 9 Dec 2022 08:51:36 -0800 Subject: [PATCH 11/13] apply comments --- sycl/include/sycl/detail/group_sort_impl.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index f3e3ae70fcaff..59eda28bf7788 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -365,9 +365,9 @@ convertToOrdered(ValT value) { // converts floating type to Ordered (in terms of bitness) type template std::enable_if_t>::value && - std::is_floating_point::value || - std::is_same::value || - std::is_same::value, + (std::is_floating_point::value || + std::is_same::value || + std::is_same::value), OrderedT> convertToOrdered(ValT value) { OrderedT uvalue = *reinterpret_cast *>(&value); From 07d0ed9a278c990289aa54f18af3e93add4ac447 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Sun, 11 Dec 2022 10:46:51 -0800 Subject: [PATCH 12/13] Addres comments --- sycl/include/sycl/detail/group_sort_impl.hpp | 29 +++++++++---------- .../experimental/group_helpers_sorters.hpp | 6 ++-- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index 59eda28bf7788..a6afc3c2ca2a2 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -320,22 +320,21 @@ template struct Ordered {}; // for unsigned integrals we use the same type template -struct Ordered::value && - std::is_unsigned::value>> { +struct Ordered && + std::is_unsigned_v>> { using Type = ValT; }; // for signed integrals or floatings we map: size -> corresponding unsigned // integral template -struct Ordered< - ValT, std::enable_if_t< - (std::is_integral::value && std::is_signed::value) || - std::is_floating_point::value || - std::is_same::value || - std::is_same::value>> { +struct Ordered && std::is_signed_v) || + std::is_floating_point_v || + std::is_same_v || + std::is_same_v>> { using Type = - typename GetOrdered::value>::Type; + typename GetOrdered>::Type; }; // shorthand @@ -354,8 +353,8 @@ convertToOrdered(ValT value) { // converts integral type to Ordered (in terms of bitness) type template -std::enable_if_t>::value && - std::is_integral::value, +std::enable_if_t> && + std::is_integral_v, OrderedT> convertToOrdered(ValT value) { ValT result = value ^ GetOrdered::mask; @@ -364,10 +363,10 @@ convertToOrdered(ValT value) { // converts floating type to Ordered (in terms of bitness) type template -std::enable_if_t>::value && - (std::is_floating_point::value || - std::is_same::value || - std::is_same::value), +std::enable_if_t> && + (std::is_floating_point_v || + std::is_same_v || + std::is_same_v), OrderedT> convertToOrdered(ValT value) { OrderedT uvalue = *reinterpret_cast *>(&value); diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp index fd10c223e841e..1c1f5988dd83a 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp @@ -123,9 +123,9 @@ class radix_sorter { std::bitset( std::numeric_limits::max())) : scratch(scratch_.data()), scratch_size(scratch_.size()) { - static_assert((std::is_arithmetic::value || - std::is_same::value || - std::is_same::value), + static_assert((std::is_arithmetic_v || + std::is_same_v || + std::is_same_v), "radix sort is not usable"); first_bit = 0; From 7a6956ff3ddf89a82cac11022b43252b02bead96 Mon Sep 17 00:00:00 2001 From: "Romanov, Vlad" Date: Tue, 13 Dec 2022 04:43:00 -0800 Subject: [PATCH 13/13] Address comments --- sycl/include/sycl/detail/group_sort_impl.hpp | 239 +++++++++--------- .../experimental/group_helpers_sorters.hpp | 44 ++-- .../ext/oneapi/experimental/group_sort.hpp | 8 +- 3 files changed, 141 insertions(+), 150 deletions(-) diff --git a/sycl/include/sycl/detail/group_sort_impl.hpp b/sycl/include/sycl/detail/group_sort_impl.hpp index a6afc3c2ca2a2..5b9d735d60445 100644 --- a/sycl/include/sycl/detail/group_sort_impl.hpp +++ b/sycl/include/sycl/detail/group_sort_impl.hpp @@ -27,11 +27,11 @@ namespace detail { // following two functions could be useless if std::[lower|upper]_bound worked // well template -std::size_t lower_bound(Acc acc, std::size_t first, std::size_t last, - const Value &value, Compare comp) { - std::size_t n = last - first; - std::size_t cur = n; - std::size_t it; +size_t lower_bound(Acc acc, size_t first, size_t last, const Value &value, + Compare comp) { + size_t n = last - first; + size_t cur = n; + size_t it; while (n > 0) { it = first; cur = n / 2; @@ -45,9 +45,8 @@ std::size_t lower_bound(Acc acc, std::size_t first, std::size_t last, } template -std::size_t upper_bound(Acc acc, const std::size_t first, - const std::size_t last, const Value &value, - Compare comp) { +size_t upper_bound(Acc acc, const size_t first, const size_t last, + const Value &value, Compare comp) { return detail::lower_bound(acc, first, last, value, [comp](auto x, auto y) { return !comp(y, x); }); } @@ -74,7 +73,7 @@ struct GetValueType> { // since we couldn't assign data to raw memory, it's better to use placement // for first assignment template -void set_value(Acc ptr, const std::size_t idx, const T &val, bool is_first) { +void set_value(Acc ptr, const size_t idx, const T &val, bool is_first) { if (is_first) { ::new (ptr + idx) T(val); } else { @@ -83,23 +82,23 @@ void set_value(Acc ptr, const std::size_t idx, const T &val, bool is_first) { } template -void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, - const std::size_t start_1, const std::size_t end_1, - const std::size_t end_2, const std::size_t start_out, Compare comp, - const std::size_t chunk, bool is_first) { - const std::size_t start_2 = end_1; +void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1, + const size_t start_1, const size_t end_1, const size_t end_2, + const size_t start_out, Compare comp, const size_t chunk, + bool is_first) { + const size_t start_2 = end_1; // Borders of the sequences to merge within this call - const std::size_t local_start_1 = - sycl::min(static_cast(offset + start_1), end_1); - const std::size_t local_end_1 = - sycl::min(static_cast(local_start_1 + chunk), end_1); - const std::size_t local_start_2 = - sycl::min(static_cast(offset + start_2), end_2); - const std::size_t local_end_2 = - sycl::min(static_cast(local_start_2 + chunk), end_2); - - const std::size_t local_size_1 = local_end_1 - local_start_1; - const std::size_t local_size_2 = local_end_2 - local_start_2; + const size_t local_start_1 = + sycl::min(static_cast(offset + start_1), end_1); + const size_t local_end_1 = + sycl::min(static_cast(local_start_1 + chunk), end_1); + const size_t local_start_2 = + sycl::min(static_cast(offset + start_2), end_2); + const size_t local_end_2 = + sycl::min(static_cast(local_start_2 + chunk), end_2); + + const size_t local_size_1 = local_end_1 - local_start_1; + const size_t local_size_2 = local_end_2 - local_start_2; // TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st // to improve performance @@ -109,15 +108,15 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, // Reduce the range for searching within the 2nd sequence and handle bound // items find left border in 2nd sequence const auto local_l_item_1 = in_acc1[local_start_1]; - std::size_t l_search_bound_2 = + size_t l_search_bound_2 = detail::lower_bound(in_acc1, start_2, end_2, local_l_item_1, comp); - const std::size_t l_shift_1 = local_start_1 - start_1; - const std::size_t l_shift_2 = l_search_bound_2 - start_2; + const size_t l_shift_1 = local_start_1 - start_1; + const size_t l_shift_2 = l_search_bound_2 - start_2; set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1, is_first); - std::size_t r_search_bound_2{}; + size_t r_search_bound_2{}; // find right border in 2nd sequence if (local_size_1 > 1) { const auto local_r_item_1 = in_acc1[local_end_1 - 1]; @@ -131,15 +130,15 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, } // Handle intermediate items - for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) { + for (size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) { const auto intermediate_item_1 = in_acc1[idx]; // we shouldn't seek in whole 2nd sequence. Just for the part where the // 1st sequence should be l_search_bound_2 = detail::lower_bound(in_acc1, l_search_bound_2, r_search_bound_2, intermediate_item_1, comp); - const std::size_t shift_1 = idx - start_1; - const std::size_t shift_2 = l_search_bound_2 - start_2; + const size_t shift_1 = idx - start_1; + const size_t shift_2 = l_search_bound_2 - start_2; set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_1, is_first); @@ -150,22 +149,22 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, // Reduce the range for searching within the 1st sequence and handle bound // items find left border in 1st sequence const auto local_l_item_2 = in_acc1[local_start_2]; - std::size_t l_search_bound_1 = + size_t l_search_bound_1 = detail::upper_bound(in_acc1, start_1, end_1, local_l_item_2, comp); - const std::size_t l_shift_1 = l_search_bound_1 - start_1; - const std::size_t l_shift_2 = local_start_2 - start_2; + const size_t l_shift_1 = l_search_bound_1 - start_1; + const size_t l_shift_2 = local_start_2 - start_2; set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2, is_first); - std::size_t r_search_bound_1{}; + size_t r_search_bound_1{}; // find right border in 1st sequence if (local_size_2 > 1) { const auto local_r_item_2 = in_acc1[local_end_2 - 1]; r_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, end_1, local_r_item_2, comp); - const std::size_t r_shift_1 = r_search_bound_1 - start_1; - const std::size_t r_shift_2 = local_end_2 - 1 - start_2; + const size_t r_shift_1 = r_search_bound_1 - start_1; + const size_t r_shift_2 = local_end_2 - 1 - start_2; set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2, is_first); @@ -179,8 +178,8 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, l_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, r_search_bound_1, intermediate_item_2, comp); - const std::size_t shift_1 = l_search_bound_1 - start_1; - const std::size_t shift_2 = idx - start_2; + const size_t shift_1 = l_search_bound_1 - start_1; + const size_t shift_2 = idx - start_2; set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_2, is_first); @@ -189,12 +188,12 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1, } template -void bubble_sort(Iter first, const std::size_t begin, const std::size_t end, +void bubble_sort(Iter first, const size_t begin, const size_t end, Compare comp) { if (begin < end) { - for (std::size_t i = begin; i < end; ++i) { + for (size_t i = begin; i < end; ++i) { // Handle intermediate items - for (std::size_t idx = i + 1; idx < end; ++idx) { + for (size_t idx = i + 1; idx < end; ++idx) { if (comp(first[idx], first[i])) { detail::swap_tuples(first[i], first[idx]); } @@ -204,12 +203,12 @@ void bubble_sort(Iter first, const std::size_t begin, const std::size_t end, } template -void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, +void merge_sort(Group group, Iter first, const size_t n, Compare comp, std::byte *scratch) { using T = typename GetValueType::type; - const std::size_t idx = group.get_local_linear_id(); - const std::size_t local = group.get_local_range().size(); - const std::size_t chunk = (n - 1) / local + 1; + const size_t idx = group.get_local_linear_id(); + const size_t local = group.get_local_range().size(); + const size_t chunk = (n - 1) / local + 1; // we need to sort within work item first bubble_sort(first, idx * chunk, sycl::min((idx + 1) * chunk, n), comp); @@ -218,13 +217,13 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, T *temp = reinterpret_cast(scratch); bool data_in_temp = false; bool is_first = true; - std::size_t sorted_size = 1; + size_t sorted_size = 1; while (sorted_size * chunk < n) { - const std::size_t start_1 = + const size_t start_1 = sycl::min(2 * sorted_size * chunk * (idx / sorted_size), n); - const std::size_t end_1 = sycl::min(start_1 + sorted_size * chunk, n); - const std::size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n); - const std::size_t offset = chunk * (idx % sorted_size); + const size_t end_1 = sycl::min(start_1 + sorted_size * chunk, n); + const size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n); + const size_t offset = chunk * (idx % sorted_size); if (!data_in_temp) { merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk, @@ -243,7 +242,7 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp, // copy back if data is in a temporary storage if (data_in_temp) { - for (std::size_t i = 0; i < chunk; ++i) { + for (size_t i = 0; i < chunk; ++i) { if (idx * chunk + i < n) { first[idx * chunk + i] = temp[idx * chunk + i]; } @@ -261,7 +260,7 @@ template struct IsCompAscending> { }; // get number of states radix bits can represent -constexpr std::uint32_t getStatesInBits(std::uint32_t radix_bits) { +constexpr uint32_t getStatesInBits(uint32_t radix_bits) { return (1 << radix_bits); } @@ -269,46 +268,44 @@ constexpr std::uint32_t getStatesInBits(std::uint32_t radix_bits) { // Ordered traits for a given size and integral/float flag //------------------------------------------------------------------------ -template struct GetOrdered {}; +template struct GetOrdered {}; template <> struct GetOrdered<1, true> { using Type = uint8_t; - constexpr static std::int8_t mask = 0x80; + constexpr static int8_t mask = 0x80; }; template <> struct GetOrdered<2, true> { using Type = uint16_t; - constexpr static std::int16_t mask = 0x8000; + constexpr static int16_t mask = 0x8000; }; template <> struct GetOrdered<4, true> { using Type = uint32_t; - constexpr static std::int32_t mask = 0x80000000; + constexpr static int32_t mask = 0x80000000; }; template <> struct GetOrdered<8, true> { using Type = uint64_t; - constexpr static std::int64_t mask = 0x8000000000000000; + constexpr static int64_t mask = 0x8000000000000000; }; template <> struct GetOrdered<2, false> { using Type = uint16_t; - constexpr static std::uint32_t nmask = 0xFFFF; // for negative numbers - constexpr static std::uint32_t pmask = 0x8000; // for positive numbers + constexpr static uint32_t nmask = 0xFFFF; // for negative numbers + constexpr static uint32_t pmask = 0x8000; // for positive numbers }; template <> struct GetOrdered<4, false> { using Type = uint32_t; - constexpr static std::uint32_t nmask = 0xFFFFFFFF; // for negative numbers - constexpr static std::uint32_t pmask = 0x80000000; // for positive numbers + constexpr static uint32_t nmask = 0xFFFFFFFF; // for negative numbers + constexpr static uint32_t pmask = 0x80000000; // for positive numbers }; template <> struct GetOrdered<8, false> { using Type = uint64_t; - constexpr static std::uint64_t nmask = - 0xFFFFFFFFFFFFFFFF; // for negative numbers - constexpr static std::uint64_t pmask = - 0x8000000000000000; // for positive numbers + constexpr static uint64_t nmask = 0xFFFFFFFFFFFFFFFF; // for negative numbers + constexpr static uint64_t pmask = 0x8000000000000000; // for positive numbers }; //------------------------------------------------------------------------ @@ -320,21 +317,22 @@ template struct Ordered {}; // for unsigned integrals we use the same type template -struct Ordered && - std::is_unsigned_v>> { +struct Ordered::value && + std::is_unsigned::value>> { using Type = ValT; }; // for signed integrals or floatings we map: size -> corresponding unsigned // integral template -struct Ordered && std::is_signed_v) || - std::is_floating_point_v || - std::is_same_v || - std::is_same_v>> { +struct Ordered< + ValT, std::enable_if_t< + (std::is_integral::value && std::is_signed::value) || + std::is_floating_point::value || + std::is_same::value || + std::is_same::value>> { using Type = - typename GetOrdered>::Type; + typename GetOrdered::value>::Type; }; // shorthand @@ -353,8 +351,8 @@ convertToOrdered(ValT value) { // converts integral type to Ordered (in terms of bitness) type template -std::enable_if_t> && - std::is_integral_v, +std::enable_if_t>::value && + std::is_integral::value, OrderedT> convertToOrdered(ValT value) { ValT result = value ^ GetOrdered::mask; @@ -363,10 +361,10 @@ convertToOrdered(ValT value) { // converts floating type to Ordered (in terms of bitness) type template -std::enable_if_t> && - (std::is_floating_point_v || - std::is_same_v || - std::is_same_v), +std::enable_if_t>::value && + (std::is_floating_point::value || + std::is_same::value || + std::is_same::value), OrderedT> convertToOrdered(ValT value) { OrderedT uvalue = *reinterpret_cast *>(&value); @@ -398,13 +396,13 @@ template <> struct InvertIf { }; // get bit values in a certain bucket of a value -template -std::uint32_t getBucketValue(ValT value, std::uint32_t radix_iter) { +template +uint32_t getBucketValue(ValT value, uint32_t radix_iter) { // invert value if we need to sort in descending order value = InvertIf{}(value); // get bucket offset idx from the end of bit type (least significant bits) - std::uint32_t bucket_offset = radix_iter * radix_bits; + uint32_t bucket_offset = radix_iter * radix_bits; // get offset mask for one bucket, e.g. // radix_bits=2: 0000 0001 -> 0000 0100 -> 0000 0011 @@ -413,12 +411,11 @@ std::uint32_t getBucketValue(ValT value, std::uint32_t radix_iter) { // get bits under bucket mask return (value >> bucket_offset) & bucket_mask; } -template ValT getDefaultValue(std::less) { - return std::numeric_limits::max(); -} - -template ValT getDefaultValue(std::greater) { - return std::numeric_limits::lowest(); +template ValT getDefaultValue(bool is_comp_asc) { + if (is_comp_asc) + return std::numeric_limits::max(); + else + return std::numeric_limits::lowest(); } template struct ValuesAssigner { @@ -443,23 +440,19 @@ template <> struct ValuesAssigner { }; // The iteration of radix sort for unknown number of elements per work item -template +template void performRadixIterDynamicSize(GroupT group, const uint32_t items_per_work_item, const uint32_t radix_iter, const size_t n, KeysT *keys_input, ValueT *vals_input, KeysT *keys_output, ValueT *vals_output, - uint32_t *memory, CompareT comp) { + uint32_t *memory) { const uint32_t radix_states = getStatesInBits(radix_bits); const size_t wgsize = group.get_local_linear_range(); const size_t idx = group.get_local_linear_id(); - constexpr bool is_comp_asc = - IsCompAscending::type>::value; - // 1.1. Zeroinitialize local memory - uint32_t *scan_memory = reinterpret_cast(memory); for (uint32_t state = 0; state < radix_states; ++state) scan_memory[state * wgsize + idx] = 0; @@ -467,12 +460,12 @@ void performRadixIterDynamicSize(GroupT group, sycl::group_barrier(group); // 1.2. count values and write result to private count array and count memory - for (uint32_t i = 0; i < items_per_work_item; ++i) { const uint32_t val_idx = items_per_work_item * idx + i; // get value, convert it to Ordered (in terms of bitness) - const auto val = convertToOrdered((val_idx < n) ? keys_input[val_idx] - : getDefaultValue(comp)); + const auto val = + convertToOrdered((val_idx < n) ? keys_input[val_idx] + : getDefaultValue(is_comp_asc)); // get bit values in a certain bucket of a value const uint32_t bucket_val = getBucketValue(val, radix_iter); @@ -508,8 +501,9 @@ void performRadixIterDynamicSize(GroupT group, for (uint32_t i = 0; i < items_per_work_item; ++i) { const uint32_t val_idx = items_per_work_item * idx + i; // get value, convert it to Ordered (in terms of bitness) - auto val = convertToOrdered((val_idx < n) ? keys_input[val_idx] - : getDefaultValue(comp)); + auto val = + convertToOrdered((val_idx < n) ? keys_input[val_idx] + : getDefaultValue(is_comp_asc)); // get bit values in a certain bucket of a value uint32_t bucket_val = getBucketValue(val, radix_iter); @@ -525,9 +519,9 @@ void performRadixIterDynamicSize(GroupT group, } // The iteration of radix sort for known number of elements per work item -template +template void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter, const uint32_t last_iter, KeysT *keys, ValsT vals, std::byte *memory) { @@ -585,7 +579,7 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter, // 3. Reorder KeysT *keys_temp = reinterpret_cast(memory); - ValsT vals_temp = reinterpret_cast( + ValsT *vals_temp = reinterpret_cast( memory + wgsize * items_per_work_item * sizeof(KeysT)); for (uint32_t i = 0; i < items_per_work_item; ++i) { keys_temp[ranks[i]] = keys[i]; @@ -596,7 +590,7 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter, // 4. Copy back to input for (uint32_t i = 0; i < items_per_work_item; ++i) { - std::size_t shift = idx * items_per_work_item + i; + size_t shift = idx * items_per_work_item + i; if constexpr (!is_blocked) { if (radix_iter == last_iter - 1) shift = i * wgsize + idx; @@ -606,12 +600,12 @@ void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter, } } -template -void privateSort(GroupT group, KeysT *keys, ValsT *values, const size_t n, - CompareT comp, std::byte *scratch, const uint32_t first_bit, - const uint32_t last_bit) { +template +void privateDynamicSort(GroupT group, KeysT *keys, ValsT *values, + const size_t n, std::byte *scratch, + const uint32_t first_bit, const uint32_t last_bit) { const size_t wgsize = group.get_local_linear_range(); constexpr uint32_t radix_states = getStatesInBits(radix_bits); const uint32_t first_iter = first_bit / radix_bits; @@ -631,9 +625,9 @@ void privateSort(GroupT group, KeysT *keys, ValsT *values, const size_t n, keys_output + is_key_value_sort * n * sizeof(KeysT) + alignof(uint32_t)); for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) { - performRadixIterDynamicSize( + performRadixIterDynamicSize( group, runtime_items_per_work_item, radix_iter, n, keys_input, - vals_input, keys_output, vals_output, scan_memory, comp); + vals_input, keys_output, vals_output, scan_memory); sycl::group_barrier(group); @@ -642,15 +636,12 @@ void privateSort(GroupT group, KeysT *keys, ValsT *values, const size_t n, } } -template -void privateMemorySort(Group group, T *keys, U *values, Compare comp, - std::byte *scratch, const uint32_t first_bit, - const uint32_t last_bit) { - (void)comp; - constexpr bool is_comp_asc = - IsCompAscending::type>::value; +template +void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch, + const uint32_t first_bit, const uint32_t last_bit) { + const uint32_t first_iter = first_bit / radix_bits; const uint32_t last_iter = last_bit / radix_bits; diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp index 1c1f5988dd83a..7a7d5283bd8ec 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp @@ -17,7 +17,7 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) { namespace ext::oneapi::experimental { // ---- group helpers -template class group_with_scratchpad { +template class group_with_scratchpad { Group g; sycl::span scratch; @@ -32,10 +32,10 @@ template class group_with_scratchpad { template > class default_sorter { Compare comp; std::byte *scratch; - std::size_t scratch_size; + size_t scratch_size; public: - template + template default_sorter(sycl::span scratch_, Compare comp_ = Compare()) : comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size()) {} @@ -61,7 +61,7 @@ template > class default_sorter { #ifdef __SYCL_DEVICE_ONLY__ auto range_size = g.get_local_range().size(); if (scratch_size >= memory_required(Group::fence_scope, range_size)) { - std::size_t local_id = g.get_local_linear_id(); + size_t local_id = g.get_local_linear_id(); T *temp = reinterpret_cast(scratch); ::new (temp + local_id) T(val); sycl::detail::merge_sort(g, temp, range_size, comp, @@ -79,14 +79,14 @@ template > class default_sorter { } template - static constexpr std::size_t memory_required(sycl::memory_scope, - std::size_t range_size) { + static constexpr size_t memory_required(sycl::memory_scope, + size_t range_size) { return range_size * sizeof(T) + alignof(T); } template - static constexpr std::size_t memory_required(sycl::memory_scope scope, - sycl::range r) { + static constexpr size_t memory_required(sycl::memory_scope scope, + sycl::range r) { return 2 * memory_required(scope, r.size()); } }; @@ -112,20 +112,20 @@ class radix_sorter { std::byte *scratch = nullptr; uint32_t first_bit = 0; uint32_t last_bit = 0; - std::size_t scratch_size = 0; + size_t scratch_size = 0; static constexpr uint32_t bits = BitsPerPass; public: - template + template radix_sorter(sycl::span scratch_, const std::bitset mask = std::bitset( std::numeric_limits::max())) : scratch(scratch_.data()), scratch_size(scratch_.size()) { - static_assert((std::is_arithmetic_v || - std::is_same_v || - std::is_same_v), + static_assert((std::is_arithmetic::value || + std::is_same::value || + std::is_same::value), "radix sort is not usable"); first_bit = 0; @@ -143,10 +143,11 @@ class radix_sorter { (void)first; (void)last; #ifdef __SYCL_DEVICE_ONLY__ - sycl::detail::privateSort( + sycl::detail::privateDynamicSort( g, first, /*empty*/ first, (last - first) > 0 ? (last - first) : 0, - typename detail::ConvertToComp::Type{}, scratch, - first_bit, last_bit); + scratch, first_bit, last_bit); #else throw sycl::exception( std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()), @@ -159,12 +160,11 @@ class radix_sorter { (void)val; #ifdef __SYCL_DEVICE_ONLY__ ValT result[]{val}; - sycl::detail::privateMemorySort( - g, result, /*empty*/ result, - typename detail::ConvertToComp::Type{}, scratch, - first_bit, last_bit); + g, result, /*empty*/ result, scratch, first_bit, last_bit); return result[0]; #else throw sycl::exception( @@ -173,8 +173,8 @@ class radix_sorter { #endif } - static constexpr std::size_t memory_required(sycl::memory_scope scope, - std::size_t range_size) { + static constexpr size_t memory_required(sycl::memory_scope scope, + size_t range_size) { // Scope is not important so far (void)scope; return range_size * sizeof(ValT) + diff --git a/sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp b/sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp index beea117bbc9d0..d1b7a4fefd1a5 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp @@ -81,7 +81,7 @@ sort_over_group(Group group, T value, Sorter sorter) { #endif } -template +template typename std::enable_if::value, T>::type sort_over_group(experimental::group_with_scratchpad exec, T value, Compare comp) { @@ -90,7 +90,7 @@ sort_over_group(experimental::group_with_scratchpad exec, experimental::default_sorter(exec.get_memory(), comp)); } -template +template typename std::enable_if>, T>::type sort_over_group(experimental::group_with_scratchpad exec, T value) { @@ -116,7 +116,7 @@ joint_sort(Group group, Iter first, Iter last, Sorter sorter) { #endif } -template +template typename std::enable_if::value, void>::type joint_sort(experimental::group_with_scratchpad exec, Iter first, @@ -125,7 +125,7 @@ joint_sort(experimental::group_with_scratchpad exec, Iter first, experimental::default_sorter(exec.get_memory(), comp)); } -template +template typename std::enable_if>, void>::type joint_sort(experimental::group_with_scratchpad exec, Iter first, Iter last) {