diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 456eebdbaa..9a2493421e 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -49,8 +49,8 @@ pybind11_add_module(${python_module_name} MODULE ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp ) set(_clang_prefix "") if (WIN32) @@ -60,6 +60,7 @@ set_source_files_properties( ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math") if (UNIX) set_source_files_properties( diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index f0930004ec..3473d5cde5 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -160,7 +160,7 @@ tanh, trunc, ) -from ._reduction import sum +from ._reduction import argmax, argmin, max, min, prod, sum from ._testing import allclose __all__ = [ @@ -309,4 +309,9 @@ "allclose", "repeat", "tile", + "max", + "min", + "argmax", + "argmin", + "prod", ] diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index d9bd6b5b2b..aac1c84677 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -52,18 +52,107 @@ def _default_reduction_dtype(inp_dt, q): return res_dt -def sum(arr, axis=None, dtype=None, keepdims=False): +def _reduction_over_axis( + x, + axis, + dtype, + keepdims, + _reduction_fn, + _dtype_supported, + _default_reduction_type_fn, + _identity=None, +): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + nd = x.ndim + if axis is None: + axis = tuple(range(nd)) + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + red_nd = len(axis) + perm = [i for i in range(nd) if i not in axis] + list(axis) + arr2 = dpt.permute_dims(x, perm) + res_shape = arr2.shape[: nd - red_nd] + q = x.sycl_queue + inp_dt = x.dtype + if dtype is None: + res_dt = _default_reduction_type_fn(inp_dt, q) + else: + res_dt = dpt.dtype(dtype) + res_dt = _to_device_supported_dtype(res_dt, q.sycl_device) + + res_usm_type = x.usm_type + if x.size == 0: + if _identity is None: + raise ValueError("reduction does not support zero-size arrays") + else: + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res_shape = tuple(res_shape[i] for i in inv_perm) + return dpt.full( + res_shape, + _identity, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=q, + ) + if red_nd == 0: + return dpt.astype(x, res_dt, copy=False) + + host_tasks_list = [] + if _dtype_supported(inp_dt, res_dt, res_usm_type, q): + res = dpt.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e, _ = _reduction_fn( + src=arr2, trailing_dims_to_reduce=red_nd, dst=res, sycl_queue=q + ) + host_tasks_list.append(ht_e) + else: + if dtype is None: + raise RuntimeError( + "Automatically determined reduction data type does not " + "have direct implementation" + ) + tmp_dt = _default_reduction_dtype(inp_dt, q) + tmp = dpt.empty( + res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_tmp, r_e = _reduction_fn( + src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q + ) + host_tasks_list.append(ht_e_tmp) + res = dpt.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=tmp, dst=res, sycl_queue=q, depends=[r_e] + ) + host_tasks_list.append(ht_e) + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) + dpctl.SyclEvent.wait_for(host_tasks_list) + + return res + + +def sum(x, axis=None, dtype=None, keepdims=False): """sum(x, axis=None, dtype=None, keepdims=False) - Calculates the sum of the input array `x`. + Calculates the sum of elements in the input array `x`. Args: x (usm_ndarray): input array. - axis (Optional[int, Tuple[int,...]]): + axis (Optional[int, Tuple[int, ...]]): axis or axes along which sums must be computed. If a tuple of unique integers, sums are computed over multiple axes. - If `None`, the sum if computed over the entire array. + If `None`, the sum is computed over the entire array. Default: `None`. dtype (Optional[dtype]): data type of the returned array. If `None`, the default data @@ -101,9 +190,84 @@ def sum(arr, axis=None, dtype=None, keepdims=False): array has the data type as described in the `dtype` parameter description above. """ - if not isinstance(arr, dpt.usm_ndarray): - raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(arr)}") - nd = arr.ndim + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + ti._sum_over_axis, + ti._sum_over_axis_dtype_supported, + _default_reduction_dtype, + _identity=0, + ) + + +def prod(x, axis=None, dtype=None, keepdims=False): + """prod(x, axis=None, dtype=None, keepdims=False) + + Calculates the product of elements in the input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which products must be computed. If a tuple + of unique integers, products are computed over multiple axes. + If `None`, the product is computed over the entire array. + Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the default data + type is inferred from the "kind" of the input array data type. + * If `x` has a real-valued floating-point data type, + the returned array will have the default real-valued + floating-point data type for the device where input + array `x` is allocated. + * If x` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array `x` is allocated. + * If `x` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array `x` is allocated. + * If `x` has a complex-valued floating-point data typee, + the returned array will have the default complex-valued + floating-pointer data type for the device where input + array `x` is allocated. + * If `x` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array `x` is allocated. + If the data type (either specified or resolved) differs from the + data type of `x`, the input array elements are cast to the + specified data type before computing the product. Default: `None`. + keepdims (Optional[bool]): + if `True`, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if `False`, the reduced axes are not included in + the returned array. Default: `False`. + Returns: + usm_ndarray: + an array containing the products. If the product was computed over + the entire array, a zero-dimensional array is returned. The returned + array has the data type as described in the `dtype` parameter + description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + ti._prod_over_axis, + ti._prod_over_axis_dtype_supported, + _default_reduction_dtype, + _identity=1, + ) + + +def _comparison_over_axis(x, axis, keepdims, _reduction_fn): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + nd = x.ndim if axis is None: axis = tuple(range(nd)) if not isinstance(axis, (tuple, list)): @@ -111,63 +275,201 @@ def sum(arr, axis=None, dtype=None, keepdims=False): axis = normalize_axis_tuple(axis, nd, "axis") red_nd = len(axis) perm = [i for i in range(nd) if i not in axis] + list(axis) - arr2 = dpt.permute_dims(arr, perm) - res_shape = arr2.shape[: nd - red_nd] - q = arr.sycl_queue - inp_dt = arr.dtype - if dtype is None: - res_dt = _default_reduction_dtype(inp_dt, q) - else: - res_dt = dpt.dtype(dtype) - res_dt = _to_device_supported_dtype(res_dt, q.sycl_device) - - res_usm_type = arr.usm_type - if arr.size == 0: - if keepdims: - res_shape = res_shape + (1,) * red_nd - inv_perm = sorted(range(nd), key=lambda d: perm[d]) - res_shape = tuple(res_shape[i] for i in inv_perm) - return dpt.zeros( - res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q - ) + x_tmp = dpt.permute_dims(x, perm) + res_shape = x_tmp.shape[: nd - red_nd] + exec_q = x.sycl_queue + res_dt = x.dtype + res_usm_type = x.usm_type + if x.size == 0: + raise ValueError("reduction does not support zero-size arrays") if red_nd == 0: - return dpt.astype(arr, res_dt, copy=False) + return x - host_tasks_list = [] - if ti._sum_over_axis_dtype_supported(inp_dt, res_dt, res_usm_type, q): - res = dpt.empty( - res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q - ) - ht_e, _ = ti._sum_over_axis( - src=arr2, trailing_dims_to_reduce=red_nd, dst=res, sycl_queue=q - ) - host_tasks_list.append(ht_e) + res = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + ) + hev, _ = _reduction_fn( + src=x_tmp, + trailing_dims_to_reduce=red_nd, + dst=res, + sycl_queue=exec_q, + ) + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) + hev.wait() + return res + + +def max(x, axis=None, keepdims=False): + """max(x, axis=None, dtype=None, keepdims=False) + + Calculates the maximum value of the input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which maxima must be computed. If a tuple + of unique integers, the maxima are computed over multiple axes. + If `None`, the max is computed over the entire array. + Default: `None`. + keepdims (Optional[bool]): + if `True`, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if `False`, the reduced axes are not included in + the returned array. Default: `False`. + Returns: + usm_ndarray: + an array containing the maxima. If the max was computed over the + entire array, a zero-dimensional array is returned. The returned + array has the same data type as `x`. + """ + return _comparison_over_axis(x, axis, keepdims, ti._max_over_axis) + + +def min(x, axis=None, keepdims=False): + """min(x, axis=None, dtype=None, keepdims=False) + + Calculates the minimum value of the input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which minima must be computed. If a tuple + of unique integers, the minima are computed over multiple axes. + If `None`, the min is computed over the entire array. + Default: `None`. + keepdims (Optional[bool]): + if `True`, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if `False`, the reduced axes are not included in + the returned array. Default: `False`. + Returns: + usm_ndarray: + an array containing the minima. If the min was computed over the + entire array, a zero-dimensional array is returned. The returned + array has the same data type as `x`. + """ + return _comparison_over_axis(x, axis, keepdims, ti._min_over_axis) + + +def _search_over_axis(x, axis, keepdims, _reduction_fn): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + nd = x.ndim + if axis is None: + axis = tuple(range(nd)) + elif isinstance(axis, int): + axis = (axis,) else: - if dtype is None: - raise RuntimeError( - "Automatically determined reduction data type does not " - "have direct implementation" - ) - tmp_dt = _default_reduction_dtype(inp_dt, q) - tmp = dpt.empty( - res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q - ) - ht_e_tmp, r_e = ti._sum_over_axis( - src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q - ) - host_tasks_list.append(ht_e_tmp) - res = dpt.empty( - res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + raise TypeError( + f"`axis` argument expected `int` or `None`, got {type(axis)}" ) - ht_e, _ = ti._copy_usm_ndarray_into_usm_ndarray( - src=tmp, dst=res, sycl_queue=q, depends=[r_e] + axis = normalize_axis_tuple(axis, nd, "axis") + red_nd = len(axis) + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt.permute_dims(x, perm) + res_shape = x_tmp.shape[: nd - red_nd] + exec_q = x.sycl_queue + res_dt = ti.default_device_index_type(exec_q.sycl_device) + res_usm_type = x.usm_type + if x.size == 0: + raise ValueError("reduction does not support zero-size arrays") + if red_nd == 0: + return dpt.zeros( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q ) - host_tasks_list.append(ht_e) + + res = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + ) + hev, _ = _reduction_fn( + src=x_tmp, + trailing_dims_to_reduce=red_nd, + dst=res, + sycl_queue=exec_q, + ) if keepdims: res_shape = res_shape + (1,) * red_nd inv_perm = sorted(range(nd), key=lambda d: perm[d]) res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) - dpctl.SyclEvent.wait_for(host_tasks_list) - + hev.wait() return res + + +def argmax(x, axis=None, keepdims=False): + """argmax(x, axis=None, dtype=None, keepdims=False) + + Returns the indices of the maximum values of the input array `x` along a + specified axis. + + When the maximum value occurs multiple times, the indices corresponding to + the first occurrence are returned. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to search. If `None`, returns the index of the + maximum value of the flattened array. + Default: `None`. + keepdims (Optional[bool]): + if `True`, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if `False`, the reduced axes are not included in + the returned array. Default: `False`. + Returns: + usm_ndarray: + an array containing the indices of the first occurrence of the + maximum values. If the entire array was searched, a + zero-dimensional array is returned. The returned array has the + default array index data type for the device of `x`. + """ + return _search_over_axis(x, axis, keepdims, ti._argmax_over_axis) + + +def argmin(x, axis=None, keepdims=False): + """argmin(x, axis=None, dtype=None, keepdims=False) + + Returns the indices of the minimum values of the input array `x` along a + specified axis. + + When the minimum value occurs multiple times, the indices corresponding to + the first occurrence are returned. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to search. If `None`, returns the index of the + minimum value of the flattened array. + Default: `None`. + keepdims (Optional[bool]): + if `True`, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if `False`, the reduced axes are not included in + the returned array. Default: `False`. + Returns: + usm_ndarray: + an array containing the indices of the first occurrence of the + minimum values. If the entire array was searched, a + zero-dimensional array is returned. The returned array has the + default array index data type for the device of `x`. + """ + return _search_over_axis(x, axis, keepdims, ti._argmin_over_axis) diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index 7dfc956492..7cb97cd4f9 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -24,6 +24,7 @@ #pragma once #include +#include #include #include #include @@ -32,6 +33,7 @@ #include #include "pybind11/pybind11.h" +#include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/sycl_utils.hpp" #include "utils/type_dispatch.hpp" @@ -39,6 +41,7 @@ namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; namespace dpctl { @@ -47,6 +50,14 @@ namespace tensor namespace kernels { +template struct can_use_reduce_over_group +{ + static constexpr bool value = + sycl::has_known_identity::value && + !std::is_same_v && !std::is_same_v && + !std::is_same_v>; +}; + template (inp_[inp_offset]); + red_val = reduction_op_(red_val, val); } out_[out_iter_offset] = red_val; @@ -153,7 +166,7 @@ struct ReductionOverGroupWithAtomicFunctor const size_t reduction_lid = it.get_local_id(0); const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg - // work-items sums over input with indices + // work-items operate over input with indices // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg // + reduction_lid // for 0 <= m < reductions_per_wi @@ -191,11 +204,17 @@ struct ReductionOverGroupWithAtomicFunctor sycl::memory_scope::device, sycl::access::address_space::global_space> res_ref(out_[out_iter_offset]); - if constexpr (std::is_same_v> || - std::is_same_v>) - { + if constexpr (su_ns::IsPlus::value) { res_ref += red_val_over_wg; } + else if constexpr (std::is_same_v>) + { + res_ref.fetch_max(red_val_over_wg); + } + else if constexpr (std::is_same_v>) + { + res_ref.fetch_min(red_val_over_wg); + } else { outT read_val = res_ref.load(); outT new_val{}; @@ -207,7 +226,103 @@ struct ReductionOverGroupWithAtomicFunctor } }; -typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)( +/* === Reduction, using custom_reduce_over_group, and sycl::atomic_ref === */ + +template +struct CustomReductionOverGroupWithAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; + size_t reductions_per_wi = 16; + +public: + CustomReductionOverGroupWithAtomicFunctor( + const argT *data, + outT *res, + ReductionOp reduction_op, + const outT &identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + SlmT local_mem, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl(inp_[inp_offset]); + + local_red_val = reduction_op_(local_red_val, val); + } + + auto work_group = it.get_group(); + outT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_iter_offset]); + outT read_val = res_ref.load(); + outT new_val{}; + do { + new_val = reduction_op_(read_val, red_val_over_wg); + } while (!res_ref.compare_exchange_strong(read_val, new_val)); + } + } +}; + +typedef sycl::event (*reduction_strided_impl_fn_ptr)( sycl::queue &, size_t, size_t, @@ -223,27 +338,51 @@ typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)( const std::vector &); template -class sum_reduction_over_group_with_atomics_krn; +class reduction_over_group_with_atomics_krn; + +template +class custom_reduction_over_group_with_atomics_krn; -template -class sum_reduction_over_group_with_atomics_init_krn; +template +class reduction_over_group_with_atomics_init_krn; template -class sum_reduction_seq_strided_krn; +class reduction_seq_strided_krn; template -class sum_reduction_seq_contig_krn; +class reduction_seq_contig_krn; template -class sum_reduction_axis0_over_group_with_atomics_contig_krn; +class reduction_axis0_over_group_with_atomics_contig_krn; + +template +class custom_reduction_axis0_over_group_with_atomics_contig_krn; template -class sum_reduction_axis1_over_group_with_atomics_contig_krn; +class reduction_axis1_over_group_with_atomics_contig_krn; + +template +class custom_reduction_axis1_over_group_with_atomics_contig_krn; using dpctl::tensor::sycl_utils::choose_workgroup_size; -template -sycl::event sum_reduction_over_group_with_atomics_strided_impl( +template +sycl::event reduction_over_group_with_atomics_strided_impl( sycl::queue &exec_q, size_t iter_nelems, // number of reductions (num. of rows in a matrix // when reducing over rows) @@ -263,8 +402,7 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( const argTy *arg_tp = reinterpret_cast(arg_cp); resTy *res_tp = reinterpret_cast(res_cp); - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = resTy{0}; + constexpr resTy identity_val = su_ns::Identity::value; const sycl::device &d = exec_q.get_device(); const auto &sg_sizes = d.get_info(); @@ -285,7 +423,7 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, reduction_shape_stride}; - cgh.parallel_for>( sycl::range<1>(iter_nelems), @@ -308,8 +446,8 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, res_strides); using InitKernelName = - class sum_reduction_over_group_with_atomics_init_krn; + class reduction_over_group_with_atomics_init_krn; cgh.depends_on(depends); cgh.parallel_for( @@ -347,18 +485,38 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - using KernelName = class sum_reduction_over_group_with_atomics_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class reduction_over_group_with_atomics_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; - cgh.parallel_for( - sycl::nd_range<1>(globalRange, localRange), - ReductionOverGroupWithAtomicFunctor( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems, - iter_nelems, reductions_per_wi)); + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupWithAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>(arg_tp, res_tp, ReductionOpT(), + identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class custom_reduction_over_group_with_atomics_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupWithAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + reduction_nelems, iter_nelems, reductions_per_wi)); + } }); return comp_ev; @@ -367,7 +525,7 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( // Contig -typedef sycl::event (*sum_reduction_contig_impl_fn_ptr)( +typedef sycl::event (*reduction_contig_impl_fn_ptr)( sycl::queue &, size_t, size_t, @@ -379,8 +537,8 @@ typedef sycl::event (*sum_reduction_contig_impl_fn_ptr)( const std::vector &); /* @brief Reduce rows in a matrix */ -template -sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl( +template +sycl::event reduction_axis1_over_group_with_atomics_contig_impl( sycl::queue &exec_q, size_t iter_nelems, // number of reductions (num. of rows in a matrix // when reducing over rows) @@ -397,8 +555,7 @@ sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl( iter_arg_offset + reduction_arg_offset; resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = resTy{0}; + constexpr resTy identity_val = su_ns::Identity::value; const sycl::device &d = exec_q.get_device(); const auto &sg_sizes = d.get_info(); @@ -422,7 +579,7 @@ sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl( NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{}; - cgh.parallel_for>( sycl::range<1>(iter_nelems), @@ -470,28 +627,47 @@ sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - using KernelName = - class sum_reduction_axis1_over_group_with_atomics_contig_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = + class reduction_axis1_over_group_with_atomics_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; - cgh.parallel_for( - sycl::nd_range<1>(globalRange, localRange), - ReductionOverGroupWithAtomicFunctor( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems, - iter_nelems, reductions_per_wi)); + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupWithAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>(arg_tp, res_tp, ReductionOpT(), + identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class + custom_reduction_axis1_over_group_with_atomics_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupWithAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + reduction_nelems, iter_nelems, reductions_per_wi)); + } }); - return comp_ev; } } /* @brief Reduce rows in a matrix */ -template -sycl::event sum_reduction_axis0_over_group_with_atomics_contig_impl( +template +sycl::event reduction_axis0_over_group_with_atomics_contig_impl( sycl::queue &exec_q, size_t iter_nelems, // number of reductions (num. of cols in a matrix // when reducing over cols) @@ -508,8 +684,8 @@ sycl::event sum_reduction_axis0_over_group_with_atomics_contig_impl( iter_arg_offset + reduction_arg_offset; resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = resTy{0}; + constexpr resTy identity_val = su_ns::Identity::value; + ; const sycl::device &d = exec_q.get_device(); const auto &sg_sizes = d.get_info(); @@ -551,21 +727,40 @@ sycl::event sum_reduction_axis0_over_group_with_atomics_contig_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - using KernelName = - class sum_reduction_axis0_over_group_with_atomics_contig_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = + class reduction_axis0_over_group_with_atomics_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; - cgh.parallel_for( - sycl::nd_range<1>(globalRange, localRange), - ReductionOverGroupWithAtomicFunctor( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems, - iter_nelems, reductions_per_wi)); + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupWithAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>(arg_tp, res_tp, ReductionOpT(), + identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class + custom_reduction_axis0_over_group_with_atomics_contig_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupWithAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + reduction_nelems, iter_nelems, reductions_per_wi)); + } }); - return comp_ev; } } @@ -618,7 +813,7 @@ struct ReductionOverGroupNoAtomicFunctor const size_t reduction_batch_id = it.get_group(0) / iter_gws_; const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; - // work-items sums over input with indices + // work-items operates over input with indices // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg // + reduction_lid // for 0 <= m < reductions_per_wi @@ -658,11 +853,110 @@ struct ReductionOverGroupNoAtomicFunctor } }; -template -class sum_reduction_over_group_temps_krn; +/* = Reduction, using custom_reduce_over_group and not using atomic_ref*/ + +template +struct CustomReductionOverGroupNoAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; + size_t reductions_per_wi = 16; + +public: + CustomReductionOverGroupNoAtomicFunctor( + const argT *data, + outT *res, + ReductionOp reduction_op, + const outT &identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + SlmT local_mem, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (size_t m = 0; m < reductions_per_wi; ++m) { + size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl(inp_[inp_offset]); + + local_red_val = reduction_op_(local_red_val, val); + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; -template -sycl::event sum_reduction_over_group_temps_strided_impl( +template +class reduction_over_group_temps_krn; + +template +class custom_reduction_over_group_temps_krn; + +template +sycl::event reduction_over_group_temps_strided_impl( sycl::queue &exec_q, size_t iter_nelems, // number of reductions (num. of rows in a matrix // when reducing over rows) @@ -682,19 +976,21 @@ sycl::event sum_reduction_over_group_temps_strided_impl( const argTy *arg_tp = reinterpret_cast(arg_cp); resTy *res_tp = reinterpret_cast(res_cp); - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = resTy{0}; + constexpr resTy identity_val = su_ns::Identity::value; const sycl::device &d = exec_q.get_device(); const auto &sg_sizes = d.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); constexpr size_t preferrered_reductions_per_wi = 4; - size_t max_wg = d.get_info(); + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, d.get_info()); size_t reductions_per_wi(preferrered_reductions_per_wi); if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { - // reduction only requires 1 work-group, can output directly to res + // reduction only requries 1 work-group, can output directly to res sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -722,19 +1018,38 @@ sycl::event sum_reduction_over_group_temps_strided_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - using KernelName = class sum_reduction_over_group_temps_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; - cgh.parallel_for( - sycl::nd_range<1>(globalRange, localRange), - ReductionOverGroupNoAtomicFunctor( - arg_tp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems, - iter_nelems, reductions_per_wi)); - }); + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>(arg_tp, res_tp, ReductionOpT(), + identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class custom_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + }); return comp_ev; } else { @@ -773,9 +1088,10 @@ sycl::event sum_reduction_over_group_temps_strided_impl( using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - // Only 2*iter_nd entries describing shape and strides of iterated - // dimensions of input array from iter_shape_and_strides are going - // to be accessed by inp_indexer + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer InputIndexerT inp_indexer(iter_nd, iter_arg_offset, iter_shape_and_strides); ResIndexerT noop_tmp_indexer{}; @@ -789,17 +1105,37 @@ sycl::event sum_reduction_over_group_temps_strided_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - using KernelName = class sum_reduction_over_group_temps_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; - cgh.parallel_for( - sycl::nd_range<1>(globalRange, localRange), - ReductionOverGroupNoAtomicFunctor( - arg_tp, partially_reduced_tmp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, reduction_nelems, - iter_nelems, preferrered_reductions_per_wi)); + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + arg_tp, partially_reduced_tmp, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class custom_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg_tp, partially_reduced_tmp, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + local_memory, reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } }); size_t remaining_reduction_nelems = reduction_groups; @@ -817,34 +1153,34 @@ sycl::event sum_reduction_over_group_temps_strided_impl( assert(reduction_groups_ > 1); // keep reducing - sycl::event partial_reduction_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(dependent_ev); - - using InputIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - InputIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - - InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; - ResIndexerT res_iter_indexer{}; - - InputOutputIterIndexerT in_out_iter_indexer{ - inp_indexer, res_iter_indexer}; - ReductionIndexerT reduction_indexer{}; - - auto globalRange = - sycl::range<1>{iter_nelems * reduction_groups_ * wg}; - auto localRange = sycl::range<1>{wg}; - - using KernelName = class sum_reduction_over_group_temps_krn< + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + if constexpr (can_use_reduce_over_group::value) { + using KernelName = class reduction_over_group_temps_krn< resTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( @@ -856,7 +1192,25 @@ sycl::event sum_reduction_over_group_temps_strided_impl( in_out_iter_indexer, reduction_indexer, remaining_reduction_nelems, iter_nelems, preferrered_reductions_per_wi)); - }); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class custom_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + local_memory, remaining_reduction_nelems, + iter_nelems, preferrered_reductions_per_wi)); + } + }); remaining_reduction_nelems = reduction_groups_; std::swap(temp_arg, temp2_arg); @@ -900,18 +1254,37 @@ sycl::event sum_reduction_over_group_temps_strided_impl( sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - using KernelName = class sum_reduction_over_group_temps_krn< - argTy, resTy, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; - cgh.parallel_for( - sycl::nd_range<1>(globalRange, localRange), - ReductionOverGroupNoAtomicFunctor( - temp_arg, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - remaining_reduction_nelems, iter_nelems, - reductions_per_wi)); + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>(temp_arg, res_tp, ReductionOpT(), + identity_val, in_out_iter_indexer, + reduction_indexer, + remaining_reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class custom_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + } }); sycl::event cleanup_host_task_event = @@ -931,31 +1304,332 @@ sycl::event sum_reduction_over_group_temps_strided_impl( } } -/* @brief Types supported by plus-reduction code based on atomic_ref */ +/* @brief Types supported by comparison-reduction code based on atomic_ref */ template -struct TypePairSupportDataForSumReductionAtomic +struct TypePairSupportDataForCompReductionAtomic { /* value if true a kernel for must be instantiated, false * otherwise */ static constexpr bool is_defined = std::disjunction< // disjunction is C++17 // feature, supported - // by DPC++ input bool - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input int8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - // input uint8 - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, - td_ns::TypePairDefinedEntry, + // by DPC++ + // input int32 + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForCompReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MaxOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForCompReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +// Sum + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForSumReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, @@ -1105,9 +1779,10 @@ struct SumOverAxisAtomicStridedFactory if constexpr (TypePairSupportDataForSumReductionAtomic< srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; return dpctl::tensor::kernels:: - sum_reduction_over_group_with_atomics_strided_impl; + reduction_over_group_with_atomics_strided_impl; } else { return nullptr; @@ -1122,8 +1797,10 @@ struct SumOverAxisTempsStridedFactory { if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; return dpctl::tensor::kernels:: - sum_reduction_over_group_temps_strided_impl; + reduction_over_group_temps_strided_impl; } else { return nullptr; @@ -1139,9 +1816,10 @@ struct SumOverAxis1AtomicContigFactory if constexpr (TypePairSupportDataForSumReductionAtomic< srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; return dpctl::tensor::kernels:: - sum_reduction_axis1_over_group_with_atomics_contig_impl; + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; } else { return nullptr; @@ -1157,9 +1835,1188 @@ struct SumOverAxis0AtomicContigFactory if constexpr (TypePairSupportDataForSumReductionAtomic< srcTy, dstTy>::is_defined) { + using ReductionOpT = sycl::plus; return dpctl::tensor::kernels:: - sum_reduction_axis0_over_group_with_atomics_contig_impl; + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +// Product + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForProductReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForProductReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-throug + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ProductOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +// Argmax and Argmin + +/* = Search reduction using reduce_over_group*/ + +template +struct SearchReduction +{ +private: + const argT *inp_ = nullptr; + argT *vals_ = nullptr; + const outT *inds_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; + size_t reductions_per_wi = 16; + +public: + SearchReduction(const argT *data, + argT *vals, + const outT *inds, + outT *res, + ReductionOp reduction_op, + const argT &identity_val, + IdxReductionOp idx_reduction_op, + const outT &idx_identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : inp_(data), vals_(vals), inds_(inds), out_(res), + reduction_op_(reduction_op), identity_(identity_val), + idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + argT local_red_val(identity_); + outT local_idx(idx_identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (size_t m = 0; m < reductions_per_wi; ++m) { + size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == local_red_val) { + if constexpr (!First) { + local_idx = + idx_reduction_op_(local_idx, inds_[inp_offset]); + } + else { + local_idx = idx_reduction_op_( + local_idx, static_cast(arg_reduce_gid)); + } + } + else { + if constexpr (su_ns::IsMinimum::value) { + if (val < local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = static_cast(arg_reduce_gid); + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + if (val > local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = static_cast(arg_reduce_gid); + } + } + } + } + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + argT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, identity_, reduction_op_); + + if constexpr (std::is_integral_v) { + local_idx = + (red_val_over_wg == local_red_val) ? local_idx : idx_identity_; + } + else { + local_idx = + (red_val_over_wg == local_red_val || + std::isnan(red_val_over_wg) || std::isnan(local_red_val)) + ? local_idx + : idx_identity_; + } + outT idx_over_wg = sycl::reduce_over_group( + work_group, local_idx, idx_identity_, idx_reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + if constexpr (!Last) { + // if not the final reduction, write value corresponding to + // an index to a temporary + vals_[out_iter_offset * n_reduction_groups + + reduction_batch_id] = red_val_over_wg; + } + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + idx_over_wg; + } + } +}; + +/* = Search reduction using custom_reduce_over_group*/ + +template +struct CustomSearchReduction +{ +private: + const argT *inp_ = nullptr; + argT *vals_ = nullptr; + const outT *inds_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; + size_t reductions_per_wi = 16; + +public: + CustomSearchReduction(const argT *data, + argT *vals, + outT *inds, + outT *res, + ReductionOp reduction_op, + const argT &identity_val, + IdxReductionOp idx_reduction_op, + const outT &idx_identity_val, + InputOutputIterIndexerT arg_res_iter_indexer, + InputRedIndexerT arg_reduced_dims_indexer, + SlmT local_mem, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : inp_(data), vals_(vals), inds_(inds), out_(res), + reduction_op_(reduction_op), identity_(identity_val), + idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + const size_t iter_gid = it.get_group(0) % iter_gws_; + const size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + argT local_red_val(identity_); + outT local_idx(idx_identity_); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (size_t m = 0; m < reductions_per_wi; ++m) { + size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == local_red_val) { + if constexpr (!First) { + local_idx = + idx_reduction_op_(local_idx, inds_[inp_offset]); + } + else { + local_idx = idx_reduction_op_( + local_idx, static_cast(arg_reduce_gid)); + } + } + else { + if constexpr (su_ns::IsMinimum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::less_complex; + // less_complex always returns false for NaNs, so + // check + if (less_complex(val, local_red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else if constexpr (std::is_floating_point_v) { + if (val < local_red_val || std::isnan(val)) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else { + if (val < local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::greater_complex; + if (greater_complex(val, local_red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else if constexpr (std::is_floating_point_v) { + if (val > local_red_val || std::isnan(val)) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else { + if (val > local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + } + } + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + argT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + // equality does not hold for NaNs, so check here + local_idx = (red_val_over_wg == local_red_val || + std::isnan(std::real(local_red_val)) || + std::isnan(std::imag(local_red_val))) + ? local_idx + : idx_identity_; + } + else if constexpr (std::is_floating_point_v) { + // equality does not hold for NaNs, so check here + local_idx = + (red_val_over_wg == local_red_val || std::isnan(local_red_val)) + ? local_idx + : idx_identity_; + } + else { + local_idx = + red_val_over_wg == local_red_val ? local_idx : idx_identity_; + } + outT idx_over_wg = sycl::reduce_over_group( + work_group, local_idx, idx_identity_, idx_reduction_op_); + if (work_group.leader()) { + // each group writes to a different memory location + if constexpr (!Last) { + // if not the final reduction, write value corresponding to + // an index to a temporary + vals_[out_iter_offset * n_reduction_groups + + reduction_batch_id] = red_val_over_wg; + } + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + idx_over_wg; + } + } +}; + +typedef sycl::event (*search_reduction_strided_impl_fn_ptr)( + sycl::queue, + size_t, + size_t, + const char *, + char *, + int, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + int, + const py::ssize_t *, + py::ssize_t, + const std::vector &); + +template +class search_reduction_over_group_temps_krn; + +template +class search_custom_reduction_over_group_temps_krn; + +using dpctl::tensor::sycl_utils::choose_workgroup_size; + +template +sycl::event search_reduction_over_group_temps_strided_impl( + sycl::queue exec_q, + size_t iter_nelems, // number of reductions (num. of rows in a matrix + // when reducing over rows) + size_t reduction_nelems, // size of each reduction (length of rows, i.e. + // number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const py::ssize_t *iter_shape_and_strides, + py::ssize_t iter_arg_offset, + py::ssize_t iter_res_offset, + int red_nd, + const py::ssize_t *reduction_shape_stride, + py::ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + constexpr argTy identity_val = su_ns::Identity::value; + constexpr resTy idx_identity_val = su_ns::Identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferrered_reductions_per_wi = 4; + // max_max_wg prevents running out of resources on CPU + size_t max_wg = std::min( + size_t(2048), d.get_info()); + + size_t reductions_per_wi(preferrered_reductions_per_wi); + if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) { + // reduction only requries 1 work-group, can output directly to res + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, + iter_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class search_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, true, true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + arg_tp, nullptr, nullptr, res_tp, ReductionOpT(), + identity_val, IndexOpT(), idx_identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class search_custom_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, true, + true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + arg_tp, nullptr, nullptr, res_tp, ReductionOpT(), + identity_val, IndexOpT(), idx_identity_val, + in_out_iter_indexer, reduction_indexer, local_memory, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + }); + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups > 1); + + size_t second_iter_reduction_groups_ = + (reduction_groups + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + } + + argTy *partially_reduced_vals_tmp = sycl::malloc_device( + iter_nelems * (reduction_groups + second_iter_reduction_groups_), + exec_q); + argTy *partially_reduced_vals_tmp2 = nullptr; + + if (partially_reduced_vals_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_vals_tmp2 = + partially_reduced_vals_tmp + reduction_groups * iter_nelems; + } + + sycl::event first_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + // Only 2*iter_nd entries describing shape and strides of iterated + // dimensions of input array from iter_shape_and_strides are going + // to be accessed by inp_indexer + InputIndexerT inp_indexer(iter_nd, iter_arg_offset, + iter_shape_and_strides); + ResIndexerT noop_tmp_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + noop_tmp_indexer}; + ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class search_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, true, false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, ReductionOpT(), identity_val, + IndexOpT(), idx_identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class search_custom_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, true, + false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, ReductionOpT(), identity_val, + IndexOpT(), idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, reduction_nelems, + iter_nelems, preferrered_reductions_per_wi)); + } + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + + argTy *vals_temp_arg = partially_reduced_vals_tmp; + argTy *vals_temp2_arg = partially_reduced_vals_tmp2; + + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferrered_reductions_per_wi * max_wg) { + size_t reduction_groups_ = + (remaining_reduction_nelems + + preferrered_reductions_per_wi * wg - 1) / + (preferrered_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class search_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, false, + false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, remaining_reduction_nelems, + iter_nelems, preferrered_reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class search_custom_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, + false, false>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, + remaining_reduction_nelems, iter_nelems, + preferrered_reductions_per_wi)); + } + }); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + std::swap(vals_temp_arg, vals_temp2_arg); + dependent_ev = partial_reduction_ev; + } + + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{iter_nd, iter_res_offset, + /* shape */ iter_shape_and_strides, + /*s trides */ iter_shape_and_strides + + 2 * iter_nd}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + if constexpr (can_use_reduce_over_group::value) + { + using KernelName = class search_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, false, true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + SearchReduction( + vals_temp_arg, nullptr, temp_arg, res_tp, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, remaining_reduction_nelems, + iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = + class search_custom_reduction_over_group_temps_krn< + argTy, resTy, ReductionOpT, IndexOpT, + InputOutputIterIndexerT, ReductionIndexerT, SlmT, false, + true>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + CustomSearchReduction( + vals_temp_arg, nullptr, temp_arg, res_tp, + ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + } + }); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + sycl::context ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, partially_reduced_tmp, partially_reduced_vals_tmp] { + sycl::free(partially_reduced_tmp, ctx); + sycl::free(partially_reduced_vals_tmp, ctx); + }); + }); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +template +struct TypePairSupportDataForSearchReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< // disjunction is C++17 + // feature, supported + // by DPC++ input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ArgmaxOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSearchReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_reduction_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_reduction_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgminOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSearchReductionTemps< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_reduction_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_reduction_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } } else { return nullptr; diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index 2fc7b02efa..0d4240c516 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -26,14 +26,79 @@ #include #include #include +#include #include +#include "math_utils.hpp" + namespace dpctl { namespace tensor { namespace sycl_utils { +namespace detail +{ + +template struct TypeList; + +template struct TypeList +{ + using head = Head; + using tail = TypeList; +}; + +using NullTypeList = TypeList<>; +template +struct IsNullTypeList : std::conditional_t, + std::true_type, + std::false_type> +{ +}; + +// recursively check if type is contained in given TypeList +template +struct IsContained + : std::conditional_t< + std::is_same_v>, + std::true_type, + IsContained> +{ +}; + +template <> struct TypeList<> +{ +}; + +// std::false_type when last case has been checked for membership +template struct IsContained : std::false_type +{ +}; + +template struct IsComplex : std::false_type +{ +}; +template struct IsComplex> : std::true_type +{ +}; + +} // namespace detail + +template +using sycl_ops = detail::TypeList, + sycl::bit_or, + sycl::bit_xor, + sycl::bit_and, + sycl::maximum, + sycl::minimum, + sycl::multiplies>; + +template struct IsSyclOp +{ + static constexpr bool value = + detail::IsContained>>::value || + detail::IsContained>>::value; +}; /*! @brief Find the smallest multiple of supported sub-group size larger than * nelems */ @@ -66,6 +131,183 @@ size_t choose_workgroup_size(const size_t nelems, return wg; } +template +T custom_reduce_over_group(const GroupT &wg, + LocAccT local_mem_acc, + const T &local_val, + const OpT &op) +{ + size_t wgs = wg.get_local_linear_range(); + local_mem_acc[wg.get_local_linear_id()] = local_val; + + sycl::group_barrier(wg, sycl::memory_scope::work_group); + + T red_val_over_wg = local_mem_acc[0]; + if (wg.leader()) { + for (size_t i = 1; i < wgs; ++i) { + red_val_over_wg = op(red_val_over_wg, local_mem_acc[i]); + } + } + + sycl::group_barrier(wg, sycl::memory_scope::work_group); + + return sycl::group_broadcast(wg, red_val_over_wg); +} + +// Reduction functors + +// Maximum + +template struct Maximum +{ + T operator()(const T &x, const T &y) const + { + if constexpr (detail::IsComplex::value) { + using dpctl::tensor::math_utils::max_complex; + return max_complex(x, y); + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) { + return (std::isnan(x) || x > y) ? x : y; + } + else if constexpr (std::is_same_v) { + return x || y; + } + else { + return (x > y) ? x : y; + } + } +}; + +// Minimum + +template struct Minimum +{ + T operator()(const T &x, const T &y) const + { + if constexpr (detail::IsComplex::value) { + using dpctl::tensor::math_utils::min_complex; + return min_complex(x, y); + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) { + return (std::isnan(x) || x < y) ? x : y; + } + else if constexpr (std::is_same_v) { + return x && y; + } + else { + return (x < y) ? x : y; + } + } +}; + +// Define identities and operator checking structs + +template struct GetIdentity +{ +}; + +// Maximum + +template +using IsMaximum = std::bool_constant> || + std::is_same_v>>; + +template +struct GetIdentity::value>> +{ + static constexpr T value = + static_cast(std::numeric_limits::has_infinity + ? static_cast(-std::numeric_limits::infinity()) + : std::numeric_limits::lowest()); +}; + +template +struct GetIdentity::value>> +{ + static constexpr bool value = false; +}; + +template +struct GetIdentity, + std::enable_if_t, Op>::value>> +{ + static constexpr std::complex value{-std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; +}; + +// Minimum + +template +using IsMinimum = std::bool_constant> || + std::is_same_v>>; + +template +struct GetIdentity::value>> +{ + static constexpr T value = + static_cast(std::numeric_limits::has_infinity + ? static_cast(std::numeric_limits::infinity()) + : std::numeric_limits::max()); +}; + +template +struct GetIdentity::value>> +{ + static constexpr bool value = true; +}; + +template +struct GetIdentity, + std::enable_if_t, Op>::value>> +{ + static constexpr std::complex value{std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; +}; + +// Plus + +template +using IsPlus = std::bool_constant> || + std::is_same_v>>; +// Multiplies + +template +using IsMultiplies = + std::bool_constant> || + std::is_same_v>>; + +template +struct GetIdentity::value>> +{ + static constexpr T value = static_cast(1); +}; + +// Identity + +template struct Identity +{ +}; + +template +using UseBuiltInIdentity = + std::conjunction, sycl::has_known_identity>; + +template +struct Identity::value>> +{ + static constexpr T value = GetIdentity::value; +}; + +template +struct Identity::value>> +{ + static constexpr T value = sycl::known_identity::value; +}; + } // namespace sycl_utils } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reduction_over_axis.cpp b/dpctl/tensor/libtensor/source/reduction_over_axis.cpp new file mode 100644 index 0000000000..c67fcd5ba3 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reduction_over_axis.cpp @@ -0,0 +1,514 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include +#include +#include + +#include "dpctl4pybind11.hpp" +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; +// Max +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + max_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_max_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::MaxOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(max_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(max_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(max_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(max_over_axis0_contig_atomic_dispatch_table); +} + +} // namespace impl + +// Min +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + min_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_min_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::MinOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(min_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(min_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(min_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(min_over_axis0_contig_atomic_dispatch_table); +} + +} // namespace impl + +// Sum +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_sum_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::SumOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); +} + +} // namespace impl + +// Product +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_prod_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::ProductOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table); +} + +} // namespace impl + +// Argmax +namespace impl +{ + +using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; +static search_reduction_strided_impl_fn_ptr + argmax_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_argmax_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::ArgmaxOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmax_over_axis_strided_temps_dispatch_table); +} + +} // namespace impl + +// Argmin +namespace impl +{ + +using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; +static search_reduction_strided_impl_fn_ptr + argmin_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_argmin_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::ArgminOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmin_over_axis_strided_temps_dispatch_table); +} + +} // namespace impl + +namespace py = pybind11; + +void init_reduction_functions(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + namespace impl = dpctl::tensor::py_internal::impl; + + using dpctl::tensor::py_internal::py_reduction_dtype_supported; + using dpctl::tensor::py_internal::py_reduction_over_axis; + + using dpctl::tensor::py_internal::check_atomic_support; + using dpctl::tensor::py_internal::fixed_decision; + + // MAX + { + using dpctl::tensor::py_internal::impl:: + populate_max_over_axis_dispatch_tables; + populate_max_over_axis_dispatch_tables(); + using impl::max_over_axis0_contig_atomic_dispatch_table; + using impl::max_over_axis1_contig_atomic_dispatch_table; + using impl::max_over_axis_strided_atomic_dispatch_table; + using impl::max_over_axis_strided_temps_dispatch_table; + + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + + auto max_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + max_over_axis_strided_atomic_dispatch_table, + max_over_axis_strided_temps_dispatch_table, + max_over_axis0_contig_atomic_dispatch_table, + max_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_max_over_axis", max_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } + + // MIN + { + using dpctl::tensor::py_internal::impl:: + populate_min_over_axis_dispatch_tables; + populate_min_over_axis_dispatch_tables(); + using impl::min_over_axis0_contig_atomic_dispatch_table; + using impl::min_over_axis1_contig_atomic_dispatch_table; + using impl::min_over_axis_strided_atomic_dispatch_table; + using impl::min_over_axis_strided_temps_dispatch_table; + + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + + auto min_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + min_over_axis_strided_atomic_dispatch_table, + min_over_axis_strided_temps_dispatch_table, + min_over_axis0_contig_atomic_dispatch_table, + min_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_min_over_axis", min_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } + + // SUM + { + using dpctl::tensor::py_internal::impl:: + populate_sum_over_axis_dispatch_tables; + populate_sum_over_axis_dispatch_tables(); + using impl::sum_over_axis0_contig_atomic_dispatch_table; + using impl::sum_over_axis1_contig_atomic_dispatch_table; + using impl::sum_over_axis_strided_atomic_dispatch_table; + using impl::sum_over_axis_strided_temps_dispatch_table; + + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + + auto sum_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + sum_over_axis0_contig_atomic_dispatch_table, + sum_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_sum_over_axis", sum_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sum_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_sum_over_axis_dtype_supported", sum_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } + + // PROD + { + using dpctl::tensor::py_internal::impl:: + populate_prod_over_axis_dispatch_tables; + populate_prod_over_axis_dispatch_tables(); + using impl::prod_over_axis0_contig_atomic_dispatch_table; + using impl::prod_over_axis1_contig_atomic_dispatch_table; + using impl::prod_over_axis_strided_atomic_dispatch_table; + using impl::prod_over_axis_strided_temps_dispatch_table; + + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + + auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + prod_over_axis0_contig_atomic_dispatch_table, + prod_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto prod_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } + + // ARGMAX + { + using dpctl::tensor::py_internal::impl:: + populate_argmax_over_axis_dispatch_tables; + populate_argmax_over_axis_dispatch_tables(); + using impl::argmax_over_axis_strided_temps_dispatch_table; + + auto argmax_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_search_over_axis; + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmax_over_axis_strided_temps_dispatch_table); + }; + m.def("_argmax_over_axis", argmax_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } + + // ARGMIN + { + using dpctl::tensor::py_internal::impl:: + populate_argmin_over_axis_dispatch_tables; + populate_argmin_over_axis_dispatch_tables(); + using impl::argmin_over_axis_strided_temps_dispatch_table; + + auto argmin_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_search_over_axis; + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmin_over_axis_strided_temps_dispatch_table); + }; + m.def("_argmin_over_axis", argmin_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sum_reductions.cpp b/dpctl/tensor/libtensor/source/reduction_over_axis.hpp similarity index 57% rename from dpctl/tensor/libtensor/source/sum_reductions.cpp rename to dpctl/tensor/libtensor/source/reduction_over_axis.hpp index 529096f5b6..1a9cb6f5e7 100644 --- a/dpctl/tensor/libtensor/source/sum_reductions.cpp +++ b/dpctl/tensor/libtensor/source/reduction_over_axis.hpp @@ -1,8 +1,8 @@ -//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,16 +16,19 @@ // See the License for the specific language governing permissions and // limitations under the License. // -//===--------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// /// /// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions -//===--------------------------------------------------------------------===// +/// This file defines functions of dpctl.tensor._tensor_impl extensions, +/// specifically functions for reductions. +//===----------------------------------------------------------------------===// + +#pragma once #include #include -#include -#include +#include +#include #include #include @@ -35,8 +38,6 @@ #include #include "kernels/reductions.hpp" -#include "sum_reductions.hpp" - #include "simplify_iteration_space.hpp" #include "utils/memory_overlap.hpp" #include "utils/offset_utils.hpp" @@ -49,14 +50,15 @@ namespace tensor namespace py_internal { +template bool check_atomic_support(const sycl::queue &exec_q, - sycl::usm::alloc usm_alloc_type, - bool require_atomic64 = false) + sycl::usm::alloc usm_alloc_type) { bool supports_atomics = false; const sycl::device &dev = exec_q.get_device(); - if (require_atomic64) { + + if constexpr (require_atomic64) { if (!dev.has(sycl::aspect::atomic64)) return false; } @@ -78,28 +80,106 @@ bool check_atomic_support(const sycl::queue &exec_q, return supports_atomics; } -using dpctl::tensor::kernels::sum_reduction_strided_impl_fn_ptr; -static sum_reduction_strided_impl_fn_ptr - sum_over_axis_strided_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static sum_reduction_strided_impl_fn_ptr - sum_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -using dpctl::tensor::kernels::sum_reduction_contig_impl_fn_ptr; -static sum_reduction_contig_impl_fn_ptr - sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static sum_reduction_contig_impl_fn_ptr - sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -std::pair py_sum_over_axis( +template +bool fixed_decision(const sycl::queue &, sycl::usm::alloc) +{ + return return_value; +} + +/* ====================== dtype supported ======================== */ + +template +bool py_reduction_dtype_supported( + const py::dtype &input_dtype, + const py::dtype &output_dtype, + const std::string &dst_usm_type, + sycl::queue &q, + const fnT &atomic_dispatch_table, + const fnT &temps_dispatch_table, + const CheckAtomicSupportFnT &check_atomic_support_size4, + const CheckAtomicSupportFnT &check_atomic_support_size8) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + // remove_all_extents gets underlying type of table + using fn_ptrT = typename std::remove_all_extents::type; + fn_ptrT fn = nullptr; + + sycl::usm::alloc kind = sycl::usm::alloc::unknown; + + if (dst_usm_type == "device") { + kind = sycl::usm::alloc::device; + } + else if (dst_usm_type == "shared") { + kind = sycl::usm::alloc::shared; + } + else if (dst_usm_type == "host") { + kind = sycl::usm::alloc::host; + } + else { + throw py::value_error("Unrecognized `dst_usm_type` argument."); + } + + bool supports_atomics = false; + + switch (output_dtype.itemsize()) { + case sizeof(float): + { + supports_atomics = check_atomic_support_size4(q, kind); + } break; + case sizeof(double): + { + supports_atomics = check_atomic_support_size8(q, kind); + } break; + } + + if (supports_atomics) { + fn = atomic_dispatch_table[arg_typeid][out_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = temps_dispatch_table[arg_typeid][out_typeid]; + } + + return (fn != nullptr); +} + +/* ==================== Generic reductions ====================== */ + +template +std::pair py_reduction_over_axis( const dpctl::tensor::usm_ndarray &src, - int trailing_dims_to_reduce, // sum over this many trailing indexes + int trailing_dims_to_reduce, // comp over this many trailing indexes const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends) + const std::vector &depends, + const strided_fnT &atomic_dispatch_table, + const strided_fnT &temps_dispatch_table, + const contig_fnT &axis0_dispatch_table, + const contig_fnT &axis1_dispatch_table, + const SupportAtomicFnT &check_atomic_support_size4, + const SupportAtomicFnT &check_atomic_support_size8) { int src_nd = src.get_ndim(); int iteration_nd = src_nd - trailing_dims_to_reduce; @@ -160,6 +240,7 @@ std::pair py_sum_over_axis( int src_typenum = src.get_typenum(); int dst_typenum = dst.get_typenum(); + namespace td_ns = dpctl::tensor::type_dispatch; const auto &array_types = td_ns::usm_ndarray_types(); int src_typeid = array_types.typenum_to_lookup_id(src_typenum); int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); @@ -173,7 +254,7 @@ std::pair py_sum_over_axis( void *data_ptr = dst.get_data(); const auto &ctx = exec_q.get_context(); auto usm_type = sycl::get_pointer_type(data_ptr, ctx); - supports_atomics = check_atomic_support(exec_q, usm_type); + supports_atomics = check_atomic_support_size4(exec_q, usm_type); } break; case sizeof(double): { @@ -181,9 +262,7 @@ std::pair py_sum_over_axis( const auto &ctx = exec_q.get_context(); auto usm_type = sycl::get_pointer_type(data_ptr, ctx); - constexpr bool require_atomic64 = true; - supports_atomics = - check_atomic_support(exec_q, usm_type, require_atomic64); + supports_atomics = check_atomic_support_size8(exec_q, usm_type); } break; } @@ -197,14 +276,14 @@ std::pair py_sum_over_axis( if ((is_src_c_contig && is_dst_c_contig) || (is_src_f_contig && dst_nelems == 1)) { - auto fn = sum_over_axis1_contig_atomic_dispatch_table[src_typeid] - [dst_typeid]; + auto fn = axis1_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { size_t iter_nelems = dst_nelems; constexpr py::ssize_t zero_offset = 0; - sycl::event sum_over_axis_contig_ev = + sycl::event reduction_over_axis_contig_ev = fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), dst.get_data(), zero_offset, // iteration_src_offset @@ -213,22 +292,22 @@ std::pair py_sum_over_axis( depends); sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {sum_over_axis_contig_ev}); + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); - return std::make_pair(keep_args_event, sum_over_axis_contig_ev); + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); } } else if (is_src_f_contig && ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) { - auto fn = sum_over_axis0_contig_atomic_dispatch_table[src_typeid] - [dst_typeid]; + auto fn = axis0_dispatch_table[src_typeid][dst_typeid]; if (fn != nullptr) { size_t iter_nelems = dst_nelems; constexpr py::ssize_t zero_offset = 0; - sycl::event sum_over_axis_contig_ev = + sycl::event reduction_over_axis_contig_ev = fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), dst.get_data(), zero_offset, // iteration_src_offset @@ -237,9 +316,10 @@ std::pair py_sum_over_axis( depends); sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {sum_over_axis_contig_ev}); + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); - return std::make_pair(keep_args_event, sum_over_axis_contig_ev); + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); } } } @@ -320,50 +400,49 @@ std::pair py_sum_over_axis( } if (mat_reduce_over_axis1 || array_reduce_all_elems) { - auto fn = sum_over_axis1_contig_atomic_dispatch_table[src_typeid] - [dst_typeid]; + auto fn = axis1_dispatch_table[src_typeid][dst_typeid]; if (fn != nullptr) { - sycl::event sum_over_axis1_contig_ev = + sycl::event reduction_over_axis1_contig_ev = fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), dst.get_data(), iteration_src_offset, iteration_dst_offset, reduction_src_offset, depends); sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {sum_over_axis1_contig_ev}); + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); return std::make_pair(keep_args_event, - sum_over_axis1_contig_ev); + reduction_over_axis1_contig_ev); } } else if (mat_reduce_over_axis0) { - auto fn = sum_over_axis0_contig_atomic_dispatch_table[src_typeid] - [dst_typeid]; + auto fn = axis0_dispatch_table[src_typeid][dst_typeid]; if (fn != nullptr) { - sycl::event sum_over_axis0_contig_ev = + sycl::event reduction_over_axis0_contig_ev = fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), dst.get_data(), iteration_src_offset, iteration_dst_offset, reduction_src_offset, depends); sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {sum_over_axis0_contig_ev}); + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); return std::make_pair(keep_args_event, - sum_over_axis0_contig_ev); + reduction_over_axis0_contig_ev); } } } - using dpctl::tensor::kernels::sum_reduction_strided_impl_fn_ptr; - sum_reduction_strided_impl_fn_ptr fn = nullptr; + // remove_all_extents gets underlying type of table + using strided_fn_ptr_T = + typename std::remove_all_extents::type; + strided_fn_ptr_T fn = nullptr; if (supports_atomics) { - fn = - sum_over_axis_strided_atomic_dispatch_table[src_typeid][dst_typeid]; + fn = atomic_dispatch_table[src_typeid][dst_typeid]; } if (fn == nullptr) { // use slower reduction implementation using temporaries - fn = sum_over_axis_strided_temps_dispatch_table[src_typeid][dst_typeid]; + fn = temps_dispatch_table[src_typeid][dst_typeid]; if (fn == nullptr) { throw std::runtime_error("Datatypes are not supported"); } @@ -398,14 +477,15 @@ std::pair py_sum_over_axis( std::copy(depends.begin(), depends.end(), all_deps.begin()); all_deps.push_back(copy_metadata_ev); - auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), - dst.get_data(), iteration_nd, iter_shape_and_strides, - iteration_src_offset, iteration_dst_offset, - reduction_nd, // number dimensions being reduced - reduction_shape_stride, reduction_src_offset, all_deps); + auto reduction_ev = + fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), + iteration_nd, iter_shape_and_strides, iteration_src_offset, + iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(comp_ev); + cgh.depends_on(reduction_ev); const auto &ctx = exec_q.get_context(); cgh.host_task([ctx, temp_allocation_ptr] { sycl::free(temp_allocation_ptr, ctx); @@ -416,127 +496,194 @@ std::pair py_sum_over_axis( sycl::event keep_args_event = dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); - return std::make_pair(keep_args_event, comp_ev); + return std::make_pair(keep_args_event, reduction_ev); } -bool py_sum_over_axis_dtype_supported(const py::dtype &input_dtype, - const py::dtype &output_dtype, - const std::string &dst_usm_type, - sycl::queue &q) -{ - int arg_tn = - input_dtype.num(); // NumPy type numbers are the same as in dpctl - int out_tn = - output_dtype.num(); // NumPy type numbers are the same as in dpctl - int arg_typeid = -1; - int out_typeid = -1; - - auto array_types = td_ns::usm_ndarray_types(); +/* ==================== Search reductions ====================== */ - try { - arg_typeid = array_types.typenum_to_lookup_id(arg_tn); - out_typeid = array_types.typenum_to_lookup_id(out_tn); - } catch (const std::exception &e) { - throw py::value_error(e.what()); +template +std::pair py_search_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const fn_tableT &dispatch_table) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); } - if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || - out_typeid >= td_ns::num_types) - { - throw std::runtime_error("Reduction type support check: lookup failed"); + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); } - using dpctl::tensor::kernels::sum_reduction_strided_impl_fn_ptr; - sum_reduction_strided_impl_fn_ptr fn = nullptr; - - sycl::usm::alloc kind = sycl::usm::alloc::unknown; + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); - if (dst_usm_type == "device") { - kind = sycl::usm::alloc::device; + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); } - else if (dst_usm_type == "shared") { - kind = sycl::usm::alloc::shared; + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); } - else if (dst_usm_type == "host") { - kind = sycl::usm::alloc::host; + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); } - else { - throw py::value_error("Unrecognized `dst_usm_type` argument."); + + size_t dst_nelems = dst.get_size(); + + size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); } - bool supports_atomics = false; + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } - switch (output_dtype.itemsize()) { - case sizeof(float): - { - supports_atomics = check_atomic_support(q, kind); - } break; - case sizeof(double): + // destination must be ample enough to accommodate all elements { - constexpr bool require_atomic64 = true; - supports_atomics = check_atomic_support(q, kind, require_atomic64); - } break; + auto dst_offsets = dst.get_minmax_offsets(); + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Destination array can not accommodate all the " + "elements of source array."); + } } - if (supports_atomics) { - fn = - sum_over_axis_strided_atomic_dispatch_table[arg_typeid][out_typeid]; + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_1; + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT compact_reduction_shape; + shT compact_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + compact_iteration_space( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + compact_reduction_shape, compact_reduction_src_strides); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); } + auto fn = dispatch_table[src_typeid][dst_typeid]; if (fn == nullptr) { - // use slower reduction implementation using temporaries - fn = sum_over_axis_strided_temps_dispatch_table[arg_typeid][out_typeid]; + throw std::runtime_error("Datatypes are not supported"); } - return (fn != nullptr); -} + std::vector host_task_events{}; -void populate_sum_over_axis_dispatch_table(void) -{ - using dpctl::tensor::kernels::sum_reduction_contig_impl_fn_ptr; - using dpctl::tensor::kernels::sum_reduction_strided_impl_fn_ptr; - using namespace td_ns; - - using dpctl::tensor::kernels::SumOverAxisAtomicStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); - - using dpctl::tensor::kernels::SumOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); - - using dpctl::tensor::kernels::SumOverAxis1AtomicContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); - - using dpctl::tensor::kernels::SumOverAxis0AtomicContigFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); -} + using dpctl::tensor::offset_utils::device_allocate_and_pack; -namespace py = pybind11; + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + compact_reduction_shape, compact_reduction_src_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); -void init_reduction_functions(py::module_ m) -{ - populate_sum_over_axis_dispatch_table(); + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_nd, iter_shape_and_strides, + iteration_src_offset, iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(comp_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); - m.def("_sum_over_axis", &py_sum_over_axis, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); - m.def("_sum_over_axis_dtype_supported", &py_sum_over_axis_dtype_supported, - "", py::arg("arg_dtype"), py::arg("out_dtype"), - py::arg("dst_usm_type"), py::arg("sycl_queue")); + return std::make_pair(keep_args_event, comp_ev); } +extern void init_reduction_functions(py::module_ m); + } // namespace py_internal } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sum_reductions.hpp b/dpctl/tensor/libtensor/source/sum_reductions.hpp deleted file mode 100644 index ac612ec1f7..0000000000 --- a/dpctl/tensor/libtensor/source/sum_reductions.hpp +++ /dev/null @@ -1,40 +0,0 @@ -//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2022 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===--------------------------------------------------------------------===// -/// -/// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions -//===--------------------------------------------------------------------===// - -#pragma once -#include -#include - -namespace dpctl -{ -namespace tensor -{ -namespace py_internal -{ - -extern void init_reduction_functions(py::module_ m); - -} // namespace py_internal -} // namespace tensor -} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 2ce7c72add..6bd0649c1f 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -46,9 +46,9 @@ #include "full_ctor.hpp" #include "integer_advanced_indexing.hpp" #include "linear_sequences.hpp" +#include "reduction_over_axis.hpp" #include "repeat.hpp" #include "simplify_iteration_space.hpp" -#include "sum_reductions.hpp" #include "triul_ctor.hpp" #include "utils/memory_overlap.hpp" #include "utils/strided_iters.hpp" diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index 403a823324..dc647febf7 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import pytest import dpctl.tensor as dpt @@ -36,7 +35,6 @@ "c8", "c16", ] -_usm_types = ["device", "shared", "host"] @pytest.mark.parametrize("arg_dtype", _all_dtypes) @@ -56,11 +54,11 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype): assert r.dtype.kind == "f" elif m.dtype.kind == "c": assert r.dtype.kind == "c" - assert (dpt.asnumpy(r) == 100).all() + assert dpt.all(r == 100) m = dpt.ones(200, dtype=arg_dtype)[:1:-2] r = dpt.sum(m) - assert (dpt.asnumpy(r) == 99).all() + assert dpt.all(r == 99) @pytest.mark.parametrize("arg_dtype", _all_dtypes) @@ -75,7 +73,7 @@ def test_sum_arg_out_dtype_matrix(arg_dtype, out_dtype): assert isinstance(r, dpt.usm_ndarray) assert r.dtype == dpt.dtype(out_dtype) - assert (dpt.asnumpy(r) == 100).all() + assert dpt.all(r == 100) def test_sum_empty(): @@ -94,7 +92,7 @@ def test_sum_axis(): assert isinstance(s, dpt.usm_ndarray) assert s.shape == (3, 6) - assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all() + assert dpt.all(s == dpt.asarray(4 * 5 * 7, dtype="i4")) def test_sum_keepdims(): @@ -105,7 +103,7 @@ def test_sum_keepdims(): assert isinstance(s, dpt.usm_ndarray) assert s.shape == (3, 1, 1, 6, 1) - assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all() + assert dpt.all(s == dpt.asarray(4 * 5 * 7, dtype=s.dtype)) def test_sum_scalar(): @@ -117,7 +115,7 @@ def test_sum_scalar(): assert isinstance(s, dpt.usm_ndarray) assert m.sycl_queue == s.sycl_queue assert s.shape == () - assert dpt.asnumpy(s) == np.full((), 1) + assert s == dpt.full((), 1) @pytest.mark.parametrize("arg_dtype", _all_dtypes) @@ -132,7 +130,7 @@ def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype): assert isinstance(r, dpt.usm_ndarray) assert r.dtype == dpt.dtype(out_dtype) - assert dpt.asnumpy(r) == 1 + assert r == 1 def test_sum_keepdims_zero_size(): @@ -187,3 +185,66 @@ def test_axis0_bug(): expected = dpt.asarray([[0, 3], [1, 4], [2, 5]]) assert dpt.all(s == expected) + + +@pytest.mark.parametrize("arg_dtype", _all_dtypes[1:]) +def test_prod_arg_dtype_default_output_dtype_matrix(arg_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.prod(m) + + assert isinstance(r, dpt.usm_ndarray) + if m.dtype.kind == "i": + assert r.dtype.kind == "i" + elif m.dtype.kind == "u": + assert r.dtype.kind == "u" + elif m.dtype.kind == "f": + assert r.dtype.kind == "f" + elif m.dtype.kind == "c": + assert r.dtype.kind == "c" + assert dpt.all(r == 1) + + if dpt.isdtype(m.dtype, "unsigned integer"): + m = dpt.tile(dpt.arange(1, 3, dtype=arg_dtype), 10)[:1:-2] + r = dpt.prod(m) + assert dpt.all(r == dpt.asarray(512, dtype=r.dtype)) + else: + m = dpt.full(200, -1, dtype=arg_dtype)[:1:-2] + r = dpt.prod(m) + assert dpt.all(r == dpt.asarray(-1, dtype=r.dtype)) + + +def test_prod_empty(): + get_queue_or_skip() + x = dpt.empty((0,), dtype="u1") + y = dpt.prod(x) + assert y.shape == tuple() + assert int(y) == 1 + + +def test_prod_axis(): + get_queue_or_skip() + + m = dpt.ones((3, 4, 5, 6, 7), dtype="i4") + s = dpt.prod(m, axis=(1, 2, -1)) + + assert isinstance(s, dpt.usm_ndarray) + assert s.shape == (3, 6) + assert dpt.all(s == dpt.asarray(1, dtype="i4")) + + +@pytest.mark.parametrize("arg_dtype", _all_dtypes) +@pytest.mark.parametrize("out_dtype", _all_dtypes[1:]) +def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + skip_if_dtype_not_supported(out_dtype, q) + + m = dpt.ones(100, dtype=arg_dtype) + r = dpt.prod(m, dtype=out_dtype) + + assert isinstance(r, dpt.usm_ndarray) + assert r.dtype == dpt.dtype(out_dtype) + assert dpt.all(r == 1) diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py new file mode 100644 index 0000000000..8d66f35d71 --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -0,0 +1,236 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from random import randrange + +import numpy as np +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + + +def test_max_min_axis(): + get_queue_or_skip() + + x = dpt.reshape( + dpt.arange((3 * 4 * 5 * 6 * 7), dtype="i4"), (3, 4, 5, 6, 7) + ) + + m = dpt.max(x, axis=(1, 2, -1)) + assert m.shape == (3, 6) + assert dpt.all(m == x[:, -1, -1, :, -1]) + + m = dpt.min(x, axis=(1, 2, -1)) + assert m.shape == (3, 6) + assert dpt.all(m == x[:, 0, 0, :, 0]) + + +def test_reduction_keepdims(): + get_queue_or_skip() + + n0, n1 = 3, 6 + x = dpt.ones((n0, 4, 5, n1, 7), dtype="i4") + m = dpt.max(x, axis=(1, 2, -1), keepdims=True) + + xx = dpt.reshape(dpt.permute_dims(x, (0, 3, 1, 2, -1)), (n0, n1, -1)) + p = dpt.argmax(xx, axis=-1, keepdims=True) + + assert m.shape == (n0, 1, 1, n1, 1) + assert dpt.all(m == dpt.reshape(x[:, 0, 0, :, 0], m.shape)) + assert dpt.all(p == 0) + + +def test_max_scalar(): + get_queue_or_skip() + + x = dpt.ones(()) + m = dpt.max(x) + + assert m.shape == () + assert x == m + + +@pytest.mark.parametrize("arg_dtype", ["i4", "f4", "c8"]) +def test_reduction_kernels(arg_dtype): + # i4 - always uses atomics w/ sycl group reduction + # f4 - always uses atomics w/ custom group reduction + # c8 - always uses temps w/ custom group reduction + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + + x = dpt.ones((24, 1025), dtype=arg_dtype, sycl_queue=q) + x[x.shape[0] // 2, :] = 3 + x[:, x.shape[1] // 2] = 3 + + m = dpt.max(x) + assert m == 3 + m = dpt.max(x, axis=0) + assert dpt.all(m == 3) + m = dpt.max(x, axis=1) + assert dpt.all(m == 3) + + x = dpt.ones((24, 1025), dtype=arg_dtype, sycl_queue=q) + x[x.shape[0] // 2, :] = 0 + x[:, x.shape[1] // 2] = 0 + + m = dpt.min(x) + assert m == 0 + m = dpt.min(x, axis=0) + assert dpt.all(m == 0) + m = dpt.min(x, axis=1) + assert dpt.all(m == 0) + + +def test_max_min_nan_propagation(): + get_queue_or_skip() + + # float, finites + x = dpt.arange(4, dtype="f4") + x[0] = dpt.nan + assert dpt.isnan(dpt.max(x)) + assert dpt.isnan(dpt.min(x)) + + # float, infinities + x[1:] = dpt.inf + assert dpt.isnan(dpt.max(x)) + x[1:] = -dpt.inf + assert dpt.isnan(dpt.min(x)) + + # complex + x = dpt.arange(4, dtype="c8") + x[0] = complex(dpt.nan, 0) + assert dpt.isnan(dpt.max(x)) + assert dpt.isnan(dpt.min(x)) + + x[0] = complex(0, dpt.nan) + assert dpt.isnan(dpt.max(x)) + assert dpt.isnan(dpt.min(x)) + + +def test_argmax_scalar(): + get_queue_or_skip() + + x = dpt.ones(()) + m = dpt.argmax(x) + + assert m.shape == () + assert m == 0 + + +@pytest.mark.parametrize("arg_dtype", ["i4", "f4", "c8"]) +def test_search_reduction_kernels(arg_dtype): + # i4 - always uses atomics w/ sycl group reduction + # f4 - always uses atomics w/ custom group reduction + # c8 - always uses temps w/ custom group reduction + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + + x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q) + idx = randrange(x.size) + idx_tup = np.unravel_index(idx, (24, 1025)) + x[idx] = 2 + + m = dpt.argmax(x) + assert m == idx + + x = dpt.reshape(x, (24, 1025)) + + x[idx_tup[0], :] = 3 + m = dpt.argmax(x, axis=0) + assert dpt.all(m == idx_tup[0]) + x[:, idx_tup[1]] = 4 + m = dpt.argmax(x, axis=1) + assert dpt.all(m == idx_tup[1]) + + x = x[:, ::-2] + idx = randrange(x.shape[1]) + x[:, idx] = 5 + m = dpt.argmax(x, axis=1) + assert dpt.all(m == idx) + + x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q) + idx = randrange(x.size) + idx_tup = np.unravel_index(idx, (24, 1025)) + x[idx] = 0 + + m = dpt.argmin(x) + assert m == idx + + x = dpt.reshape(x, (24, 1025)) + + x[idx_tup[0], :] = -1 + m = dpt.argmin(x, axis=0) + assert dpt.all(m == idx_tup[0]) + x[:, idx_tup[1]] = -2 + m = dpt.argmin(x, axis=1) + assert dpt.all(m == idx_tup[1]) + + x = x[:, ::-2] + idx = randrange(x.shape[1]) + x[:, idx] = -3 + m = dpt.argmin(x, axis=1) + assert dpt.all(m == idx) + + +def test_argmax_argmin_nan_propagation(): + get_queue_or_skip() + + sz = 4 + idx = randrange(sz) + # floats + x = dpt.arange(sz, dtype="f4") + x[idx] = dpt.nan + assert dpt.argmax(x) == idx + assert dpt.argmin(x) == idx + + # complex + x = dpt.arange(sz, dtype="c8") + x[idx] = complex(dpt.nan, 0) + assert dpt.argmax(x) == idx + assert dpt.argmin(x) == idx + + x[idx] = complex(0, dpt.nan) + assert dpt.argmax(x) == idx + assert dpt.argmin(x) == idx + + +def test_argmax_argmin_identities(): + # make sure that identity arrays work as expected + get_queue_or_skip() + + x = dpt.full(3, dpt.iinfo(dpt.int32).min, dtype="i4") + assert dpt.argmax(x) == 0 + x = dpt.full(3, dpt.iinfo(dpt.int32).max, dtype="i4") + assert dpt.argmin(x) == 0 + + +def test_reduction_arg_validation(): + get_queue_or_skip() + + x = dict() + with pytest.raises(TypeError): + dpt.sum(x) + with pytest.raises(TypeError): + dpt.max(x) + with pytest.raises(TypeError): + dpt.argmax(x) + + x = dpt.zeros((0,), dtype="i4") + with pytest.raises(ValueError): + dpt.max(x) + with pytest.raises(ValueError): + dpt.argmax(x)