Skip to content

Commit 4a364a5

Browse files
committed
Refactor majority of elwise tests with refimpl utils
1 parent e50fc1a commit 4a364a5

File tree

1 file changed

+35
-87
lines changed

1 file changed

+35
-87
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 35 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,6 @@ def test_bitwise_xor(ctx, data):
667667

668668
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
669669
def test_ceil(x):
670-
# This test is almost identical to test_floor()
671670
out = xp.ceil(x)
672671
ph.assert_dtype("ceil", x.dtype, out.dtype)
673672
ph.assert_shape("ceil", out.shape, x.shape)
@@ -686,26 +685,15 @@ def test_cos(x):
686685
out = xp.cos(x)
687686
ph.assert_dtype("cos", x.dtype, out.dtype)
688687
ph.assert_shape("cos", out.shape, x.shape)
689-
ONE = ah.one(x.shape, x.dtype)
690-
INFINITY = ah.infinity(x.shape, x.dtype)
691-
domain = ah.inrange(x, -INFINITY, INFINITY, open=True)
692-
codomain = ah.inrange(out, -ONE, ONE)
693-
# cos maps (-inf, inf) to [-1, 1]. Values outside this domain are mapped
694-
# to nan, which is already tested in the special cases.
695-
ah.assert_exactly_equal(domain, codomain)
688+
unary_assert_against_refimpl("cos", x, out, math.cos)
696689

697690

698691
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
699692
def test_cosh(x):
700693
out = xp.cosh(x)
701694
ph.assert_dtype("cosh", x.dtype, out.dtype)
702695
ph.assert_shape("cosh", out.shape, x.shape)
703-
INFINITY = ah.infinity(x.shape, x.dtype)
704-
domain = ah.inrange(x, -INFINITY, INFINITY)
705-
codomain = ah.inrange(out, -INFINITY, INFINITY)
706-
# cosh maps [-inf, inf] to [-inf, inf]. Values outside this domain are
707-
# mapped to nan, which is already tested in the special cases.
708-
ah.assert_exactly_equal(domain, codomain)
696+
unary_assert_against_refimpl("cosh", x, out, math.cosh)
709697

710698

711699
@pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes()))
@@ -758,27 +746,15 @@ def test_exp(x):
758746
out = xp.exp(x)
759747
ph.assert_dtype("exp", x.dtype, out.dtype)
760748
ph.assert_shape("exp", out.shape, x.shape)
761-
INFINITY = ah.infinity(x.shape, x.dtype)
762-
ZERO = ah.zero(x.shape, x.dtype)
763-
domain = ah.inrange(x, -INFINITY, INFINITY)
764-
codomain = ah.inrange(out, ZERO, INFINITY)
765-
# exp maps [-inf, inf] to [0, inf]. Values outside this domain are
766-
# mapped to nan, which is already tested in the special cases.
767-
ah.assert_exactly_equal(domain, codomain)
749+
unary_assert_against_refimpl("exp", x, out, math.exp)
768750

769751

770752
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
771753
def test_expm1(x):
772754
out = xp.expm1(x)
773755
ph.assert_dtype("expm1", x.dtype, out.dtype)
774756
ph.assert_shape("expm1", out.shape, x.shape)
775-
INFINITY = ah.infinity(x.shape, x.dtype)
776-
NEGONE = -ah.one(x.shape, x.dtype)
777-
domain = ah.inrange(x, -INFINITY, INFINITY)
778-
codomain = ah.inrange(out, NEGONE, INFINITY)
779-
# expm1 maps [-inf, inf] to [1, inf]. Values outside this domain are
780-
# mapped to nan, which is already tested in the special cases.
781-
ah.assert_exactly_equal(domain, codomain)
757+
unary_assert_against_refimpl("expm1", x, out, math.expm1)
782758

783759

784760
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
@@ -881,39 +857,17 @@ def test_isfinite(x):
881857
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
882858
def test_isinf(x):
883859
out = xp.isinf(x)
884-
885860
ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool)
886861
ph.assert_shape("isinf", out.shape, x.shape)
887-
888-
if dh.is_int_dtype(x.dtype):
889-
ah.assert_exactly_equal(out, ah.false(x.shape))
890-
finite_or_nan = ah.logical_or(ah.isfinite(x), ah.isnan(x))
891-
ah.assert_exactly_equal(out, ah.logical_not(finite_or_nan))
892-
893-
# Test the exact value by comparing to the math version
894-
if dh.is_float_dtype(x.dtype):
895-
for idx in sh.ndindex(x.shape):
896-
s = float(x[idx])
897-
assert bool(out[idx]) == math.isinf(s)
862+
unary_assert_against_refimpl("isinf", x, out, math.isinf, res_stype=bool)
898863

899864

