diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 1750171b81a10..fceafcff8490c 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -108,7 +108,7 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { // Compute sign(x) = cast(x < 0) * (-2) + 1 Value sign = rewriter.create(loc, arith::CmpFPredicate::OLT, op.getOperand(), zero); - sign = rewriter.create(loc, floatType, sign); + sign = rewriter.create(loc, floatType, sign); sign = rewriter.create(loc, sign, negTwo); sign = rewriter.create(loc, sign, one); diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 86ee5c8620472..6326d3a71874b 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -9,7 +9,7 @@ func.func @tanh(%arg: f32) -> f32 { // CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32 // CHECK-DAG: %[[TWO:.+]] = arith.constant -2.000000e+00 : f32 // CHECK: %[[VAL0:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] : f32 -// CHECK: %[[VAL1:.+]] = arith.sitofp %[[VAL0]] : i1 to f32 +// CHECK: %[[VAL1:.+]] = arith.uitofp %[[VAL0]] : i1 to f32 // CHECK: %[[VAL2:.+]] = arith.mulf %[[VAL1]], %[[TWO]] : f32 // CHECK: %[[SIGN:.+]] = arith.addf %[[VAL2]], %[[ONE]] : f32 // CHECK: %[[POSX:.+]] = arith.mulf %[[SIGN]], %arg0 : f32 diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir index 541a201c94c58..e2229a392bbf7 100644 --- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -683,6 +683,24 @@ func.func @cosh() { return } +// -------------------------------------------------------------------------- // +// Tanh. +// -------------------------------------------------------------------------- // + +func.func @tanh_8xf32(%a : vector<8xf32>) { + %r = math.tanh %a : vector<8xf32> + vector.print %r : vector<8xf32> + return +} + +func.func @tanh() { + // CHECK: -1, -0.761594, -0.291313, 0, 0.291313, 0.761594, 1, 1 + %v3 = arith.constant dense<[0xff800000, -1.0, -0.3, 0.0, 0.3, 1.0, 10.0, 0x7f800000]> : vector<8xf32> + call @tanh_8xf32(%v3) : (vector<8xf32>) -> () + + return +} + func.func @main() { call @exp2f() : () -> () call @roundf() : () -> () @@ -690,5 +708,6 @@ func.func @main() { call @roundeven() : () -> () call @sinh() : () -> () call @cosh() : () -> () + call @tanh() : () -> () return }