Skip to content

Commit 0e5249b

Browse files
authored
GH-46395: [C++][Statistics] Use EqualOptions for min and max in arrow::ArrayStatistics::Equals() (#46422)
### Rationale for this change `arrow::ArrayStatistics::Equals` does not handle double values for `ArrayStatistics::ValueType` correctly ### What changes are included in this PR? Add `arrow::EqualOptions` to `arrow::ArrayStatistics::Eqauls()` Add `arrow::ArrayStatisticsEqauls()` Add `EqualOptions::use_atol_` Add `EqualOptions::use_atol()` Add `EqualOptions::use_atol(bool v)` ### Are these changes tested? Yes, I ran the relevant unit tests. ### Are there any user-facing changes? Yes. Add `arrow::ArrayStatisticsEqauls()` Add `EqualOptions::use_atol()` Add `EqualOptions::use_atol(bool v)` * GitHub Issue: #46395 Authored-by: Arash Andishgar <arashandishgar1@gmail.com> Signed-off-by: Sutou Kouhei <kou@clear-code.com>
1 parent 1d169cc commit 0e5249b

File tree

4 files changed

+148
-10
lines changed

4 files changed

+148
-10
lines changed

cpp/src/arrow/array/statistics.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <string>
2323
#include <variant>
2424

25+
#include "arrow/compare.h"
2526
#include "arrow/type.h"
2627
#include "arrow/util/visibility.h"
2728

@@ -127,11 +128,17 @@ struct ARROW_EXPORT ArrayStatistics {
127128
/// \brief Whether the maximum value is exact or not
128129
bool is_max_exact = false;
129130

130-
/// \brief Check two statistics for equality
131-
bool Equals(const ArrayStatistics& other) const {
132-
return null_count == other.null_count && distinct_count == other.distinct_count &&
133-
min == other.min && is_min_exact == other.is_min_exact && max == other.max &&
134-
is_max_exact == other.is_max_exact;
131+
/// \brief Check two \ref arrow::ArrayStatistics for equality
132+
///
133+
/// \param other The \ref arrow::ArrayStatistics instance to compare against.
134+
///
135+
/// \param equal_options Options used to compare double values for equality.
136+
///
137+
/// \return True if the two \ref arrow::ArrayStatistics instances are equal; otherwise,
138+
/// false.
139+
bool Equals(const ArrayStatistics& other,
140+
const EqualOptions& equal_options = EqualOptions::Defaults()) const {
141+
return ArrayStatisticsEquals(*this, other, equal_options);
135142
}
136143

137144
/// \brief Check two statistics for equality

cpp/src/arrow/array/statistics_test.cc

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,33 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
#include <limits>
19+
#include <variant>
20+
1821
#include <gtest/gtest.h>
1922

2023
#include "arrow/array/statistics.h"
24+
#include "arrow/compare.h"
2125

2226
namespace arrow {
2327

24-
TEST(ArrayStatisticsTest, TestNullCount) {
28+
TEST(TestArrayStatistics, NullCount) {
2529
ArrayStatistics statistics;
2630
ASSERT_FALSE(statistics.null_count.has_value());
2731
statistics.null_count = 29;
2832
ASSERT_TRUE(statistics.null_count.has_value());
2933
ASSERT_EQ(29, statistics.null_count.value());
3034
}
3135

32-
TEST(ArrayStatisticsTest, TestDistinctCount) {
36+
TEST(TestArrayStatistics, DistinctCount) {
3337
ArrayStatistics statistics;
3438
ASSERT_FALSE(statistics.distinct_count.has_value());
3539
statistics.distinct_count = 29;
3640
ASSERT_TRUE(statistics.distinct_count.has_value());
3741
ASSERT_EQ(29, statistics.distinct_count.value());
3842
}
3943

40-
TEST(ArrayStatisticsTest, TestMin) {
44+
TEST(TestArrayStatistics, Min) {
4145
ArrayStatistics statistics;
4246
ASSERT_FALSE(statistics.min.has_value());
4347
ASSERT_FALSE(statistics.is_min_exact);
@@ -49,7 +53,7 @@ TEST(ArrayStatisticsTest, TestMin) {
4953
ASSERT_TRUE(statistics.is_min_exact);
5054
}
5155

52-
TEST(ArrayStatisticsTest, TestMax) {
56+
TEST(TestArrayStatistics, Max) {
5357
ArrayStatistics statistics;
5458
ASSERT_FALSE(statistics.max.has_value());
5559
ASSERT_FALSE(statistics.is_max_exact);
@@ -61,7 +65,7 @@ TEST(ArrayStatisticsTest, TestMax) {
6165
ASSERT_FALSE(statistics.is_max_exact);
6266
}
6367

64-
TEST(ArrayStatisticsTest, TestEquality) {
68+
TEST(TestArrayStatistics, EqualityNonDoulbeValue) {
6569
ArrayStatistics statistics1;
6670
ArrayStatistics statistics2;
6771

@@ -96,6 +100,56 @@ TEST(ArrayStatisticsTest, TestEquality) {
96100
ASSERT_NE(statistics1, statistics2);
97101
statistics2.is_max_exact = true;
98102
ASSERT_EQ(statistics1, statistics2);
103+
104+
// Test different ArrayStatistics::ValueType
105+
statistics1.max = static_cast<uint64_t>(29);
106+
statistics1.max = static_cast<int64_t>(29);
107+
ASSERT_NE(statistics1, statistics2);
108+
}
109+
110+
class TestArrayStatisticsEqualityDoubleValue : public ::testing::Test {
111+
protected:
112+
ArrayStatistics statistics1_;
113+
ArrayStatistics statistics2_;
114+
EqualOptions options_ = EqualOptions::Defaults();
115+
};
116+
117+
TEST_F(TestArrayStatisticsEqualityDoubleValue, ExactValue) {
118+
statistics2_.min = 29.0;
119+
statistics1_.min = 29.0;
120+
ASSERT_EQ(statistics1_, statistics2_);
121+
statistics2_.min = 30.0;
122+
ASSERT_NE(statistics1_, statistics2_);
123+
}
124+
125+
TEST_F(TestArrayStatisticsEqualityDoubleValue, SignedZero) {
126+
statistics1_.min = +0.0;
127+
statistics2_.min = -0.0;
128+
ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.signed_zeros_equal(true)));
129+
ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.signed_zeros_equal(false)));
130+
}
131+
132+
TEST_F(TestArrayStatisticsEqualityDoubleValue, Infinity) {
133+
auto infinity = std::numeric_limits<double>::infinity();
134+
statistics1_.min = infinity;
135+
statistics2_.min = infinity;
136+
ASSERT_EQ(statistics1_, statistics2_);
137+
statistics1_.min = -infinity;
138+
ASSERT_NE(statistics1_, statistics2_);
139+
}
140+
141+
TEST_F(TestArrayStatisticsEqualityDoubleValue, NaN) {
142+
statistics1_.min = std::numeric_limits<double>::quiet_NaN();
143+
statistics2_.min = std::numeric_limits<double>::quiet_NaN();
144+
ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.nans_equal(true)));
145+
ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.nans_equal(false)));
146+
}
147+
148+
TEST_F(TestArrayStatisticsEqualityDoubleValue, ApproximateEquals) {
149+
statistics1_.max = 0.5001f;
150+
statistics2_.max = 0.5;
151+
ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.atol(1e-3).use_atol(false)));
152+
ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.atol(1e-3)));
99153
}
100154

