Skip to content

[mlir][emitc] arith.cmpf to EmitC conversion #93671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 4, 2024

Conversation

TinaAMD
Copy link
Contributor

@TinaAMD TinaAMD commented May 29, 2024

Convert all arith.cmpf on floats (not vectors/tensors thereof) to EmitC.

Convert all arith.cmpf on floats (not vectors/tensors thereof) to EmitC.

Co-authored-by: Matthias Gehre <matthias.gehre@amd.com>
Co-authored-by: Jose Lopes <jose.lopes@amd.com>
@llvmbot
Copy link
Member

llvmbot commented May 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-emitc

Author: Tina Jung (TinaAMD)

Changes

Convert all arith.cmpf on floats (not vectors/tensors thereof) to EmitC.


Full diff: https://github.com/llvm/llvm-project/pull/93671.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+160-1)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir (+16)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+228)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 388794ec122d2..71e3e48e8d185 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -15,7 +15,9 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
-#include "mlir/Tools/PDLL/AST/Types.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -40,6 +42,162 @@ class ArithConstantOpConversionPattern
   }
 };
 
+class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!isa<FloatType>(adaptor.getRhs().getType())) {
+      return rewriter.notifyMatchFailure(op.getLoc(),
+                                         "cmpf currently only supported on "
+                                         "floats, not tensors/vectors thereof");
+    }
+
+    bool unordered = false;
+    emitc::CmpPredicate predicate;
+    switch (op.getPredicate()) {
+    case arith::CmpFPredicate::AlwaysFalse: {
+      auto constant = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rewriter.getI1Type(),
+          rewriter.getBoolAttr(/*value=*/false));
+      rewriter.replaceOp(op, constant);
+      return success();
+    }
+    case arith::CmpFPredicate::OEQ:
+      unordered = false;
+      predicate = emitc::CmpPredicate::eq;
+      break;
+    case arith::CmpFPredicate::OGT:
+      unordered = false;
+      predicate = emitc::CmpPredicate::gt;
+      break;
+    case arith::CmpFPredicate::OGE:
+      unordered = false;
+      predicate = emitc::CmpPredicate::ge;
+      break;
+    case arith::CmpFPredicate::OLT:
+      unordered = false;
+      predicate = emitc::CmpPredicate::lt;
+      break;
+    case arith::CmpFPredicate::OLE:
+      unordered = false;
+      predicate = emitc::CmpPredicate::le;
+      break;
+    case arith::CmpFPredicate::ONE:
+      unordered = false;
+      predicate = emitc::CmpPredicate::ne;
+      break;
+    case arith::CmpFPredicate::ORD: {
+      // ordered, i.e. none of the operands is NaN
+      auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
+                                      adaptor.getRhs());
+      rewriter.replaceOp(op, cmp);
+      return success();
+    }
+    case arith::CmpFPredicate::UEQ:
+      unordered = true;
+      predicate = emitc::CmpPredicate::eq;
+      break;
+    case arith::CmpFPredicate::UGT:
+      unordered = true;
+      predicate = emitc::CmpPredicate::gt;
+      break;
+    case arith::CmpFPredicate::UGE:
+      unordered = true;
+      predicate = emitc::CmpPredicate::ge;
+      break;
+    case arith::CmpFPredicate::ULT:
+      unordered = true;
+      predicate = emitc::CmpPredicate::lt;
+      break;
+    case arith::CmpFPredicate::ULE:
+      unordered = true;
+      predicate = emitc::CmpPredicate::le;
+      break;
+    case arith::CmpFPredicate::UNE:
+      unordered = true;
+      predicate = emitc::CmpPredicate::ne;
+      break;
+    case arith::CmpFPredicate::UNO: {
+      // unordered, i.e. either operand is nan
+      auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
+                                        adaptor.getRhs());
+      rewriter.replaceOp(op, cmp);
+      return success();
+    }
+    case arith::CmpFPredicate::AlwaysTrue: {
+      auto constant = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rewriter.getI1Type(),
+          rewriter.getBoolAttr(/*value=*/true));
+      rewriter.replaceOp(op, constant);
+      return success();
+    }
+    }
+
+    // Compare the values naively
+    auto cmpResult =
+        rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
+                                      adaptor.getLhs(), adaptor.getRhs());
+
+    // Adjust the results for unordered/ordered semantics
+    if (unordered) {
+      auto isUnordered = createCheckIsUnordered(
+          rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
+      rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(
+          op, op.getType(), isUnordered.getResult(), cmpResult);
+      return success();
+    }
+
+    auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
+                                          adaptor.getLhs(), adaptor.getRhs());
+    rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
+                                                     isOrdered, cmpResult);
+    return success();
+  }
+
+private:
+  /// Return an operation that returns true (in i1) when \p operand is NaN.
+  emitc::CmpOp isNan(ConversionPatternRewriter &rewriter, Location loc,
+                     Value operand) const {
+    // A value is NaN exactly when it compares unequal to itself.
+    return rewriter.create<emitc::CmpOp>(
+        loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
+  }
+
+  /// Return an operation that returns true (in i1) when \p operand is not NaN.
+  emitc::CmpOp isNotNan(ConversionPatternRewriter &rewriter, Location loc,
+                        Value operand) const {
+    // A value is not NaN exactly when it compares equal to itself.
+    return rewriter.create<emitc::CmpOp>(
+        loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
+  }
+
+  /// Return an op that return true (in i1) if the operands \p first and
+  /// \p second are unordered (i.e., at least one of them is NaN).
+  emitc::LogicalOrOp createCheckIsUnordered(ConversionPatternRewriter &rewriter,
+                                            Location loc, Value first,
+                                            Value second) const {
+    auto firstIsNaN = isNan(rewriter, loc, first);
+    auto secondIsNaN = isNan(rewriter, loc, second);
+    return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
+                                               firstIsNaN, secondIsNaN);
+  }
+
+  /// Return an op that return true (in i1) if the operands \p first and
+  /// \p second are both ordered (i.e., none one of them is NaN).
+  emitc::LogicalAndOp createCheckIsOrdered(ConversionPatternRewriter &rewriter,
+                                           Location loc, Value first,
+                                           Value second) const {
+    auto firstIsNaN = isNotNan(rewriter, loc, first);
+    auto secondIsNaN = isNotNan(rewriter, loc, second);
+    return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
+                                                firstIsNaN, secondIsNaN);
+  }
+};
+
 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -401,6 +559,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
     IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
     IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
