@@ -49,7 +49,7 @@ def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bo
49
49
50
50
51
51
def mock_int_dtype (n : int , dtype : DataType ) -> int :
52
- """Returns equivalent of `n` that mocks `dtype` behaviour"""
52
+ """Returns equivalent of `n` that mocks `dtype` behaviour. """
53
53
nbits = dh .dtype_nbits [dtype ]
54
54
mask = (1 << nbits ) - 1
55
55
n &= mask
@@ -76,6 +76,7 @@ def unary_assert_against_refimpl(
76
76
expr_template : Optional [str ] = None ,
77
77
res_stype : Optional [ScalarType ] = None ,
78
78
filter_ : Callable [[Scalar ], bool ] = default_filter ,
79
+ strict_check : bool = False ,
79
80
):
80
81
if in_ .shape != res .shape :
81
82
raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
@@ -101,7 +102,7 @@ def unary_assert_against_refimpl(
101
102
f_i = sh .fmt_idx ("x" , idx )
102
103
f_o = sh .fmt_idx ("out" , idx )
103
104
expr = expr_template .format (f_i , expected )
104
- if dh .is_float_dtype (res .dtype ):
105
+ if not strict_check and dh .is_float_dtype (res .dtype ):
105
106
assert isclose (scalar_o , expected ), (
106
107
f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
107
108
f"{ f_i } ={ scalar_i } "
@@ -125,6 +126,7 @@ def binary_assert_against_refimpl(
125
126
right_sym : str = "x2" ,
126
127
res_name : str = "out" ,
127
128
filter_ : Callable [[Scalar ], bool ] = default_filter ,
129
+ strict_check : bool = False ,
128
130
):
129
131
if expr_template is None :
130
132
expr_template = func_name + "({}, {})={}"
@@ -150,7 +152,7 @@ def binary_assert_against_refimpl(
150
152
f_r = sh .fmt_idx (right_sym , r_idx )
151
153
f_o = sh .fmt_idx (res_name , o_idx )
152
154
expr = expr_template .format (f_l , f_r , expected )
153
- if dh .is_float_dtype (res .dtype ):
155
+ if not strict_check and dh .is_float_dtype (res .dtype ):
154
156
assert isclose (scalar_o , expected ), (
155
157
f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
156
158
f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
@@ -366,6 +368,7 @@ def binary_param_assert_against_refimpl(
366
368
refimpl : Callable [[Scalar , Scalar ], Scalar ],
367
369
res_stype : Optional [ScalarType ] = None ,
368
370
filter_ : Callable [[Scalar ], bool ] = default_filter ,
371
+ strict_check : bool = False ,
369
372
):
370
373
expr_template = "({} " + op_sym + " {})={}"
371
374
if ctx .right_is_scalar :
@@ -390,7 +393,7 @@ def binary_param_assert_against_refimpl(
390
393
f_l = sh .fmt_idx (ctx .left_sym , idx )
391
394
f_o = sh .fmt_idx (ctx .res_name , idx )
392
395
expr = expr_template .format (f_l , right , expected )
393
- if dh .is_float_dtype (left .dtype ):
396
+ if not strict_check and dh .is_float_dtype (left .dtype ):
394
397
assert isclose (scalar_o , expected ), (
395
398
f"{ f_o } ={ scalar_o } , but should be roughly { expr } "
396
399
f"[{ ctx .func_name } ()]\n "
@@ -415,6 +418,7 @@ def binary_param_assert_against_refimpl(
415
418
refimpl = refimpl ,
416
419
expr_template = expr_template ,
417
420
filter_ = filter_ ,
421
+ strict_check = strict_check ,
418
422
)
419
423
420
424
@@ -670,14 +674,7 @@ def test_ceil(x):
670
674
out = xp .ceil (x )
671
675
ph .assert_dtype ("ceil" , x .dtype , out .dtype )
672
676
ph .assert_shape ("ceil" , out .shape , x .shape )
673
- finite = ah .isfinite (x )
674
- ah .assert_integral (out [finite ])
675
- assert ah .all (ah .less_equal (x [finite ], out [finite ]))
676
- assert ah .all (
677
- ah .less_equal (out [finite ] - x [finite ], ah .one (x [finite ].shape , x .dtype ))
678
- )
679
- integers = ah .isintegral (x )
680
- ah .assert_exactly_equal (out [integers ], x [integers ])
677
+ unary_assert_against_refimpl ("ceil" , x , out , math .ceil , strict_check = True )
681
678
682
679
683
680
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
@@ -759,18 +756,10 @@ def test_expm1(x):
759
756
760
757
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
761
758
def test_floor (x ):
762
- # This test is almost identical to test_ceil
763
759
out = xp .floor (x )
764
760
ph .assert_dtype ("floor" , x .dtype , out .dtype )
765
761
ph .assert_shape ("floor" , out .shape , x .shape )
766
- finite = ah .isfinite (x )
767
- ah .assert_integral (out [finite ])
768
- assert ah .all (ah .less_equal (out [finite ], x [finite ]))
769
- assert ah .all (
770
- ah .less_equal (x [finite ] - out [finite ], ah .one (x [finite ].shape , x .dtype ))
771
- )
772
- integers = ah .isintegral (x )
773
- ah .assert_exactly_equal (out [integers ], x [integers ])
762
+ unary_assert_against_refimpl ("floor" , x , out , math .floor , strict_check = True )
774
763
775
764
776
765
@pytest .mark .parametrize (
@@ -1122,29 +1111,9 @@ def test_remainder(ctx, data):
1122
1111
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
1123
1112
def test_round (x ):
1124
1113
out = xp .round (x )
1125
-
1126
1114
ph .assert_dtype ("round" , x .dtype , out .dtype )
1127
-
1128
1115
ph .assert_shape ("round" , out .shape , x .shape )
1129
-
1130
- # Test that the out is integral
1131
- finite = ah .isfinite (x )
1132
- ah .assert_integral (out [finite ])
1133
-
1134
- # round(x) should be the neaoutt integer to x. The case where there is a
1135
- # tie (round to even) is already handled by the special cases tests.
1136
-
1137
- # This is the same strategy used in the mask in the
1138
- # test_round_special_cases_one_arg_two_integers_equally_close special
1139
- # cases test.
1140
- floor = xp .floor (x )
1141
- ceil = xp .ceil (x )
1142
- over = xp .subtract (x , floor )
1143
- under = xp .subtract (ceil , x )
1144
- round_down = ah .less (over , under )
1145
- round_up = ah .less (under , over )
1146
- ah .assert_exactly_equal (out [round_down ], floor [round_down ])
1147
- ah .assert_exactly_equal (out [round_up ], ceil [round_up ])
1116
+ unary_assert_against_refimpl ("round" , x , out , round , strict_check = True )
1148
1117
1149
1118
1150
1119
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
@@ -1246,8 +1215,4 @@ def test_trunc(x):
1246
1215
out = xp .trunc (x )
1247
1216
ph .assert_dtype ("trunc" , x .dtype , out .dtype )
1248
1217
ph .assert_shape ("trunc" , out .shape , x .shape )
1249
- if dh .is_int_dtype (x .dtype ):
1250
- ah .assert_exactly_equal (out , x )
1251
- else :
1252
- finite = ah .isfinite (x )
1253
- ah .assert_integral (out [finite ])
1218
+ unary_assert_against_refimpl ("trunc" , x , out , math .trunc , strict_check = True )
0 commit comments