Skip to content

Commit 9521f6b

Browse files
committed
Values testing for remaining tests for elwise funcs starting with a
1 parent 80d2909 commit 9521f6b

File tree

1 file changed

+50
-93
lines changed

1 file changed

+50
-93
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 50 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,22 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
6060
return n
6161

6262

63+
def default_filter(s: Scalar) -> bool:
64+
"""Returns False when s is a non-finite or a signed zero.
65+
66+
Used by default as these values are typically special-cased.
67+
"""
68+
return math.isfinite(s) and s is not -0.0 and s is not +0.0
69+
70+
6371
def unary_assert_against_refimpl(
6472
func_name: str,
6573
in_: Array,
6674
res: Array,
6775
refimpl: Callable[[Scalar], Scalar],
6876
expr_template: str,
6977
res_stype: Optional[ScalarType] = None,
70-
filter_: Callable[[Scalar], bool] = math.isfinite,
78+
filter_: Callable[[Scalar], bool] = default_filter,
7179
):
7280
if in_.shape != res.shape:
7381
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
@@ -114,7 +122,7 @@ def binary_assert_against_refimpl(
114122
left_sym: str = "x1",
115123
right_sym: str = "x2",
116124
res_name: str = "out",
117-
filter_: Callable[[Scalar], bool] = math.isfinite,
125+
filter_: Callable[[Scalar], bool] = default_filter,
118126
):
119127
in_stype = dh.get_scalar_type(left.dtype)
120128
if res_stype is None:
@@ -353,7 +361,7 @@ def binary_param_assert_against_refimpl(
353361
refimpl: Callable[[Scalar, Scalar], Scalar],
354362
expr_template: str,
355363
res_stype: Optional[ScalarType] = None,
356-
filter_: Callable[[Scalar], bool] = math.isfinite,
364+
filter_: Callable[[Scalar], bool] = default_filter,
357365
):
358366
if ctx.right_is_scalar:
359367
assert filter_(right) # sanity check
@@ -429,36 +437,30 @@ def test_abs(ctx, data):
429437
)
430438

431439

432-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
440+
@given(
441+
xps.arrays(
442+
dtype=xps.floating_dtypes(),
443+
shape=hh.shapes(),
444+
elements={"min_value": -1, "max_value": 1},
445+
)
446+
)
433447
def test_acos(x):
434-
res = xp.acos(x)
435-
ph.assert_dtype("acos", x.dtype, res.dtype)
436-
ph.assert_shape("acos", res.shape, x.shape)
437-
ONE = ah.one(x.shape, x.dtype)
438-
# Here (and elsewhere), should technically be res.dtype, but this is the
439-
# same as x.dtype, as tested by the type_promotion tests.
440-
PI = ah.π(x.shape, x.dtype)
441-
ZERO = ah.zero(x.shape, x.dtype)
442-
domain = ah.inrange(x, -ONE, ONE)
443-
codomain = ah.inrange(res, ZERO, PI)
444-
# acos maps [-1, 1] to [0, pi]. Values outside this domain are mapped to
445-
# nan, which is already tested in the special cases.
446-
ah.assert_exactly_equal(domain, codomain)
448+
out = xp.acos(x)
449+
ph.assert_dtype("acos", x.dtype, out.dtype)
450+
ph.assert_shape("acos", out.shape, x.shape)
451+
unary_assert_against_refimpl("acos", x, out, math.acos, "acos({})={}")
447452

448453

449-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
454+
@given(
455+
xps.arrays(
456+
dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1}
457+
)
458+
)
450459
def test_acosh(x):
451-
res = xp.acosh(x)
452-
ph.assert_dtype("acosh", x.dtype, res.dtype)
453-
ph.assert_shape("acosh", res.shape, x.shape)
454-
ONE = ah.one(x.shape, x.dtype)
455-
INFINITY = ah.infinity(x.shape, x.dtype)
456-
ZERO = ah.zero(x.shape, x.dtype)
457-
domain = ah.inrange(x, ONE, INFINITY)
458-
codomain = ah.inrange(res, ZERO, INFINITY)
459-
# acosh maps [-1, inf] to [0, inf]. Values outside this domain are mapped
460-
# to nan, which is already tested in the special cases.
461-
ah.assert_exactly_equal(domain, codomain)
460+
out = xp.acosh(x)
461+
ph.assert_dtype("acosh", x.dtype, out.dtype)
462+
ph.assert_shape("acosh", out.shape, x.shape)
463+
unary_assert_against_refimpl("acosh", x, out, math.acosh, "acosh({})={}")
462464

463465

464466
@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes()))
@@ -479,101 +481,56 @@ def test_add(ctx, data):
479481
)
480482

481483

482-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
484+
@given(
485+
xps.arrays(
486+
dtype=xps.floating_dtypes(),
487+
shape=hh.shapes(),
488+
elements={"min_value": -1, "max_value": 1},
489+
)
490+
)
483491
def test_asin(x):
484492
out = xp.asin(x)
485493
ph.assert_dtype("asin", x.dtype, out.dtype)
486494
ph.assert_shape("asin", out.shape, x.shape)
487-
ONE = ah.one(x.shape, x.dtype)
488-
PI = ah.π(x.shape, x.dtype)
489-
domain = ah.inrange(x, -ONE, ONE)
490-
codomain = ah.inrange(out, -PI / 2, PI / 2)
491-
# asin maps [-1, 1] to [-pi/2, pi/2]. Values outside this domain are
492-
# mapped to nan, which is already tested in the special cases.
493-
ah.assert_exactly_equal(domain, codomain)
495+
unary_assert_against_refimpl("asin", x, out, math.asin, "asin({})={}")
494496