101155
} // namespace arrow

cpp/src/arrow/compare.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,16 @@
2424
#include <cstdint>
2525
#include <cstring>
2626
#include <memory>
27+
#include <optional>
2728
#include <string>
2829
#include <type_traits>
2930
#include <utility>
31+
#include <variant>
3032
#include <vector>
3133

3234
#include "arrow/array.h"
3335
#include "arrow/array/diff.h"
36+
#include "arrow/array/statistics.h"
3437
#include "arrow/buffer.h"
3538
#include "arrow/scalar.h"
3639
#include "arrow/sparse_tensor.h"
@@ -1523,4 +1526,55 @@ bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata
15231526
}
15241527
}
15251528

1529+
namespace {
1530+
1531+
bool DoubleEquals(const double& left, const double& right, const EqualOptions& options) {
1532+
bool result;
1533+
auto visitor = [&](auto&& compare_func) { result = compare_func(left, right); };
1534+
VisitFloatingEquality<double>(options, options.use_atol(), std::move(visitor));
1535+
return result;
1536+
}
1537+
1538+
bool ArrayStatisticsValueTypeEquals(
1539+
const std::optional<ArrayStatistics::ValueType>& left,
1540+
const std::optional<ArrayStatistics::ValueType>& right, const EqualOptions& options) {
1541+
if (!left.has_value() || !right.has_value()) {
1542+
return left.has_value() == right.has_value();
1543+
} else if (left->index() != right->index()) {
1544+
return false;
1545+
} else {
1546+
auto EqualsVisitor = [&](const auto& v1, const auto& v2) {
1547+
using type_1 = std::decay_t<decltype(v1)>;
1548+
using type_2 = std::decay_t<decltype(v2)>;
1549+
if constexpr (std::conjunction_v<std::is_same<type_1, double>,
1550+
std::is_same<type_2, double>>) {
1551+
return DoubleEquals(v1, v2, options);
1552+
} else if constexpr (std::is_same_v<type_1, type_2>) {
1553+
return v1 == v2;
1554+
}
1555+
// It is unreachable
1556+
DCHECK(false);
1557+
return false;
1558+
};
1559+
return std::visit(EqualsVisitor, left.value(), right.value());
1560+
}
1561+
}
1562+
1563+
bool ArrayStatisticsEqualsImpl(const ArrayStatistics& left, const ArrayStatistics& right,
1564+
const EqualOptions& equal_options) {
1565+
return left.null_count == right.null_count &&
1566+
left.distinct_count == right.distinct_count &&
1567+
left.is_min_exact == right.is_min_exact &&
1568+
left.is_max_exact == right.is_max_exact &&
1569+
ArrayStatisticsValueTypeEquals(left.min, right.min, equal_options) &&
1570+
ArrayStatisticsValueTypeEquals(left.max, right.max, equal_options);
1571+
}
1572+
1573+
} // namespace
1574+
1575+
bool ArrayStatisticsEquals(const ArrayStatistics& left, const ArrayStatistics& right,
1576+
const EqualOptions& options) {
1577+
return ArrayStatisticsEqualsImpl(left, right, options);
1578+
}
1579+
15261580
} // namespace arrow

