Skip to content

Commit 1838afa

Browse files
committed
fix sort where the number of elements to be sorted is 1
1 parent 35a8c26 commit 1838afa

File tree

3 files changed

+96
-22
lines changed

3 files changed

+96
-22
lines changed

dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,22 +126,42 @@ py_argsort(const dpctl::tensor::usm_ndarray &src,
126126
bool is_dst_c_contig = dst.is_c_contiguous();
127127

128128
if (is_src_c_contig && is_dst_c_contig) {
129-
static constexpr py::ssize_t zero_offset = py::ssize_t(0);
129+
if (sort_nelems > 1) {
130+
static constexpr py::ssize_t zero_offset = py::ssize_t(0);
130131

131-
auto fn = sort_contig_fns[src_typeid][dst_typeid];
132+
auto fn = sort_contig_fns[src_typeid][dst_typeid];
132133

133-
if (fn == nullptr) {
134-
throw py::value_error("Not implemented for dtypes of input arrays");
134+
if (fn == nullptr) {
135+
throw py::value_error(
136+
"Not implemented for dtypes of input arrays");
137+
}
138+
139+
sycl::event comp_ev =
140+
fn(exec_q, iter_nelems, sort_nelems, src.get_data(),
141+
dst.get_data(), zero_offset, zero_offset, zero_offset,
142+
zero_offset, depends);
143+
144+
sycl::event keep_args_alive_ev =
145+
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev});
146+
147+
return std::make_pair(keep_args_alive_ev, comp_ev);
135148
}
149+
else {
150+
int dst_elemsize = dst.get_elemsize();
151+
static constexpr int memset_val(0);
136152

137-
sycl::event comp_ev =
138-
fn(exec_q, iter_nelems, sort_nelems, src.get_data(), dst.get_data(),
139-
zero_offset, zero_offset, zero_offset, zero_offset, depends);
153+
sycl::event fill_ev = exec_q.submit([&](sycl::handler &cgh) {
154+
cgh.depends_on(depends);
140155

141-
sycl::event keep_args_alive_ev =
142-
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev});
156+
cgh.memset(reinterpret_cast<void *>(dst.get_data()), memset_val,
157+
iter_nelems * dst_elemsize);
158+
});
143159

144-
return std::make_pair(keep_args_alive_ev, comp_ev);
160+
sycl::event keep_args_alive_ev =
161+
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {fill_ev});
162+
163+
return std::make_pair(keep_args_alive_ev, fill_ev);
164+
}
145165
}
146166

147167
throw py::value_error(

dpctl/tensor/libtensor/source/sorting/py_sort_common.hpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,23 +127,37 @@ py_sort(const dpctl::tensor::usm_ndarray &src,
127127
bool is_dst_c_contig = dst.is_c_contiguous();
128128

129129
if (is_src_c_contig && is_dst_c_contig) {
130-
static constexpr py::ssize_t zero_offset = py::ssize_t(0);
130+
if (sort_nelems > 1) {
131+
static constexpr py::ssize_t zero_offset = py::ssize_t(0);
131132

132-
auto fn = sort_contig_fns[src_typeid];
133+
auto fn = sort_contig_fns[src_typeid];
133134

134-
if (nullptr == fn) {
135-
throw py::value_error(
136-
"Not implemented for the dtype of input arrays");
137-
}
135+
if (nullptr == fn) {
136+
throw py::value_error(
137+
"Not implemented for the dtype of input arrays");
138+
}
139+
140+
sycl::event comp_ev =
141+
fn(exec_q, iter_nelems, sort_nelems, src.get_data(),
142+
dst.get_data(), zero_offset, zero_offset, zero_offset,
143+
zero_offset, depends);
138144

139-
sycl::event comp_ev =
140-
fn(exec_q, iter_nelems, sort_nelems, src.get_data(), dst.get_data(),
141-
zero_offset, zero_offset, zero_offset, zero_offset, depends);
145+
sycl::event keep_args_alive_ev =
146+
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev});
142147

143-
sycl::event keep_args_alive_ev =
144-
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev});
148+
return std::make_pair(keep_args_alive_ev, comp_ev);
149+
}
150+
else {
151+
int src_elemsize = src.get_elemsize();
145152

146-
return std::make_pair(keep_args_alive_ev, comp_ev);
153+
sycl::event copy_ev =
154+
exec_q.copy<char>(src.get_data(), dst.get_data(),
155+
src_elemsize * iter_nelems, depends);
156+
157+
return std::make_pair(
158+
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {copy_ev}),
159+
copy_ev);
160+
}
147161
}
148162

149163
throw py::value_error(

dpctl/tests/test_usm_ndarray_sorting.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,43 @@ def test_sort_complex_fp_nan(dtype):
338338
assert np.array_equal(
339339
r1.view(np.int64), r2.view(np.int64)
340340
), f"Failed for {i} and {j}"
341+
342+
343+
def test_radix_sort_size_1_axis():
344+
get_queue_or_skip()
345+
346+
x1 = dpt.ones((), dtype="i1")
347+
r1 = dpt.sort(x1, kind="radixsort")
348+
assert r1 == x1
349+
350+
x2 = dpt.ones([1], dtype="i1")
351+
r2 = dpt.sort(x2, kind="radixsort")
352+
assert r2 == x2
353+
354+
x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1))
355+
r3 = dpt.sort(x3, kind="radixsort")
356+
assert dpt.all(r3 == x3)
357+
358+
x4 = dpt.reshape(dpt.arange(10, dtype="i1"), (1, 10))
359+
r4 = dpt.sort(x4, axis=0, kind="radixsort")
360+
assert dpt.all(r4 == x4)
361+
362+
363+
def test_radix_argsort_size_1_axis():
364+
get_queue_or_skip()
365+
366+
x1 = dpt.ones((), dtype="i1")
367+
r1 = dpt.argsort(x1, kind="radixsort")
368+
assert r1 == 0
369+
370+
x2 = dpt.ones([1], dtype="i1")
371+
r2 = dpt.argsort(x2, kind="radixsort")
372+
assert r2 == 0
373+
374+
x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1))
375+
r3 = dpt.argsort(x3, kind="radixsort")
376+
assert dpt.all(r3 == 0)
377+
378+
x4 = dpt.reshape(dpt.arange(10, dtype="i1"), (1, 10))
379+
r4 = dpt.argsort(x4, axis=0, kind="radixsort")
380+
assert dpt.all(r4 == 0)

0 commit comments

Comments
 (0)