900865
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
901866
def test_isnan(x):
902867
out = ah.isnan(x)
903-
904868
ph.assert_dtype("isnan", x.dtype, out.dtype, xp.bool)
905869
ph.assert_shape("isnan", out.shape, x.shape)
906-
907-
if dh.is_int_dtype(x.dtype):
908-
ah.assert_exactly_equal(out, ah.false(x.shape))
909-
finite_or_inf = ah.logical_or(ah.isfinite(x), xp.isinf(x))
910-
ah.assert_exactly_equal(out, ah.logical_not(finite_or_inf))
911-
912-
# Test the exact value by comparing to the math version
913-
if dh.is_float_dtype(x.dtype):
914-
for idx in sh.ndindex(x.shape):
915-
s = float(x[idx])
916-
assert bool(out[idx]) == math.isnan(s)
870+
unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool)
917871

918872

919873
@pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes()))
@@ -956,62 +910,56 @@ def test_less_equal(ctx, data):
956910
)
957911

958912

959-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
913+
@given(
914+
xps.arrays(
915+
dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1}
916+
)
917+
)
960918
def test_log(x):
961919
out = xp.log(x)
962-
963920
ph.assert_dtype("log", x.dtype, out.dtype)
964921
ph.assert_shape("log", out.shape, x.shape)
965-
966-
INFINITY = ah.infinity(x.shape, x.dtype)
967-
ZERO = ah.zero(x.shape, x.dtype)
968-
domain = ah.inrange(x, ZERO, INFINITY)
969-
codomain = ah.inrange(out, -INFINITY, INFINITY)
970-
# log maps [0, inf] to [-inf, inf]. Values outside this domain are
971-
# mapped to nan, which is already tested in the special cases.
972-
ah.assert_exactly_equal(domain, codomain)
922+
unary_assert_against_refimpl("log", x, out, math.log)
973923

974924

975-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
925+
@given(
926+
xps.arrays(
927+
dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1}
928+
)
929+
)
976930
def test_log1p(x):
977931
out = xp.log1p(x)
978932
ph.assert_dtype("log1p", x.dtype, out.dtype)
979933
ph.assert_shape("log1p", out.shape, x.shape)
980-
INFINITY = ah.infinity(x.shape, x.dtype)
981-
NEGONE = -ah.one(x.shape, x.dtype)
982-
codomain = ah.inrange(x, NEGONE, INFINITY)
983-
domain = ah.inrange(out, -INFINITY, INFINITY)
984-
# log1p maps [1, inf] to [-inf, inf]. Values outside this domain are
985-
# mapped to nan, which is already tested in the special cases.
986-
ah.assert_exactly_equal(domain, codomain)
934+
unary_assert_against_refimpl("log1p", x, out, math.log1p)
987935

988936

989-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
937+
@given(
938+
xps.arrays(
939+
dtype=xps.floating_dtypes(),
940+
shape=hh.shapes(),
941+
elements={"min_value": 0, "exclude_min": True},
942+
)
943+
)
990944
def test_log2(x):
991945
out = xp.log2(x)
992946
ph.assert_dtype("log2", x.dtype, out.dtype)
993947
ph.assert_shape("log2", out.shape, x.shape)
994-
INFINITY = ah.infinity(x.shape, x.dtype)
995-
ZERO = ah.zero(x.shape, x.dtype)
996-
domain = ah.inrange(x, ZERO, INFINITY)
997-
codomain = ah.inrange(out, -INFINITY, INFINITY)
998-
# log2 maps [0, inf] to [-inf, inf]. Values outside this domain are
999-
# mapped to nan, which is already tested in the special cases.
1000-
ah.assert_exactly_equal(domain, codomain)
948+
unary_assert_against_refimpl("log2", x, out, math.log2)
1001949

1002950

1003-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
951+
@given(
952+
xps.arrays(
953+
dtype=xps.floating_dtypes(),
954+
shape=hh.shapes(),
955+
elements={"min_value": 0, "exclude_min": True},
956+
)
957+
)
1004958
def test_log10(x):
1005959
out = xp.log10(x)
1006960
ph.assert_dtype("log10", x.dtype, out.dtype)
1007961
ph.assert_shape("log10", out.shape, x.shape)
1008-
INFINITY = ah.infinity(x.shape, x.dtype)
1009-
ZERO = ah.zero(x.shape, x.dtype)
1010-
domain = ah.inrange(x, ZERO, INFINITY)
1011-
codomain = ah.inrange(out, -INFINITY, INFINITY)
1012-
# log10 maps [0, inf] to [-inf, inf]. Values outside this domain are
1013-
# mapped to nan, which is already tested in the special cases.
1014-
ah.assert_exactly_equal(domain, codomain)
962+
unary_assert_against_refimpl("log10", x, out, math.log10)
1015963

1016964

1017965
@given(*hh.two_mutual_arrays(dh.float_dtypes))
@@ -1204,7 +1152,7 @@ def test_sign(x):
12041152
out = xp.sign(x)
12051153
ph.assert_dtype("sign", x.dtype, out.dtype)
12061154
ph.assert_shape("sign", out.shape, x.shape)
1207-
scalar_type = dh.get_scalar_type(x.dtype)
1155+
scalar_type = dh.get_scalar_type(out.dtype)
12081156
for idx in sh.ndindex(x.shape):
12091157
scalar_x = scalar_type(x[idx])
12101158
f_x = sh.fmt_idx("x", idx)

0 commit comments

Comments
 (0)