cpp/src/arrow/compare.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
namespace arrow {
2929

30+
struct ArrayStatistics;
3031
class Array;
3132
class DataType;
3233
class Tensor;
@@ -58,7 +59,18 @@ class EqualOptions {
5859
return res;
5960
}
6061

62+
/// Whether the "atol" property is used in the comparison.
63+
bool use_atol() const { return use_atol_; }
64+
65+
/// Return a new EqualOptions object with the "use_atol" property changed.
66+
EqualOptions use_atol(bool v) const {
67+
auto res = EqualOptions(*this);
68+
res.use_atol_ = v;
69+
return res;
70+
}
71+
6172
/// The absolute tolerance for approximate comparisons of floating-point values.
73+
/// Note that this option is ignored if "use_atol" is set to false.
6274
double atol() const { return atol_; }
6375

6476
/// Return a new EqualOptions object with the "atol" property changed.
@@ -87,6 +99,7 @@ class EqualOptions {
8799
double atol_ = kDefaultAbsoluteTolerance;
88100
bool nans_equal_ = false;
89101
bool signed_zeros_equal_ = true;
102+
bool use_atol_ = true;
90103

91104
std::ostream* diff_sink_ = NULLPTR;
92105
};
@@ -135,6 +148,16 @@ ARROW_EXPORT bool SparseTensorEquals(const SparseTensor& left, const SparseTenso
135148
ARROW_EXPORT bool TypeEquals(const DataType& left, const DataType& right,
136149
bool check_metadata = true);
137150

151+
/// \brief Check two \ref arrow::ArrayStatistics for equality
152+
/// \param[in] left an \ref arrow::ArrayStatistics
153+
/// \param[in] right an \ref arrow::ArrayStatistics
154+
/// \param[in] options Options used to compare double values for equality.
155+
/// \return True if the two \ref arrow::ArrayStatistics instances are equal; otherwise,
156+
/// false.
157+
ARROW_EXPORT bool ArrayStatisticsEquals(
158+
const ArrayStatistics& left, const ArrayStatistics& right,
159+
const EqualOptions& options = EqualOptions::Defaults());
160+
138161
/// Returns true if scalars are equal
139162
/// \param[in] left a Scalar
140163
/// \param[in] right a Scalar

0 commit comments

Comments
 (0)