@@ -60,14 +60,22 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
60
60
return n
61
61
62
62
63
+ def default_filter (s : Scalar ) -> bool :
64
+ """Returns False when s is a non-finite or a signed zero.
65
+
66
+ Used by default as these values are typically special-cased.
67
+ """
68
+ return math .isfinite (s ) and s is not - 0.0 and s is not + 0.0
69
+
70
+
63
71
def unary_assert_against_refimpl (
64
72
func_name : str ,
65
73
in_ : Array ,
66
74
res : Array ,
67
75
refimpl : Callable [[Scalar ], Scalar ],
68
76
expr_template : str ,
69
77
res_stype : Optional [ScalarType ] = None ,
70
- filter_ : Callable [[Scalar ], bool ] = math . isfinite ,
78
+ filter_ : Callable [[Scalar ], bool ] = default_filter ,
71
79
):
72
80
if in_ .shape != res .shape :
73
81
raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
@@ -114,7 +122,7 @@ def binary_assert_against_refimpl(
114
122
left_sym : str = "x1" ,
115
123
right_sym : str = "x2" ,
116
124
res_name : str = "out" ,
117
- filter_ : Callable [[Scalar ], bool ] = math . isfinite ,
125
+ filter_ : Callable [[Scalar ], bool ] = default_filter ,
118
126
):
119
127
in_stype = dh .get_scalar_type (left .dtype )
120
128
if res_stype is None :
@@ -353,7 +361,7 @@ def binary_param_assert_against_refimpl(
353
361
refimpl : Callable [[Scalar , Scalar ], Scalar ],
354
362
expr_template : str ,
355
363
res_stype : Optional [ScalarType ] = None ,
356
- filter_ : Callable [[Scalar ], bool ] = math . isfinite ,
364
+ filter_ : Callable [[Scalar ], bool ] = default_filter ,
357
365
):
358
366
if ctx .right_is_scalar :
359
367
assert filter_ (right ) # sanity check
@@ -429,36 +437,30 @@ def test_abs(ctx, data):
429
437
)
430
438
431
439
432
- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
440
+ @given (
441
+ xps .arrays (
442
+ dtype = xps .floating_dtypes (),
443
+ shape = hh .shapes (),
444
+ elements = {"min_value" : - 1 , "max_value" : 1 },
445
+ )
446
+ )
433
447
def test_acos (x ):
434
- res = xp .acos (x )
435
- ph .assert_dtype ("acos" , x .dtype , res .dtype )
436
- ph .assert_shape ("acos" , res .shape , x .shape )
437
- ONE = ah .one (x .shape , x .dtype )
438
- # Here (and elsewhere), should technically be res.dtype, but this is the
439
- # same as x.dtype, as tested by the type_promotion tests.
440
- PI = ah .π (x .shape , x .dtype )
441
- ZERO = ah .zero (x .shape , x .dtype )
442
- domain = ah .inrange (x , - ONE , ONE )
443
- codomain = ah .inrange (res , ZERO , PI )
444
- # acos maps [-1, 1] to [0, pi]. Values outside this domain are mapped to
445
- # nan, which is already tested in the special cases.
446
- ah .assert_exactly_equal (domain , codomain )
448
+ out = xp .acos (x )
449
+ ph .assert_dtype ("acos" , x .dtype , out .dtype )
450
+ ph .assert_shape ("acos" , out .shape , x .shape )
451
+ unary_assert_against_refimpl ("acos" , x , out , math .acos , "acos({})={}" )
447
452
448
453
449
- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
454
+ @given (
455
+ xps .arrays (
456
+ dtype = xps .floating_dtypes (), shape = hh .shapes (), elements = {"min_value" : 1 }
457
+ )
458
+ )
450
459
def test_acosh (x ):
451
- res = xp .acosh (x )
452
- ph .assert_dtype ("acosh" , x .dtype , res .dtype )
453
- ph .assert_shape ("acosh" , res .shape , x .shape )
454
- ONE = ah .one (x .shape , x .dtype )
455
- INFINITY = ah .infinity (x .shape , x .dtype )
456
- ZERO = ah .zero (x .shape , x .dtype )
457
- domain = ah .inrange (x , ONE , INFINITY )
458
- codomain = ah .inrange (res , ZERO , INFINITY )
459
- # acosh maps [-1, inf] to [0, inf]. Values outside this domain are mapped
460
- # to nan, which is already tested in the special cases.
461
- ah .assert_exactly_equal (domain , codomain )
460
+ out = xp .acosh (x )
461
+ ph .assert_dtype ("acosh" , x .dtype , out .dtype )
462
+ ph .assert_shape ("acosh" , out .shape , x .shape )
463
+ unary_assert_against_refimpl ("acosh" , x , out , math .acosh , "acosh({})={}" )
462
464
463
465
464
466
@pytest .mark .parametrize ("ctx," , make_binary_params ("add" , xps .numeric_dtypes ()))
@@ -479,101 +481,56 @@ def test_add(ctx, data):
479
481
)
480
482
481
483
482
- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
484
+ @given (
485
+ xps .arrays (
486
+ dtype = xps .floating_dtypes (),
487
+ shape = hh .shapes (),
488
+ elements = {"min_value" : - 1 , "max_value" : 1 },
489
+ )
490
+ )
483
491
def test_asin (x ):
484
492
out = xp .asin (x )
485
493
ph .assert_dtype ("asin" , x .dtype , out .dtype )
486
494
ph .assert_shape ("asin" , out .shape , x .shape )
487
- ONE = ah .one (x .shape , x .dtype )
488
- PI = ah .π (x .shape , x .dtype )
489
- domain = ah .inrange (x , - ONE , ONE )
490
- codomain = ah .inrange (out , - PI / 2 , PI / 2 )
491
- # asin maps [-1, 1] to [-pi/2, pi/2]. Values outside this domain are
492
- # mapped to nan, which is already tested in the special cases.
493
- ah .assert_exactly_equal (domain , codomain )
495
+ unary_assert_against_refimpl ("asin" , x , out , math .asin , "asin({})={}" )
494
496
495
497
496
498
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
497
499
def test_asinh (x ):
498
500
out = xp .asinh (x )
499
501
ph .assert_dtype ("asinh" , x .dtype , out .dtype )
500
502
ph .assert_shape ("asinh" , out .shape , x .shape )
501
- INFINITY = ah .infinity (x .shape , x .dtype )
502
- domain = ah .inrange (x , - INFINITY , INFINITY )
503
- codomain = ah .inrange (out , - INFINITY , INFINITY )
504
- # asinh maps [-inf, inf] to [-inf, inf]. Values outside this domain are
505
- # mapped to nan, which is already tested in the special cases.
506
- ah .assert_exactly_equal (domain , codomain )
503
+ unary_assert_against_refimpl ("asinh" , x , out , math .asinh , "asinh({})={}" )
507
504
508
505
509
506
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
510
507
def test_atan (x ):
511
508
out = xp .atan (x )
512
509
ph .assert_dtype ("atan" , x .dtype , out .dtype )
513
510
ph .assert_shape ("atan" , out .shape , x .shape )
514
- INFINITY = ah .infinity (x .shape , x .dtype )
515
- PI = ah .π (x .shape , x .dtype )
516
- domain = ah .inrange (x , - INFINITY , INFINITY )
517
- codomain = ah .inrange (out , - PI / 2 , PI / 2 )
518
- # atan maps [-inf, inf] to [-pi/2, pi/2]. Values outside this domain are
519
- # mapped to nan, which is already tested in the special cases.
520
- ah .assert_exactly_equal (domain , codomain )
511
+ unary_assert_against_refimpl ("atan" , x , out , math .atan , "atan({})={}" )
521
512
522
513
523
514
@given (* hh .two_mutual_arrays (dh .float_dtypes ))
524
515
def test_atan2 (x1 , x2 ):
525
516
out = xp .atan2 (x1 , x2 )
526
517
ph .assert_dtype ("atan2" , [x1 .dtype , x2 .dtype ], out .dtype )
527
518
ph .assert_result_shape ("atan2" , [x1 .shape , x2 .shape ], out .shape )
528
- INFINITY1 = ah .infinity (x1 .shape , x1 .dtype )
529
- INFINITY2 = ah .infinity (x2 .shape , x2 .dtype )
530
- PI = ah .π (out .shape , out .dtype )
531
- domainx1 = ah .inrange (x1 , - INFINITY1 , INFINITY1 )
532
- domainx2 = ah .inrange (x2 , - INFINITY2 , INFINITY2 )
533
- # codomain = ah.inrange(out, -PI, PI, 1e-5)
534
- codomain = ah .inrange (out , - PI , PI )
535
- # atan2 maps [-inf, inf] x [-inf, inf] to [-pi, pi]. Values outside
536
- # this domain are mapped to nan, which is already tested in the special
537
- # cases.
538
- ah .assert_exactly_equal (ah .logical_and (domainx1 , domainx2 ), codomain )
539
- # From the spec:
540
- #
541
- # The mathematical signs of `x1_i` and `x2_i` determine the quadrant of
542
- # each element-wise out. The quadrant (i.e., branch) is chosen such
543
- # that each element-wise out is the signed angle in radians between the
544
- # ray ending at the origin and passing through the point `(1,0)` and the
545
- # ray ending at the origin and passing through the point `(x2_i, x1_i)`.
546
-
547
- # This is equivalent to atan2(x1, x2) has the same sign as x1 when x2 is
548
- # finite.
549
- pos_x1 = ah .positive_mathematical_sign (x1 )
550
- neg_x1 = ah .negative_mathematical_sign (x1 )
551
- pos_x2 = ah .positive_mathematical_sign (x2 )
552
- neg_x2 = ah .negative_mathematical_sign (x2 )
553
- pos_out = ah .positive_mathematical_sign (out )
554
- neg_out = ah .negative_mathematical_sign (out )
555
- ah .assert_exactly_equal (
556
- ah .logical_or (ah .logical_and (pos_x1 , pos_x2 ), ah .logical_and (pos_x1 , neg_x2 )),
557
- pos_out ,
558
- )
559
- ah .assert_exactly_equal (
560
- ah .logical_or (ah .logical_and (neg_x1 , pos_x2 ), ah .logical_and (neg_x1 , neg_x2 )),
561
- neg_out ,
562
- )
519
+ binary_assert_against_refimpl ("atan2" , x1 , x2 , out , math .atan2 , "atan2({})={}" )
563
520
564
521
565
- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
522
+ @given (
523
+ xps .arrays (
524
+ dtype = xps .floating_dtypes (),
525
+ shape = hh .shapes (),
526
+ elements = {"min_value" : - 1 , "max_value" : 1 },
527
+ )
528
+ )
566
529
def test_atanh (x ):
567
530
out = xp .atanh (x )
568
531
ph .assert_dtype ("atanh" , x .dtype , out .dtype )
569
532
ph .assert_shape ("atanh" , out .shape , x .shape )
570
- ONE = ah .one (x .shape , x .dtype )
571
- INFINITY = ah .infinity (x .shape , x .dtype )
572
- domain = ah .inrange (x , - ONE , ONE )
573
- codomain = ah .inrange (out , - INFINITY , INFINITY )
574
- # atanh maps [-1, 1] to [-inf, inf]. Values outside this domain are
575
- # mapped to nan, which is already tested in the special cases.
576
- ah .assert_exactly_equal (domain , codomain )
533
+ unary_assert_against_refimpl ("atanh" , x , out , math .atanh , "atanh({})={}" )
577
534
578
535
579
536
@pytest .mark .parametrize (
0 commit comments