|
15 | 15 |
|
16 | 16 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
17 | 17 | #include "mlir/Dialect/EmitC/IR/EmitC.h"
|
18 |
| -#include "mlir/Tools/PDLL/AST/Types.h" |
| 18 | +#include "mlir/IR/BuiltinAttributes.h" |
| 19 | +#include "mlir/IR/BuiltinTypes.h" |
| 20 | +#include "mlir/Support/LogicalResult.h" |
19 | 21 | #include "mlir/Transforms/DialectConversion.h"
|
20 | 22 |
|
21 | 23 | using namespace mlir;
|
@@ -59,6 +61,160 @@ Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
|
59 | 61 | return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
|
60 | 62 | }
|
61 | 63 |
|
| 64 | +class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> { |
| 65 | +public: |
| 66 | + using OpConversionPattern::OpConversionPattern; |
| 67 | + |
| 68 | + LogicalResult |
| 69 | + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| 70 | + ConversionPatternRewriter &rewriter) const override { |
| 71 | + |
| 72 | + if (!isa<FloatType>(adaptor.getRhs().getType())) { |
| 73 | + return rewriter.notifyMatchFailure(op.getLoc(), |
| 74 | + "cmpf currently only supported on " |
| 75 | + "floats, not tensors/vectors thereof"); |
| 76 | + } |
| 77 | + |
| 78 | + bool unordered = false; |
| 79 | + emitc::CmpPredicate predicate; |
| 80 | + switch (op.getPredicate()) { |
| 81 | + case arith::CmpFPredicate::AlwaysFalse: { |
| 82 | + auto constant = rewriter.create<emitc::ConstantOp>( |
| 83 | + op.getLoc(), rewriter.getI1Type(), |
| 84 | + rewriter.getBoolAttr(/*value=*/false)); |
| 85 | + rewriter.replaceOp(op, constant); |
| 86 | + return success(); |
| 87 | + } |
| 88 | + case arith::CmpFPredicate::OEQ: |
| 89 | + unordered = false; |
| 90 | + predicate = emitc::CmpPredicate::eq; |
| 91 | + break; |
| 92 | + case arith::CmpFPredicate::OGT: |
| 93 | + unordered = false; |
| 94 | + predicate = emitc::CmpPredicate::gt; |
| 95 | + break; |
| 96 | + case arith::CmpFPredicate::OGE: |
| 97 | + unordered = false; |
| 98 | + predicate = emitc::CmpPredicate::ge; |
| 99 | + break; |
| 100 | + case arith::CmpFPredicate::OLT: |
| 101 | + unordered = false; |
| 102 | + predicate = emitc::CmpPredicate::lt; |
| 103 | + break; |
| 104 | + case arith::CmpFPredicate::OLE: |
| 105 | + unordered = false; |
| 106 | + predicate = emitc::CmpPredicate::le; |
| 107 | + break; |
| 108 | + case arith::CmpFPredicate::ONE: |
| 109 | + unordered = false; |
| 110 | + predicate = emitc::CmpPredicate::ne; |
| 111 | + break; |
| 112 | + case arith::CmpFPredicate::ORD: { |
| 113 | + // ordered, i.e. none of the operands is NaN |
| 114 | + auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(), |
| 115 | + adaptor.getRhs()); |
| 116 | + rewriter.replaceOp(op, cmp); |
| 117 | + return success(); |
| 118 | + } |
| 119 | + case arith::CmpFPredicate::UEQ: |
| 120 | + unordered = true; |
| 121 | + predicate = emitc::CmpPredicate::eq; |
| 122 | + break; |
| 123 | + case arith::CmpFPredicate::UGT: |
| 124 | + unordered = true; |
| 125 | + predicate = emitc::CmpPredicate::gt; |
| 126 | + break; |
| 127 | + case arith::CmpFPredicate::UGE: |
| 128 | + unordered = true; |
| 129 | + predicate = emitc::CmpPredicate::ge; |
| 130 | + break; |
| 131 | + case arith::CmpFPredicate::ULT: |
| 132 | + unordered = true; |
| 133 | + predicate = emitc::CmpPredicate::lt; |
| 134 | + break; |
| 135 | + case arith::CmpFPredicate::ULE: |
| 136 | + unordered = true; |
| 137 | + predicate = emitc::CmpPredicate::le; |
| 138 | + break; |
| 139 | + case arith::CmpFPredicate::UNE: |
| 140 | + unordered = true; |
| 141 | + predicate = emitc::CmpPredicate::ne; |
| 142 | + break; |
| 143 | + case arith::CmpFPredicate::UNO: { |
| 144 | + // unordered, i.e. either operand is nan |
| 145 | + auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(), |
| 146 | + adaptor.getRhs()); |
| 147 | + rewriter.replaceOp(op, cmp); |
| 148 | + return success(); |
| 149 | + } |
| 150 | + case arith::CmpFPredicate::AlwaysTrue: { |
| 151 | + auto constant = rewriter.create<emitc::ConstantOp>( |
| 152 | + op.getLoc(), rewriter.getI1Type(), |
| 153 | + rewriter.getBoolAttr(/*value=*/true)); |
| 154 | + rewriter.replaceOp(op, constant); |
| 155 | + return success(); |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + // Compare the values naively |
| 160 | + auto cmpResult = |
| 161 | + rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate, |
| 162 | + adaptor.getLhs(), adaptor.getRhs()); |
| 163 | + |
| 164 | + // Adjust the results for unordered/ordered semantics |
| 165 | + if (unordered) { |
| 166 | + auto isUnordered = createCheckIsUnordered( |
| 167 | + rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); |
| 168 | + rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(), |
| 169 | + isUnordered, cmpResult); |
| 170 | + return success(); |
| 171 | + } |
| 172 | + |
| 173 | + auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(), |
| 174 | + adaptor.getLhs(), adaptor.getRhs()); |
| 175 | + rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(), |
| 176 | + isOrdered, cmpResult); |
| 177 | + return success(); |
| 178 | + } |
| 179 | + |
| 180 | +private: |
| 181 | + /// Return a value that is true if \p operand is NaN. |
| 182 | + Value isNaN(ConversionPatternRewriter &rewriter, Location loc, |
| 183 | + Value operand) const { |
| 184 | + // A value is NaN exactly when it compares unequal to itself. |
| 185 | + return rewriter.create<emitc::CmpOp>( |
| 186 | + loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand); |
| 187 | + } |
| 188 | + |
| 189 | + /// Return a value that is true if \p operand is not NaN. |
| 190 | + Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, |
| 191 | + Value operand) const { |
| 192 | + // A value is not NaN exactly when it compares equal to itself. |
| 193 | + return rewriter.create<emitc::CmpOp>( |
| 194 | + loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand); |
| 195 | + } |
| 196 | + |
| 197 | + /// Return a value that is true if the operands \p first and \p second are |
| 198 | + /// unordered (i.e., at least one of them is NaN). |
| 199 | + Value createCheckIsUnordered(ConversionPatternRewriter &rewriter, |
| 200 | + Location loc, Value first, Value second) const { |
| 201 | + auto firstIsNaN = isNaN(rewriter, loc, first); |
| 202 | + auto secondIsNaN = isNaN(rewriter, loc, second); |
| 203 | + return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(), |
| 204 | + firstIsNaN, secondIsNaN); |
| 205 | + } |
| 206 | + |
| 207 | + /// Return a value that is true if the operands \p first and \p second are |
| 208 | + /// both ordered (i.e., none one of them is NaN). |
| 209 | + Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc, |
| 210 | + Value first, Value second) const { |
| 211 | + auto firstIsNotNaN = isNotNaN(rewriter, loc, first); |
| 212 | + auto secondIsNotNaN = isNotNaN(rewriter, loc, second); |
| 213 | + return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(), |
| 214 | + firstIsNotNaN, secondIsNotNaN); |
| 215 | + } |
| 216 | +}; |
| 217 | + |
62 | 218 | class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
|
63 | 219 | public:
|
64 | 220 | using OpConversionPattern::OpConversionPattern;
|
@@ -463,6 +619,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
|
463 | 619 | BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
|
464 | 620 | BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
|
465 | 621 | BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
|
| 622 | + CmpFOpConversion, |
466 | 623 | CmpIOpConversion,
|
467 | 624 | SelectOpConversion,
|
468 | 625 | // Truncation is guaranteed for unsigned types.
|
|
0 commit comments