495497

496498
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
497499
def test_asinh(x):
498500
out = xp.asinh(x)
499501
ph.assert_dtype("asinh", x.dtype, out.dtype)
500502
ph.assert_shape("asinh", out.shape, x.shape)
501-
INFINITY = ah.infinity(x.shape, x.dtype)
502-
domain = ah.inrange(x, -INFINITY, INFINITY)
503-
codomain = ah.inrange(out, -INFINITY, INFINITY)
504-
# asinh maps [-inf, inf] to [-inf, inf]. Values outside this domain are
505-
# mapped to nan, which is already tested in the special cases.
506-
ah.assert_exactly_equal(domain, codomain)
503+
unary_assert_against_refimpl("asinh", x, out, math.asinh, "asinh({})={}")
507504

508505

509506
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
510507
def test_atan(x):
511508
out = xp.atan(x)
512509
ph.assert_dtype("atan", x.dtype, out.dtype)
513510
ph.assert_shape("atan", out.shape, x.shape)
514-
INFINITY = ah.infinity(x.shape, x.dtype)
515-
PI = ah.π(x.shape, x.dtype)
516-
domain = ah.inrange(x, -INFINITY, INFINITY)
517-
codomain = ah.inrange(out, -PI / 2, PI / 2)
518-
# atan maps [-inf, inf] to [-pi/2, pi/2]. Values outside this domain are
519-
# mapped to nan, which is already tested in the special cases.
520-
ah.assert_exactly_equal(domain, codomain)
511+
unary_assert_against_refimpl("atan", x, out, math.atan, "atan({})={}")
521512

522513

523514
@given(*hh.two_mutual_arrays(dh.float_dtypes))
524515
def test_atan2(x1, x2):
525516
out = xp.atan2(x1, x2)
526517
ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype)
527518
ph.assert_result_shape("atan2", [x1.shape, x2.shape], out.shape)
528-
INFINITY1 = ah.infinity(x1.shape, x1.dtype)
529-
INFINITY2 = ah.infinity(x2.shape, x2.dtype)
530-
PI = ah.π(out.shape, out.dtype)
531-
domainx1 = ah.inrange(x1, -INFINITY1, INFINITY1)
532-
domainx2 = ah.inrange(x2, -INFINITY2, INFINITY2)
533-
# codomain = ah.inrange(out, -PI, PI, 1e-5)
534-
codomain = ah.inrange(out, -PI, PI)
535-
# atan2 maps [-inf, inf] x [-inf, inf] to [-pi, pi]. Values outside
536-
# this domain are mapped to nan, which is already tested in the special
537-
# cases.
538-
ah.assert_exactly_equal(ah.logical_and(domainx1, domainx2), codomain)
539-
# From the spec:
540-
#
541-
# The mathematical signs of `x1_i` and `x2_i` determine the quadrant of
542-
# each element-wise out. The quadrant (i.e., branch) is chosen such
543-
# that each element-wise out is the signed angle in radians between the
544-
# ray ending at the origin and passing through the point `(1,0)` and the
545-
# ray ending at the origin and passing through the point `(x2_i, x1_i)`.
546-
547-
# This is equivalent to atan2(x1, x2) has the same sign as x1 when x2 is
548-
# finite.
549-
pos_x1 = ah.positive_mathematical_sign(x1)
550-
neg_x1 = ah.negative_mathematical_sign(x1)
551-
pos_x2 = ah.positive_mathematical_sign(x2)
552-
neg_x2 = ah.negative_mathematical_sign(x2)
553-
pos_out = ah.positive_mathematical_sign(out)
554-
neg_out = ah.negative_mathematical_sign(out)
555-
ah.assert_exactly_equal(
556-
ah.logical_or(ah.logical_and(pos_x1, pos_x2), ah.logical_and(pos_x1, neg_x2)),
557-
pos_out,
558-
)
559-
ah.assert_exactly_equal(
560-
ah.logical_or(ah.logical_and(neg_x1, pos_x2), ah.logical_and(neg_x1, neg_x2)),
561-
neg_out,
562-
)
519+
binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2, "atan2({})={}")
563520

564521

565-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
522+
@given(
523+
xps.arrays(
524+
dtype=xps.floating_dtypes(),
525+
shape=hh.shapes(),
526+
elements={"min_value": -1, "max_value": 1},
527+
)
528+
)
566529
def test_atanh(x):
567530
out = xp.atanh(x)
568531
ph.assert_dtype("atanh", x.dtype, out.dtype)
569532
ph.assert_shape("atanh", out.shape, x.shape)
570-
ONE = ah.one(x.shape, x.dtype)
571-
INFINITY = ah.infinity(x.shape, x.dtype)
572-
domain = ah.inrange(x, -ONE, ONE)
573-
codomain = ah.inrange(out, -INFINITY, INFINITY)
574-
# atanh maps [-1, 1] to [-inf, inf]. Values outside this domain are
575-
# mapped to nan, which is already tested in the special cases.
576-
ah.assert_exactly_equal(domain, codomain)
533+
unary_assert_against_refimpl("atanh", x, out, math.atanh, "atanh({})={}")
577534

578535

579536
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)