diff --git a/dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp b/dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp index 027431a80e..f73c7e766b 100644 --- a/dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp +++ b/dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp @@ -126,22 +126,42 @@ py_argsort(const dpctl::tensor::usm_ndarray &src, bool is_dst_c_contig = dst.is_c_contiguous(); if (is_src_c_contig && is_dst_c_contig) { - static constexpr py::ssize_t zero_offset = py::ssize_t(0); + if (sort_nelems > 1) { + static constexpr py::ssize_t zero_offset = py::ssize_t(0); - auto fn = sort_contig_fns[src_typeid][dst_typeid]; + auto fn = sort_contig_fns[src_typeid][dst_typeid]; - if (fn == nullptr) { - throw py::value_error("Not implemented for dtypes of input arrays"); + if (fn == nullptr) { + throw py::value_error( + "Not implemented for dtypes of input arrays"); + } + + sycl::event comp_ev = + fn(exec_q, iter_nelems, sort_nelems, src.get_data(), + dst.get_data(), zero_offset, zero_offset, zero_offset, + zero_offset, depends); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev}); + + return std::make_pair(keep_args_alive_ev, comp_ev); } + else { + int dst_elemsize = dst.get_elemsize(); + static constexpr int memset_val(0); - sycl::event comp_ev = - fn(exec_q, iter_nelems, sort_nelems, src.get_data(), dst.get_data(), - zero_offset, zero_offset, zero_offset, zero_offset, depends); + sycl::event fill_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - sycl::event keep_args_alive_ev = - dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev}); + cgh.memset(reinterpret_cast(dst.get_data()), memset_val, + iter_nelems * dst_elemsize); + }); - return std::make_pair(keep_args_alive_ev, comp_ev); + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {fill_ev}); + + return std::make_pair(keep_args_alive_ev, fill_ev); + } } throw py::value_error( diff --git a/dpctl/tensor/libtensor/source/sorting/py_sort_common.hpp b/dpctl/tensor/libtensor/source/sorting/py_sort_common.hpp index 2a727c4bd9..28762c4a5d 100644 --- a/dpctl/tensor/libtensor/source/sorting/py_sort_common.hpp +++ b/dpctl/tensor/libtensor/source/sorting/py_sort_common.hpp @@ -127,23 +127,37 @@ py_sort(const dpctl::tensor::usm_ndarray &src, bool is_dst_c_contig = dst.is_c_contiguous(); if (is_src_c_contig && is_dst_c_contig) { - static constexpr py::ssize_t zero_offset = py::ssize_t(0); + if (sort_nelems > 1) { + static constexpr py::ssize_t zero_offset = py::ssize_t(0); - auto fn = sort_contig_fns[src_typeid]; + auto fn = sort_contig_fns[src_typeid]; - if (nullptr == fn) { - throw py::value_error( - "Not implemented for the dtype of input arrays"); - } + if (nullptr == fn) { + throw py::value_error( + "Not implemented for the dtype of input arrays"); + } + + sycl::event comp_ev = + fn(exec_q, iter_nelems, sort_nelems, src.get_data(), + dst.get_data(), zero_offset, zero_offset, zero_offset, + zero_offset, depends); - sycl::event comp_ev = - fn(exec_q, iter_nelems, sort_nelems, src.get_data(), dst.get_data(), - zero_offset, zero_offset, zero_offset, zero_offset, depends); + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev}); - sycl::event keep_args_alive_ev = - dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev}); + return std::make_pair(keep_args_alive_ev, comp_ev); + } + else { + int src_elemsize = src.get_elemsize(); - return std::make_pair(keep_args_alive_ev, comp_ev); + sycl::event copy_ev = + exec_q.copy(src.get_data(), dst.get_data(), + src_elemsize * iter_nelems, depends); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {copy_ev}), + copy_ev); + } } throw py::value_error( diff --git a/dpctl/tests/test_usm_ndarray_sorting.py b/dpctl/tests/test_usm_ndarray_sorting.py index 5ecae344c5..c8d3701db2 100644 --- a/dpctl/tests/test_usm_ndarray_sorting.py +++ b/dpctl/tests/test_usm_ndarray_sorting.py @@ -338,3 +338,43 @@ def test_sort_complex_fp_nan(dtype): assert np.array_equal( r1.view(np.int64), r2.view(np.int64) ), f"Failed for {i} and {j}" + + +def test_radix_sort_size_1_axis(): + get_queue_or_skip() + + x1 = dpt.ones((), dtype="i1") + r1 = dpt.sort(x1, kind="radixsort") + assert r1 == x1 + + x2 = dpt.ones([1], dtype="i1") + r2 = dpt.sort(x2, kind="radixsort") + assert r2 == x2 + + x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1)) + r3 = dpt.sort(x3, kind="radixsort") + assert dpt.all(r3 == x3) + + x4 = dpt.reshape(dpt.arange(10, dtype="i1"), (1, 10)) + r4 = dpt.sort(x4, axis=0, kind="radixsort") + assert dpt.all(r4 == x4) + + +def test_radix_argsort_size_1_axis(): + get_queue_or_skip() + + x1 = dpt.ones((), dtype="i1") + r1 = dpt.argsort(x1, kind="radixsort") + assert r1 == 0 + + x2 = dpt.ones([1], dtype="i1") + r2 = dpt.argsort(x2, kind="radixsort") + assert r2 == 0 + + x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1)) + r3 = dpt.argsort(x3, kind="radixsort") + assert dpt.all(r3 == 0) + + x4 = dpt.reshape(dpt.arange(10, dtype="i1"), (1, 10)) + r4 = dpt.argsort(x4, axis=0, kind="radixsort") + assert dpt.all(r4 == 0)