+    CmpFOpConversion,
     CmpIOpConversion,
     SelectOpConversion,
     // Truncation is guaranteed for unsigned types.
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 97e4593f97b90..c07289109e6dd 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -65,6 +65,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
 
 // -----
 
+func.func @arith_cmpf_vector(%arg0: vector<5xf32>, %arg1: vector<5xf32>) -> vector<5xi1> {
+  // expected-error @+1 {{failed to legalize operation 'arith.cmpf'}}
+  %t = arith.cmpf uno, %arg0, %arg1 : vector<5xf32>
+  return %t: vector<5xi1>
+}
+
+// -----
+
+func.func @arith_cmpf_tensor(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xi1> {
+  // expected-error @+1 {{failed to legalize operation 'arith.cmpf'}}
+  %t = arith.cmpf uno, %arg0, %arg1 : tensor<5xf32>
+  return %t: tensor<5xi1>
+}
+
+// -----
+
 func.func @arith_extsi_i1_to_i32(%arg0: i1) {
   // expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
   %idx = arith.extsi %arg0 : i1 to i32
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index dac3fd99b607c..cf6ba2f3febde 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -107,6 +107,234 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -
 
 // -----
 
+func.func @arith_cmpf_false(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_false
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[False:[^ ]*]] = "emitc.constant"() <{value = false}> : () -> i1
+  %false = arith.cmpf false, %arg0, %arg1 : f32
+  // CHECK: return [[False]]
+  return %false: i1
+}
+
+// -----
+
+func.func @arith_cmpf_oeq(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_oeq
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[EQ:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[OEQ:[^ ]*]] = emitc.logical_and [[Ordered]], [[EQ]] : i1, i1
+  %oeq = arith.cmpf oeq, %arg0, %arg1 : f32
+  // CHECK: return [[OEQ]]
+  return %oeq: i1
+}
+
+// -----
+
+func.func @arith_cmpf_ogt(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_ogt
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[GT:[^ ]*]] = emitc.cmp gt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[OrderedArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[OrderedArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[OrderedArg0]], [[OrderedArg1]] : i1, i1
+  // CHECK-DAG: [[OGT:[^ ]*]] = emitc.logical_and [[Ordered]], [[GT]] : i1, i1
+  %ogt = arith.cmpf ogt, %arg0, %arg1 : f32
+  // CHECK: return [[OGT]]
+  return %ogt: i1
+}
+
+// -----
+
+func.func @arith_cmpf_oge(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_oge
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[GE:[^ ]*]] = emitc.cmp ge, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[OGE:[^ ]*]] = emitc.logical_and [[Ordered]], [[GE]] : i1, i1
+  %oge = arith.cmpf oge, %arg0, %arg1 : f32
+  // CHECK: return [[OGE]]
+  return %oge: i1
+}
+
+// -----
+
+func.func @arith_cmpf_olt(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_olt
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[LT:[^ ]*]] = emitc.cmp lt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[OLT:[^ ]*]] = emitc.logical_and [[Ordered]], [[LT]] : i1, i1
+  %olt = arith.cmpf olt, %arg0, %arg1 : f32
+  // CHECK: return [[OLT]]
+  return %olt: i1
+}
+
+// -----
+
+func.func @arith_cmpf_ole(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_ole
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[LT:[^ ]*]] = emitc.cmp le, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[OLE:[^ ]*]] = emitc.logical_and [[Ordered]], [[LT]] : i1, i1
+  %ole = arith.cmpf ole, %arg0, %arg1 : f32
+  // CHECK: return [[OLE]]
+  return %ole: i1
+}
+
+// -----
+
+func.func @arith_cmpf_one(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_one
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[NEQ:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[ONE:[^ ]*]] = emitc.logical_and [[Ordered]], [[NEQ]] : i1, i1
+  %one = arith.cmpf one, %arg0, %arg1 : f32
+  // CHECK: return [[ONE]]
+  return %one: i1
+}
+
+// -----
+
+func.func @arith_cmpf_ord(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_ord
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
+  %ord = arith.cmpf ord, %arg0, %arg1 : f32
+  // CHECK: return [[Ordered]]
+  return %ord: i1
+}
+
+// -----
+
+func.func @arith_cmpf_ueq(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_ueq
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[EQ:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[UEQ:[^ ]*]] = emitc.logical_or [[Unordered]], [[EQ]] : i1, i1
+  %ueq = arith.cmpf ueq, %arg0, %arg1 : f32
+  // CHECK: return [[UEQ]]
+  return %ueq: i1
+}
+
+// -----
+
+func.func @arith_cmpf_ugt(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_ugt
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[GT:[^ ]*]] = emitc.cmp gt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[UGT:[^ ]*]] = emitc.logical_or [[Unordered]], [[GT]] : i1, i1
+  %ugt = arith.cmpf ugt, %arg0, %arg1 : f32
+  // CHECK: return [[UGT]]
+  return %ugt: i1
+}
+
+// -----
+
+func.func @arith_cmpf_uge(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_uge
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[GE:[^ ]*]] = emitc.cmp ge, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[UGE:[^ ]*]] = emitc.logical_or [[Unordered]], [[GE]] : i1, i1
+  %uge = arith.cmpf uge, %arg0, %arg1 : f32
+  // CHECK: return [[UGE]]
+  return %uge: i1
+}
+
+// -----
+
+func.func @arith_cmpf_ult(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_ult
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[LT:[^ ]*]] = emitc.cmp lt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[ULT:[^ ]*]] = emitc.logical_or [[Unordered]], [[LT]] : i1, i1
+  %ult = arith.cmpf ult, %arg0, %arg1 : f32
+  // CHECK: return [[ULT]]
+  return %ult: i1
+}
+
+// -----
+
+func.func @arith_cmpf_ule(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_ule
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[LE:[^ ]*]] = emitc.cmp le, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[ULE:[^ ]*]] = emitc.logical_or [[Unordered]], [[LE]] : i1, i1
+  %ule = arith.cmpf ule, %arg0, %arg1 : f32
+  // CHECK: return [[ULE]]
+  return %ule: i1
+}
+
+// -----
+
+func.func @arith_cmpf_une(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_une
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[NEQ:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1
+  // CHECK-DAG: [[UNE:[^ ]*]] = emitc.logical_or [[Unordered]], [[NEQ]] : i1, i1
+  %une = arith.cmpf une, %arg0, %arg1 : f32
+  // CHECK: return [[UNE]]
+  return %une: i1
+}
+
+// -----
+
+func.func @arith_cmpf_uno(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_uno
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
+  // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
+  // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1
+  %uno = arith.cmpf uno, %arg0, %arg1 : f32
+  // CHECK: return [[Unordered]]
+  return %uno: i1
+}
+
+// -----
+
+func.func @arith_cmpf_true(%arg0: f32, %arg1: f32) -> i1 {
+  // CHECK-LABEL: arith_cmpf_true
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
+  // CHECK-DAG: [[True:[^ ]*]] = "emitc.constant"() <{value = true}> : () -> i1
+  %ueq = arith.cmpf true, %arg0, %arg1 : f32
+  // CHECK: return [[True]]
+  return %ueq: i1
+}
+
+// -----
+
 func.func @arith_cmpi_eq(%arg0: i32, %arg1: i32) -> i1 {
   // CHECK-LABEL: arith_cmpi_eq
   // CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32)

Copy link
Contributor

@simon-camp simon-camp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good, just a few nits.

TinaAMD added 2 commits May 29, 2024 15:56
* Rename functions from Nan to NaN
* Add missing `Not` in variable name
* Change return type of private function to Value + adapt documentation to reflect this
@TinaAMD TinaAMD requested a review from simon-camp May 29, 2024 15:02
@TinaAMD
Copy link
Contributor Author

TinaAMD commented May 29, 2024

Thanks for the very quick review @simon-camp! I addressed the comments, please resolve the conversations if you are happy with the changes 🙂

Copy link
Contributor

@simon-camp simon-camp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good!

Copy link
Member

@marbre marbre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor NITs.

@TinaAMD
Copy link
Contributor Author

TinaAMD commented Jun 3, 2024

Some minor NITs.

Resolved. Was this a preliminary review or is the PR good to go then?🙂

@marbre
Copy link
Member

marbre commented Jun 3, 2024

Some minor NITs.

Resolved. Was this a preliminary review or is the PR good to go then?🙂

I only had a very quick look but will be all good if Simon approved :)

@simon-camp
Copy link
Contributor

Yes, this can be landed from my side :)

@TinaAMD TinaAMD merged commit 46672c1 into llvm:main Jun 4, 2024
7 checks passed
@TinaAMD TinaAMD deleted the tina.arith-to-emitc-cmpf branch June 4, 2024 06:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants