Skip to content

Commit d5ac229

Browse files
committed
Adds tests for #1455 resolution
1 parent f293713 commit d5ac229

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

dpctl/tests/test_tensor_sum.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,24 @@ def test_axis0_bug():
212212
assert dpt.all(s == expected)
213213

214214

215+
def test_sum_axis1_axis0():
216+
"""See gh-1455"""
217+
get_queue_or_skip()
218+
219+
# The atomic case is checked in `test_usm_ndarray_reductions`
220+
# This test checks the tree reduction path for correctness
221+
x = dpt.reshape(dpt.arange(4 * 5, dtype="f4"), (4, 5))
222+
223+
m = dpt.sum(x, axis=0)
224+
expected = dpt.asarray([30, 34, 38, 42, 46], dtype="f4")
225+
tol = dpt.finfo(m.dtype).resolution
226+
assert dpt.allclose(m, expected, atol=tol, rtol=tol)
227+
228+
m = dpt.sum(x, axis=1)
229+
expected = dpt.asarray([10, 35, 60, 85], dtype="f4")
230+
assert dpt.allclose(m, expected, atol=tol, rtol=tol)
231+
232+
215233
def _any_complex(dtypes):
216234
return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes)
217235

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@ def test_max_min_axis():
6161
assert dpt.all(m == x[:, 0, 0, :, 0])
6262

6363

64+
def test_max_axis1_axis0():
65+
"""See gh-1455"""
66+
get_queue_or_skip()
67+
68+
x = dpt.reshape(dpt.arange(4 * 5), (4, 5))
69+
70+
m = dpt.max(x, axis=0)
71+
assert dpt.all(m == x[-1, :])
72+
73+
m = dpt.max(x, axis=1)
74+
assert dpt.all(m == x[:, -1])
75+
76+
6477
def test_reduction_keepdims():
6578
get_queue_or_skip()
6679

@@ -440,3 +453,27 @@ def test_hypot_complex():
440453
x = dpt.zeros(1, dtype="c8")
441454
with pytest.raises(TypeError):
442455
dpt.reduce_hypot(x)
456+
457+
458+
def test_tree_reduction_axis1_axis0():
459+
"""See gh-1455"""
460+
get_queue_or_skip()
461+
462+
x = dpt.reshape(dpt.arange(4 * 5, dtype="f4"), (4, 5))
463+
464+
m = dpt.logsumexp(x, axis=0)
465+
tol = dpt.finfo(m.dtype).resolution
466+
assert_allclose(
467+
dpt.asnumpy(m),
468+
np.logaddexp.reduce(dpt.asnumpy(x), axis=0, dtype=m.dtype),
469+
rtol=tol,
470+
atol=tol,
471+
)
472+
473+
m = dpt.logsumexp(x, axis=1)
474+
assert_allclose(
475+
dpt.asnumpy(m),
476+
np.logaddexp.reduce(dpt.asnumpy(x), axis=1, dtype=m.dtype),
477+
rtol=tol,
478+
atol=tol,
479+
)

0 commit comments

Comments
 (0)