From 9efcad55ede7d3600b2a1ee0845a71c68285f57b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 16 Oct 2024 14:53:05 -0600 Subject: [PATCH 1/5] Don't ignore exceptions in elementwise reference implementations For some reason, "except OverflowError" was changed to "except Exception" in e72184e5. For now I have removed the except entirely, but it's possible we may need to keep the handling for OverflowError. There are several issues with tests that this was masking, which I have not fixed yet. Quite a few tests are not testing the complex implementations correctly because they are using math instead of cmath, for example. See also data-apis/array-api-compat#183. --- .../test_operators_and_elementwise_functions.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 4c8333c9..42b3f17f 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -242,10 +242,7 @@ def unary_assert_against_refimpl( scalar_i = in_stype(in_[idx]) if not filter_(scalar_i): continue - try: - expected = refimpl(scalar_i) - except Exception: - continue + expected = refimpl(scalar_i) if res.dtype != xp.bool: if res.dtype in dh.complex_dtypes: if expected.real <= m or expected.real >= M: @@ -317,10 +314,7 @@ def binary_assert_against_refimpl( scalar_r = in_stype(right[r_idx]) if not (filter_(scalar_l) and filter_(scalar_r)): continue - try: - expected = refimpl(scalar_l, scalar_r) - except Exception: - continue + expected = refimpl(scalar_l, scalar_r) if res.dtype != xp.bool: if res.dtype in dh.complex_dtypes: if expected.real <= m or expected.real >= M: @@ -392,10 +386,7 @@ def right_scalar_assert_against_refimpl( scalar_l = in_stype(left[idx]) if not (filter_(scalar_l) and filter_(right)): continue - try: - expected = refimpl(scalar_l, right) - except Exception: - continue + expected = refimpl(scalar_l, right) if left.dtype != xp.bool: if res.dtype in dh.complex_dtypes: if expected.real <= m or expected.real >= M: From aba096e68184068b649fa1132a56c7dd1baef81c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 18 Oct 2024 15:44:50 -0600 Subject: [PATCH 2/5] Fix reference implementations for complex elementwise functions Some of these still have some issues that need to be addressed, like overflows. --- ...est_operators_and_elementwise_functions.py | 138 +++++++++++++----- 1 file changed, 104 insertions(+), 34 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 42b3f17f..2ada813c 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -703,9 +703,9 @@ def test_abs(ctx, data): abs, # type: ignore res_stype=float if x.dtype in dh.complex_dtypes else None, expr_template="abs({})={}", - filter_=lambda s: ( - s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s)) - ), + # filter_=lambda s: ( + # s == float("infinity") or (cmath.isfinite(s) and not ph.is_neg_zero(s)) + # ), ) @@ -714,8 +714,10 @@ def test_acos(x): out = xp.acos(x) ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("acos", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acos if x.dtype in dh.complex_dtypes else math.acos + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 unary_assert_against_refimpl( - "acos", x, out, math.acos, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + "acos", x, out, refimpl, filter_=filter_ ) @@ -724,8 +726,10 @@ def test_acosh(x): out = xp.acosh(x) ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acosh if x.dtype in dh.complex_dtypes else math.acosh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 1 unary_assert_against_refimpl( - "acosh", x, out, math.acosh, filter_=lambda s: default_filter(s) and s >= 1 + "acosh", x, out, refimpl, filter_=filter_ ) @@ -748,8 +752,10 @@ def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("asin", out_shape=out.shape, expected=x.shape) + refimpl = cmath.asin if x.dtype in dh.complex_dtypes else math.asin + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 unary_assert_against_refimpl( - "asin", x, out, math.asin, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + "asin", x, out, refimpl, filter_=filter_ ) @@ -758,7 +764,8 @@ def test_asinh(x): out = xp.asinh(x) ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("asinh", x, out, math.asinh) + refimpl = cmath.asinh if x.dtype in dh.complex_dtypes else math.asinh + unary_assert_against_refimpl("asinh", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -766,7 +773,8 @@ def test_atan(x): out = xp.atan(x) ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("atan", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("atan", x, out, math.atan) + refimpl = cmath.atan if x.dtype in dh.complex_dtypes else math.atan + unary_assert_against_refimpl("atan", x, out, refimpl) @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) @@ -774,7 +782,8 @@ def test_atan2(x1, x2): out = xp.atan2(x1, x2) ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) ph.assert_result_shape("atan2", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) + refimpl = cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2 + binary_assert_against_refimpl("atan2", x1, x2, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -782,12 +791,14 @@ def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.atanh if x.dtype in dh.complex_dtypes else math.atanh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 < s < 1 unary_assert_against_refimpl( "atanh", x, out, - math.atanh, - filter_=lambda s: default_filter(s) and -1 <= s <= 1, + refimpl, + filter_=filter_, ) @@ -1065,7 +1076,8 @@ def test_cos(x): out = xp.cos(x) ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("cos", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("cos", x, out, math.cos) + refimpl = cmath.cos if x.dtype in dh.complex_dtypes else math.cos + unary_assert_against_refimpl("cos", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -1073,7 +1085,8 @@ def test_cosh(x): out = xp.cosh(x) ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("cosh", x, out, math.cosh) + refimpl = cmath.cosh if x.dtype in dh.complex_dtypes else math.cosh + unary_assert_against_refimpl("cosh", x, out, refimpl) @pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes)) @@ -1097,7 +1110,7 @@ def test_divide(ctx, data): res, "/", operator.truediv, - filter_=lambda s: math.isfinite(s) and s != 0, + filter_=lambda s: cmath.isfinite(s) and s != 0, ) @@ -1134,7 +1147,8 @@ def test_exp(x): out = xp.exp(x) ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("exp", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("exp", x, out, math.exp) + refimpl = cmath.exp if x.dtype in dh.complex_dtypes else math.exp + unary_assert_against_refimpl("exp", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -1142,7 +1156,23 @@ def test_expm1(x): out = xp.expm1(x) ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("expm1", x, out, math.expm1) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + # There's no cmath.expm1. Use + # + # exp(x+yi) - 1 + # = exp(x)exp(yi) - 1 + # = exp(x)(cos(y) + sin(y)i) - 1 + # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i + # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i + # + # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of + # significance near y = 0. + re, im = z.real, z.imag + return math.expm1(re)*math.cos(im) - 2*math.sin(im/2)**2 + 1j*math.exp(re)*math.sin(im) + else: + refimpl = math.expm1 + unary_assert_against_refimpl("expm1", x, out, refimpl) @given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes())) @@ -1150,7 +1180,12 @@ def test_floor(x): out = xp.floor(x) ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("floor", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + return complex(math.floor(z.real), math.floor(z.imag)) + else: + refimpl = math.floor + unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True) @pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes)) @@ -1236,7 +1271,8 @@ def test_isfinite(x): out = xp.isfinite(x) ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("isfinite", x, out, math.isfinite, res_stype=bool) + refimpl = cmath.isfinite if x.dtype in dh.complex_dtypes else math.isfinite + unary_assert_against_refimpl("isfinite", x, out, refimpl, res_stype=bool) @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) @@ -1244,7 +1280,8 @@ def test_isinf(x): out = xp.isinf(x) ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("isinf", x, out, math.isinf, res_stype=bool) + refimpl = cmath.isinf if x.dtype in dh.complex_dtypes else math.isinf + unary_assert_against_refimpl("isinf", x, out, refimpl, res_stype=bool) @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) @@ -1252,7 +1289,8 @@ def test_isnan(x): out = xp.isnan(x) ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) + refimpl = cmath.isnan if x.dtype in dh.complex_dtypes else math.isnan + unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool) @pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes)) @@ -1300,8 +1338,10 @@ def test_log(x): out = xp.log(x) ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("log", out_shape=out.shape, expected=x.shape) + refimpl = cmath.log if x.dtype in dh.complex_dtypes else math.log + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 unary_assert_against_refimpl( - "log", x, out, math.log, filter_=lambda s: default_filter(s) and s >= 1 + "log", x, out, refimpl, filter_=filter_ ) @@ -1310,8 +1350,19 @@ def test_log1p(x): out = xp.log1p(x) ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape) + # There isn't a cmath.log1p, and implementing one isn't straightforward + # (see + # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy). + # For now, just use log(1+p) for complex inputs, which should hopefully be + # fine given the very loose numerical tolerances we use. If it isn't, we + # can try using something like a series expansion for small p. + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(1+z) + else: + refimpl = math.log1p + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > -1 unary_assert_against_refimpl( - "log1p", x, out, math.log1p, filter_=lambda s: default_filter(s) and s >= 1 + "log1p", x, out, refimpl, filter_=filter_ ) @@ -1320,8 +1371,13 @@ def test_log2(x): out = xp.log2(x) ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("log2", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(2) + else: + refimpl = math.log2 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 unary_assert_against_refimpl( - "log2", x, out, math.log2, filter_=lambda s: default_filter(s) and s > 1 + "log2", x, out, refimpl, filter_=filter_ ) @@ -1330,12 +1386,17 @@ def test_log10(x): out = xp.log10(x) ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("log10", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(10) + else: + refimpl = math.log10 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 unary_assert_against_refimpl( - "log10", x, out, math.log10, filter_=lambda s: default_filter(s) and s > 0 + "log10", x, out, refimpl, filter_=filter_ ) -def logaddexp(l: float, r: float) -> float: +def logaddexp_refimpl(l: float, r: float) -> float: return math.log(math.exp(l) + math.exp(r)) @@ -1344,7 +1405,7 @@ def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) ph.assert_result_shape("logaddexp", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp) + binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp_refimpl) @given(*hh.two_mutual_arrays([xp.bool])) @@ -1521,7 +1582,11 @@ def test_round(x): out = xp.round(x) ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("round", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("round", x, out, round, strict_check=True) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: complex(round(z.real), round(z.imag)) + else: + refimpl = round + unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) @pytest.mark.min_version("2023.12") @@ -1539,13 +1604,12 @@ def test_sign(x): out = xp.sign(x) ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("sign", out_shape=out.shape, expected=x.shape) - refimpl = lambda x: x / math.abs(x) if x != 0 else 0 + refimpl = lambda x: x / abs(x) if x != 0 else 0 unary_assert_against_refimpl( "sign", x, out, refimpl, - filter_=lambda s: s != 0, strict_check=True, ) @@ -1555,7 +1619,8 @@ def test_sin(x): out = xp.sin(x) ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("sin", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("sin", x, out, math.sin) + refimpl = cmath.sin if x.dtype in dh.complex_dtypes else math.sin + unary_assert_against_refimpl("sin", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -1563,7 +1628,8 @@ def test_sinh(x): out = xp.sinh(x) ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("sinh", x, out, math.sinh) + refimpl = cmath.sinh if x.dtype in dh.complex_dtypes else math.sinh + unary_assert_against_refimpl("sinh", x, out, refimpl) @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) @@ -1581,8 +1647,10 @@ def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sqrt if x.dtype in dh.complex_dtypes else math.sqrt + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 0 unary_assert_against_refimpl( - "sqrt", x, out, math.sqrt, filter_=lambda s: default_filter(s) and s >= 0 + "sqrt", x, out, refimpl, filter_=filter_ ) @@ -1605,7 +1673,8 @@ def test_tan(x): out = xp.tan(x) ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("tan", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("tan", x, out, math.tan) + refimpl = cmath.tan if x.dtype in dh.complex_dtypes else math.tan + unary_assert_against_refimpl("tan", x, out, refimpl) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -1613,7 +1682,8 @@ def test_tanh(x): out = xp.tanh(x) ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("tanh", x, out, math.tanh) + refimpl = cmath.tanh if x.dtype in dh.complex_dtypes else math.tanh + unary_assert_against_refimpl("tanh", x, out, refimpl) @given(hh.arrays(dtype=hh.real_dtypes, shape=xps.array_shapes())) From 9fce2989e27a53a9386fc263009c285fca834b13 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 18 Oct 2024 15:46:31 -0600 Subject: [PATCH 3/5] Reinstate the guard against OverflowError for elementwise reference implementations Other exceptions should still be raised, but this is a common occurance which just means that the math library raises OverflowError in many instances instead of giving inf or nan. --- .../test_operators_and_elementwise_functions.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 2ada813c..3d451e6c 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -242,7 +242,10 @@ def unary_assert_against_refimpl( scalar_i = in_stype(in_[idx]) if not filter_(scalar_i): continue - expected = refimpl(scalar_i) + try: + expected = refimpl(scalar_i) + except OverflowError: + continue if res.dtype != xp.bool: if res.dtype in dh.complex_dtypes: if expected.real <= m or expected.real >= M: @@ -314,7 +317,10 @@ def binary_assert_against_refimpl( scalar_r = in_stype(right[r_idx]) if not (filter_(scalar_l) and filter_(scalar_r)): continue - expected = refimpl(scalar_l, scalar_r) + try: + expected = refimpl(scalar_l, scalar_r) + except OverflowError: + continue if res.dtype != xp.bool: if res.dtype in dh.complex_dtypes: if expected.real <= m or expected.real >= M: @@ -386,7 +392,10 @@ def right_scalar_assert_against_refimpl( scalar_l = in_stype(left[idx]) if not (filter_(scalar_l) and filter_(right)): continue - expected = refimpl(scalar_l, right) + try: + expected = refimpl(scalar_l, right) + except OverflowError: + continue if left.dtype != xp.bool: if res.dtype in dh.complex_dtypes: if expected.real <= m or expected.real >= M: From 440094cf9b02d3cb7372a6dc834cebfb5c04840d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 21 Oct 2024 14:25:19 -0600 Subject: [PATCH 4/5] Fix underflow exception in test_logaddexp --- array_api_tests/test_operators_and_elementwise_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 3d451e6c..b9fff5fb 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1406,7 +1406,10 @@ def test_log10(x): def logaddexp_refimpl(l: float, r: float) -> float: - return math.log(math.exp(l) + math.exp(r)) + try: + return math.log(math.exp(l) + math.exp(r)) + except ValueError: # raised for log(0.) + raise OverflowError @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) From abc7ebfec496f447c2f896477683462dc6137ec9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 21 Oct 2024 14:31:46 -0600 Subject: [PATCH 5/5] Fix test_bitwise_left_shift --- array_api_tests/test_operators_and_elementwise_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index b9fff5fb..785d3665 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -846,7 +846,7 @@ def test_bitwise_left_shift(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - nbits = res.dtype + nbits = dh.dtype_nbits[res.dtype] binary_param_assert_against_refimpl( ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0 )