diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index f3754e8820..6651483c6c 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -3401,6 +3401,129 @@ struct LogSumExpOverAxis0TempsContigFactory // Argmax and Argmin +/* Sequential search reduction */ + +template +struct SequentialSearchReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + +public: + SequentialSearchReduction(const argT *inp, + outT *res, + ReductionOp reduction_op, + const argT &identity_val, + IdxReductionOp idx_reduction_op, + const outT &idx_identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + size_t reduction_size) + : inp_(inp), out_(res), reduction_op_(reduction_op), + identity_(identity_val), idx_reduction_op_(idx_reduction_op), + idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); + const py::ssize_t &inp_iter_offset = + inp_out_iter_offsets_.get_first_offset(); + const py::ssize_t &out_iter_offset = + inp_out_iter_offsets_.get_second_offset(); + + argT red_val(identity_); + outT idx_val(idx_identity_); + for (size_t m = 0; m < reduction_max_gid_; ++m) { + const py::ssize_t inp_reduction_offset = + inp_reduced_dims_indexer_(m); + const py::ssize_t inp_offset = + inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == red_val) { + idx_val = idx_reduction_op_(idx_val, static_cast(m)); + } + else { + if constexpr (su_ns::IsMinimum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::less_complex; + // less_complex always returns false for NaNs, so check + if (less_complex(val, red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + red_val = val; + idx_val = static_cast(m); + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val < red_val || std::isnan(val)) { + red_val = val; + idx_val = static_cast(m); + } + } + else { + if (val < red_val) { + red_val = val; + idx_val = static_cast(m); + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::greater_complex; + if (greater_complex(val, red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + red_val = val; + idx_val = static_cast(m); + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val > red_val || std::isnan(val)) { + red_val = val; + idx_val = static_cast(m); + } + } + else { + if (val > red_val) { + red_val = val; + idx_val = static_cast(m); + } + } + } + } + } + out_[out_iter_offset] = idx_val; + } +}; + /* = Search reduction using reduce_over_group*/ template ) { + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { if (val < local_red_val || std::isnan(val)) { local_red_val = val; if constexpr (!First) { @@ -3714,7 +3839,9 @@ struct CustomSearchReduction } } } - else if constexpr (std::is_floating_point_v) { + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { if (val > local_red_val || std::isnan(val)) { local_red_val = val; if constexpr (!First) { @@ -3757,7 +3884,9 @@ struct CustomSearchReduction ? local_idx : idx_identity_; } - else if constexpr (std::is_floating_point_v) { + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { // equality does not hold for NaNs, so check here local_idx = (red_val_over_wg == local_red_val || std::isnan(local_red_val)) @@ -3799,6 +3928,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)( py::ssize_t, const std::vector &); +template +class search_seq_strided_krn; + template class custom_search_over_group_temps_strided_krn; +template +class search_seq_contig_krn; + template (); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + if (reduction_nelems < wg) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, + iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + constexpr size_t preferred_reductions_per_wi = 4; // max_max_wg prevents running out of resources on CPU size_t max_wg = @@ -4419,6 +4594,39 @@ sycl::event search_axis1_over_group_temps_contig_impl( const auto &sg_sizes = d.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + if (reduction_nelems < wg) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputIterIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{}; + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + constexpr size_t preferred_reductions_per_wi = 8; // max_max_wg prevents running out of resources on CPU size_t max_wg = @@ -4801,6 +5009,43 @@ sycl::event search_axis0_over_group_temps_contig_impl( const auto &sg_sizes = d.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + if (reduction_nelems < wg) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + using KernelName = + class search_seq_contig_krn; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for( + iter_range, + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + constexpr size_t preferred_reductions_per_wi = 8; // max_max_wg prevents running out of resources on CPU size_t max_wg =