Skip to content

Commit 56aa06d

Browse files
committed
strict_check kwarg for refiml utils for testing integrals
1 parent 4a364a5 commit 56aa06d

File tree

1 file changed

+12
-47
lines changed

1 file changed

+12
-47
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bo
4949

5050

5151
def mock_int_dtype(n: int, dtype: DataType) -> int:
52-
"""Returns equivalent of `n` that mocks `dtype` behaviour"""
52+
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
5353
nbits = dh.dtype_nbits[dtype]
5454
mask = (1 << nbits) - 1
5555
n &= mask
@@ -76,6 +76,7 @@ def unary_assert_against_refimpl(
7676
expr_template: Optional[str] = None,
7777
res_stype: Optional[ScalarType] = None,
7878
filter_: Callable[[Scalar], bool] = default_filter,
79+
strict_check: bool = False,
7980
):
8081
if in_.shape != res.shape:
8182
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
@@ -101,7 +102,7 @@ def unary_assert_against_refimpl(
101102
f_i = sh.fmt_idx("x", idx)
102103
f_o = sh.fmt_idx("out", idx)
103104
expr = expr_template.format(f_i, expected)
104-
if dh.is_float_dtype(res.dtype):
105+
if not strict_check and dh.is_float_dtype(res.dtype):
105106
assert isclose(scalar_o, expected), (
106107
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
107108
f"{f_i}={scalar_i}"
@@ -125,6 +126,7 @@ def binary_assert_against_refimpl(
125126
right_sym: str = "x2",
126127
res_name: str = "out",
127128
filter_: Callable[[Scalar], bool] = default_filter,
129+
strict_check: bool = False,
128130
):
129131
if expr_template is None:
130132
expr_template = func_name + "({}, {})={}"
@@ -150,7 +152,7 @@ def binary_assert_against_refimpl(
150152
f_r = sh.fmt_idx(right_sym, r_idx)
151153
f_o = sh.fmt_idx(res_name, o_idx)
152154
expr = expr_template.format(f_l, f_r, expected)
153-
if dh.is_float_dtype(res.dtype):
155+
if not strict_check and dh.is_float_dtype(res.dtype):
154156
assert isclose(scalar_o, expected), (
155157
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
156158
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
@@ -366,6 +368,7 @@ def binary_param_assert_against_refimpl(
366368
refimpl: Callable[[Scalar, Scalar], Scalar],
367369
res_stype: Optional[ScalarType] = None,
368370
filter_: Callable[[Scalar], bool] = default_filter,
371+
strict_check: bool = False,
369372
):
370373
expr_template = "({} " + op_sym + " {})={}"
371374
if ctx.right_is_scalar:
@@ -390,7 +393,7 @@ def binary_param_assert_against_refimpl(
390393
f_l = sh.fmt_idx(ctx.left_sym, idx)
391394
f_o = sh.fmt_idx(ctx.res_name, idx)
392395
expr = expr_template.format(f_l, right, expected)
393-
if dh.is_float_dtype(left.dtype):
396+
if not strict_check and dh.is_float_dtype(left.dtype):
394397
assert isclose(scalar_o, expected), (
395398
f"{f_o}={scalar_o}, but should be roughly {expr} "
396399
f"[{ctx.func_name}()]\n"
@@ -415,6 +418,7 @@ def binary_param_assert_against_refimpl(
415418
refimpl=refimpl,
416419
expr_template=expr_template,
417420
filter_=filter_,
421+
strict_check=strict_check,
418422
)
419423

420424

@@ -670,14 +674,7 @@ def test_ceil(x):
670674
out = xp.ceil(x)
671675
ph.assert_dtype("ceil", x.dtype, out.dtype)
672676
ph.assert_shape("ceil", out.shape, x.shape)
673-
finite = ah.isfinite(x)
674-
ah.assert_integral(out[finite])
675-
assert ah.all(ah.less_equal(x[finite], out[finite]))
676-
assert ah.all(
677-
ah.less_equal(out[finite] - x[finite], ah.one(x[finite].shape, x.dtype))
678-
)
679-
integers = ah.isintegral(x)
680-
ah.assert_exactly_equal(out[integers], x[integers])
677+
unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True)
681678

682679

683680
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
@@ -759,18 +756,10 @@ def test_expm1(x):
759756

760757
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
761758
def test_floor(x):
762-
# This test is almost identical to test_ceil
763759
out = xp.floor(x)
764760
ph.assert_dtype("floor", x.dtype, out.dtype)
765761
ph.assert_shape("floor", out.shape, x.shape)
766-
finite = ah.isfinite(x)
767-
ah.assert_integral(out[finite])
768-
assert ah.all(ah.less_equal(out[finite], x[finite]))
769-
assert ah.all(
770-
ah.less_equal(x[finite] - out[finite], ah.one(x[finite].shape, x.dtype))
771-
)
772-
integers = ah.isintegral(x)
773-
ah.assert_exactly_equal(out[integers], x[integers])
762+
unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True)
774763

775764

776765
@pytest.mark.parametrize(
@@ -1122,29 +1111,9 @@ def test_remainder(ctx, data):
11221111
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
11231112
def test_round(x):
11241113
out = xp.round(x)
1125-
11261114
ph.assert_dtype("round", x.dtype, out.dtype)
1127-
11281115
ph.assert_shape("round", out.shape, x.shape)
1129-
1130-
# Test that the out is integral
1131-
finite = ah.isfinite(x)
1132-
ah.assert_integral(out[finite])
1133-
1134-
# round(x) should be the neaoutt integer to x. The case where there is a
1135-
# tie (round to even) is already handled by the special cases tests.
1136-
1137-
# This is the same strategy used in the mask in the
1138-
# test_round_special_cases_one_arg_two_integers_equally_close special
1139-
# cases test.
1140-
floor = xp.floor(x)
1141-
ceil = xp.ceil(x)
1142-
over = xp.subtract(x, floor)
1143-
under = xp.subtract(ceil, x)
1144-
round_down = ah.less(over, under)
1145-
round_up = ah.less(under, over)
1146-
ah.assert_exactly_equal(out[round_down], floor[round_down])
1147-
ah.assert_exactly_equal(out[round_up], ceil[round_up])
1116+
unary_assert_against_refimpl("round", x, out, round, strict_check=True)
11481117

11491118

11501119
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
@@ -1246,8 +1215,4 @@ def test_trunc(x):
12461215
out = xp.trunc(x)
12471216
ph.assert_dtype("trunc", x.dtype, out.dtype)
12481217
ph.assert_shape("trunc", out.shape, x.shape)
1249-
if dh.is_int_dtype(x.dtype):
1250-
ah.assert_exactly_equal(out, x)
1251-
else:
1252-
finite = ah.isfinite(x)
1253-
ah.assert_integral(out[finite])
1218+
unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True)

0 commit comments

Comments
 (0)