From ca4de5152467406b849897ffc17c8a6faf61b0dd Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 13:45:51 -0600 Subject: [PATCH 1/2] Add a sanity check that signbit works The array-api-strict package disables signbit at runtime when the API version is set to 2022.12. This happens when the function is called, so the hasattr check would pass even though the function call fails. --- array_api_tests/pytest_helpers.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 0e1b4c8b..496f0324 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -421,6 +421,15 @@ def assert_fill( assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg +def _has_signbit() -> bool: + if not hasattr(_xp, "signbit"): + return False + try: + assert _xp.all(_xp.signbit(_xp.asarray(0.0)) == False) + except: + return False + return True + def _real_float_strict_equals(out: Array, expected: Array) -> bool: nan_mask = xp.isnan(out) if not xp.all(nan_mask == xp.isnan(expected)): @@ -429,7 +438,7 @@ def _real_float_strict_equals(out: Array, expected: Array) -> bool: # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's # not that big of a deal for the perf costs. - if hasattr(_xp, "signbit"): + if _has_signbit(): out_zero_mask = out == 0 out_sign_mask = _xp.signbit(out) out_pos_zero_mask = out_zero_mask & out_sign_mask From 21524f7c9ef5a1d751e46543f7cccb0a11740645 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 May 2024 15:16:35 -0600 Subject: [PATCH 2/2] Apply suggestions from code review --- array_api_tests/pytest_helpers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 496f0324..9759822e 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -421,7 +421,8 @@ def assert_fill( assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg -def _has_signbit() -> bool: +def _has_functional_signbit() -> bool: + # signbit can be available but not implemented (e.g., in array-api-strict) if not hasattr(_xp, "signbit"): return False try: @@ -438,7 +439,7 @@ def _real_float_strict_equals(out: Array, expected: Array) -> bool: # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's # not that big of a deal for the perf costs. - if _has_signbit(): + if _has_functional_signbit(): out_zero_mask = out == 0 out_sign_mask = _xp.signbit(out) out_pos_zero_mask = out_zero_mask & out_sign_mask