Skip to content

Commit 46672c1

Browse files
TinaAMDmgehre-amdjosel-amd
authored
[mlir][emitc] arith.cmpf to EmitC conversion (#93671)
Convert all arith.cmpf on floats (not vectors/tensors thereof) to EmitC. --------- Co-authored-by: Matthias Gehre <matthias.gehre@amd.com> Co-authored-by: Jose Lopes <jose.lopes@amd.com>
1 parent 4ab7354 commit 46672c1

File tree

3 files changed

+402
-1
lines changed

3 files changed

+402
-1
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/EmitC/IR/EmitC.h"
18-
#include "mlir/Tools/PDLL/AST/Types.h"
18+
#include "mlir/IR/BuiltinAttributes.h"
19+
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/Support/LogicalResult.h"
1921
#include "mlir/Transforms/DialectConversion.h"
2022

2123
using namespace mlir;
@@ -59,6 +61,160 @@ Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
5961
return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
6062
}
6163

64+
class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
65+
public:
66+
using OpConversionPattern::OpConversionPattern;
67+
68+
LogicalResult
69+
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
70+
ConversionPatternRewriter &rewriter) const override {
71+
72+
if (!isa<FloatType>(adaptor.getRhs().getType())) {
73+
return rewriter.notifyMatchFailure(op.getLoc(),
74+
"cmpf currently only supported on "
75+
"floats, not tensors/vectors thereof");
76+
}
77+
78+
bool unordered = false;
79+
emitc::CmpPredicate predicate;
80+
switch (op.getPredicate()) {
81+
case arith::CmpFPredicate::AlwaysFalse: {
82+
auto constant = rewriter.create<emitc::ConstantOp>(
83+
op.getLoc(), rewriter.getI1Type(),
84+
rewriter.getBoolAttr(/*value=*/false));
85+
rewriter.replaceOp(op, constant);
86+
return success();
87+
}
88+
case arith::CmpFPredicate::OEQ:
89+
unordered = false;
90+
predicate = emitc::CmpPredicate::eq;
91+
break;
92+
case arith::CmpFPredicate::OGT:
93+
unordered = false;
94+
predicate = emitc::CmpPredicate::gt;
95+
break;
96+
case arith::CmpFPredicate::OGE:
97+
unordered = false;
98+
predicate = emitc::CmpPredicate::ge;
99+
break;
100+
case arith::CmpFPredicate::OLT:
101+
unordered = false;
102+
predicate = emitc::CmpPredicate::lt;
103+
break;
104+
case arith::CmpFPredicate::OLE:
105+
unordered = false;
106+
predicate = emitc::CmpPredicate::le;
107+
break;
108+
case arith::CmpFPredicate::ONE:
109+
unordered = false;
110+
predicate = emitc::CmpPredicate::ne;
111+
break;
112+
case arith::CmpFPredicate::ORD: {
113+
// ordered, i.e. none of the operands is NaN
114+
auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
115+
adaptor.getRhs());
116+
rewriter.replaceOp(op, cmp);
117+
return success();
118+
}
119+
case arith::CmpFPredicate::UEQ:
120+
unordered = true;
121+
predicate = emitc::CmpPredicate::eq;
122+
break;
123+
case arith::CmpFPredicate::UGT:
124+
unordered = true;
125+
predicate = emitc::CmpPredicate::gt;
126+
break;
127+
case arith::CmpFPredicate::UGE:
128+
unordered = true;
129+
predicate = emitc::CmpPredicate::ge;
130+
break;
131+
case arith::CmpFPredicate::ULT:
132+
unordered = true;
133+
predicate = emitc::CmpPredicate::lt;
134+
break;
135+
case arith::CmpFPredicate::ULE:
136+
unordered = true;
137+
predicate = emitc::CmpPredicate::le;
138+
break;
139+
case arith::CmpFPredicate::UNE:
140+
unordered = true;
141+
predicate = emitc::CmpPredicate::ne;
142+
break;
143+
case arith::CmpFPredicate::UNO: {
144+
// unordered, i.e. either operand is nan
145+
auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
146+
adaptor.getRhs());
147+
rewriter.replaceOp(op, cmp);
148+
return success();
149+
}
150+
case arith::CmpFPredicate::AlwaysTrue: {
151+
auto constant = rewriter.create<emitc::ConstantOp>(
152+
op.getLoc(), rewriter.getI1Type(),
153+
rewriter.getBoolAttr(/*value=*/true));
154+
rewriter.replaceOp(op, constant);
155+
return success();
156+
}
157+
}
158+
159+
// Compare the values naively
160+
auto cmpResult =
161+
rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
162+
adaptor.getLhs(), adaptor.getRhs());
163+
164+
// Adjust the results for unordered/ordered semantics
165+
if (unordered) {
166+
auto isUnordered = createCheckIsUnordered(
167+
rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
168+
rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
169+
isUnordered, cmpResult);
170+
return success();
171+
}
172+
173+
auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
174+
adaptor.getLhs(), adaptor.getRhs());
175+
rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
176+
isOrdered, cmpResult);
177+
return success();
178+
}
179+
180+
private:
181+
/// Return a value that is true if \p operand is NaN.
182+
Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
183+
Value operand) const {
184+
// A value is NaN exactly when it compares unequal to itself.
185+
return rewriter.create<emitc::CmpOp>(
186+
loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
187+
}
188+
189+
/// Return a value that is true if \p operand is not NaN.
190+
Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
191+
Value operand) const {
192+
// A value is not NaN exactly when it compares equal to itself.
193+
return rewriter.create<emitc::CmpOp>(
194+
loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
195+
}
196+
197+
/// Return a value that is true if the operands \p first and \p second are
198+
/// unordered (i.e., at least one of them is NaN).
199+
Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
200+
Location loc, Value first, Value second) const {
201+
auto firstIsNaN = isNaN(rewriter, loc, first);
202+
auto secondIsNaN = isNaN(rewriter, loc, second);
203+
return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
204+
firstIsNaN, secondIsNaN);
205+
}
206+
207+
/// Return a value that is true if the operands \p first and \p second are
208+
/// both ordered (i.e., none one of them is NaN).
209+
Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
210+
Value first, Value second) const {
211+
auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
212+
auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
213+
return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
214+
firstIsNotNaN, secondIsNotNaN);
215+
}
216+
};
217+
62218
class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
63219
public:
64220
using OpConversionPattern::OpConversionPattern;
@@ -463,6 +619,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
463619
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
464620
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
465621
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
622+
CmpFOpConversion,
466623
CmpIOpConversion,
467624
SelectOpConversion,
468625
// Truncation is guaranteed for unsigned types.

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
6565

6666
// -----
6767

68+
func.func @arith_cmpf_vector(%arg0: vector<5xf32>, %arg1: vector<5xf32>) -> vector<5xi1> {
69+
// expected-error @+1 {{failed to legalize operation 'arith.cmpf'}}
70+
%t = arith.cmpf uno, %arg0, %arg1 : vector<5xf32>
71+
return %t: vector<5xi1>
72+
}
73+
74+
// -----
75+
76+
func.func @arith_cmpf_tensor(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xi1> {
77+
// expected-error @+1 {{failed to legalize operation 'arith.cmpf'}}
78+
%t = arith.cmpf uno, %arg0, %arg1 : tensor<5xf32>
79+
return %t: tensor<5xi1>
80+
}
81+
82+
// -----
83+
6884
func.func @arith_extsi_i1_to_i32(%arg0: i1) {
6985
// expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
7086
%idx = arith.extsi %arg0 : i1 to i32

0 commit comments

Comments
 (0)