From 5709f99a4ec67766c3b239a336dd0f0261d2c1d6 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 2 Nov 2023 20:19:16 -0700 Subject: [PATCH 1/2] Adds SequentialSearchReduction functor to search reductions --- .../libtensor/include/kernels/reductions.hpp | 235 ++++++++++++++++++ 1 file changed, 235 insertions(+) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index f3754e8820..74834af679 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -3401,6 +3401,125 @@ struct LogSumExpOverAxis0TempsContigFactory // Argmax and Argmin +/* Sequential search reduction */ + +template +struct SequentialSearchReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + +public: + SequentialSearchReduction(const argT *inp, + outT *res, + ReductionOp reduction_op, + const argT &identity_val, + IdxReductionOp idx_reduction_op, + const outT &idx_identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + size_t reduction_size) + : inp_(inp), out_(res), reduction_op_(reduction_op), + identity_(identity_val), idx_reduction_op_(idx_reduction_op), + idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); + const py::ssize_t &inp_iter_offset = + inp_out_iter_offsets_.get_first_offset(); + const py::ssize_t &out_iter_offset = + inp_out_iter_offsets_.get_second_offset(); + + argT red_val(identity_); + outT idx_val(idx_identity_); + for (size_t m = 0; m < reduction_max_gid_; ++m) { + const py::ssize_t inp_reduction_offset = + inp_reduced_dims_indexer_(m); + const py::ssize_t inp_offset = + inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == red_val) { + idx_val = idx_reduction_op_(idx_val, static_cast(m)); + } + else { + if constexpr (su_ns::IsMinimum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::less_complex; + // less_complex always returns false for NaNs, so check + if (less_complex(val, red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + red_val = val; + idx_val = static_cast(m); + } + } + else if constexpr (std::is_floating_point_v) { + if (val < red_val || std::isnan(val)) { + red_val = val; + idx_val = static_cast(m); + } + } + else { + if (val < red_val) { + red_val = val; + idx_val = static_cast(m); + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::greater_complex; + if (greater_complex(val, red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + red_val = val; + idx_val = static_cast(m); + } + } + else if constexpr (std::is_floating_point_v) { + if (val > red_val || std::isnan(val)) { + red_val = val; + idx_val = static_cast(m); + } + } + else { + if (val > red_val) { + red_val = val; + idx_val = static_cast(m); + } + } + } + } + } + out_[out_iter_offset] = idx_val; + } +}; + /* = Search reduction using reduce_over_group*/ template &); +template +class search_seq_strided_krn; + template class custom_search_over_group_temps_strided_krn; +template +class search_seq_contig_krn; + template (); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + if (reduction_nelems < wg) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, + iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + constexpr size_t preferred_reductions_per_wi = 4; // max_max_wg prevents running out of resources on CPU size_t max_wg = @@ -4419,6 +4584,39 @@ sycl::event search_axis1_over_group_temps_contig_impl( const auto &sg_sizes = d.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + if (reduction_nelems < wg) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputIterIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{}; + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + constexpr size_t preferred_reductions_per_wi = 8; // max_max_wg prevents running out of resources on CPU size_t max_wg = @@ -4801,6 +4999,43 @@ sycl::event search_axis0_over_group_temps_contig_impl( const auto &sg_sizes = d.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + if (reduction_nelems < wg) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + using KernelName = + class search_seq_contig_krn; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for( + iter_range, + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + constexpr size_t preferred_reductions_per_wi = 8; // max_max_wg prevents running out of resources on CPU size_t max_wg = From 119d43d565e86e5055dbd85416ed8d9df1bb2e47 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 3 Nov 2023 07:18:45 -0700 Subject: [PATCH 2/2] Search reductions use correct branch for float16 constexpr branch logic accounted for floating point types but not sycl::half, which meant NaNs were not propagating for float16 data --- .../libtensor/include/kernels/reductions.hpp | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index 74834af679..6651483c6c 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -3476,7 +3476,9 @@ struct SequentialSearchReduction idx_val = static_cast(m); } } - else if constexpr (std::is_floating_point_v) { + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { if (val < red_val || std::isnan(val)) { red_val = val; idx_val = static_cast(m); @@ -3501,7 +3503,9 @@ struct SequentialSearchReduction idx_val = static_cast(m); } } - else if constexpr (std::is_floating_point_v) { + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { if (val > red_val || std::isnan(val)) { red_val = val; idx_val = static_cast(m); @@ -3789,7 +3793,9 @@ struct CustomSearchReduction } } } - else if constexpr (std::is_floating_point_v) { + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { if (val < local_red_val || std::isnan(val)) { local_red_val = val; if constexpr (!First) { @@ -3833,7 +3839,9 @@ struct CustomSearchReduction } } } - else if constexpr (std::is_floating_point_v) { + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { if (val > local_red_val || std::isnan(val)) { local_red_val = val; if constexpr (!First) { @@ -3876,7 +3884,9 @@ struct CustomSearchReduction ? local_idx : idx_identity_; } - else if constexpr (std::is_floating_point_v) { + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { // equality does not hold for NaNs, so check here local_idx = (red_val_over_wg == local_red_val || std::isnan(local_red_val))