From c4dd5ad49f64f58aa46cd1d241fab0ffa5f3b553 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Thu, 9 May 2024 14:36:51 +0800 Subject: [PATCH 01/25] add canonicalize-f32-promotion pass --- .../mlir/Dialect/Math/Transforms/Passes.h | 1 + .../mlir/Dialect/Math/Transforms/Passes.td | 43 +++++++++++ .../Dialect/Math/Transforms/CMakeLists.txt | 1 + .../Transforms/CanonicalizeF32Promotion.cpp | 73 +++++++++++++++++++ .../Math/canonicalize-f32-promotion.mlir | 56 ++++++++++++++ 5 files changed, 174 insertions(+) create mode 100644 mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp create mode 100644 mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index e2c513047c77a..f150ff6f944d2 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -17,6 +17,7 @@ namespace math { #include "mlir/Dialect/Math/Transforms/Passes.h.inc" #define GEN_PASS_DECL_MATHUPLIFTTOFMA #define GEN_PASS_DECL_MATHLEGALIZETOF32 +#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION #include "mlir/Dialect/Math/Transforms/Passes.h.inc" #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Math/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td index e870e714bfda5..538dcbfbe7f77 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -36,4 +36,47 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> { let dependentDialects = ["math::MathDialect", "arith::ArithDialect"]; } +def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> { + let summary = "Eliminate redundant truncf/extf pairs"; + let description = [{ + `legalize-to-f32` pass does f32 promotion for every op belonging to the + illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32` + will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal + ops. + + This pass is to eliminate the redundant truncf/extf pairs. + + Example: + + ```mlir + // the initial func + func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { + %0 = math.absf %arg0 : vector<32xbf16> + %1 = math.sin %0 : vector<32xbf16> + return %1 : vector<32xbf16> + } + // after legalize-to-f32 + func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { + %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> + %1 = math.absf %0 : vector<32xf32> + %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16> + %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32> + %4 = math.sin %3 : vector<32xf32> + %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16> + return %5 : vector<32xbf16> + } + // after canonicalize-f32-promotion + func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { + %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> + %1 = math.absf %0 : vector<32xf32> + %2 = math.sin %1 : vector<32xf32> + %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16> + return %3 : vector<32xbf16> + } + ``` + + }]; + let dependentDialects = ["math::MathDialect", "arith::ArithDialect"]; +} + #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index 2a5b4fbcb5271..0d39d14925d23 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMathTransforms AlgebraicSimplification.cpp ExpandPatterns.cpp LegalizeToF32.cpp + CanonicalizeF32Promotion.cpp PolynomialApproximation.cpp UpliftToFMA.cpp diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp new file mode 100644 index 0000000000000..bfff17df8d7d4 --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp @@ -0,0 +1,73 @@ +//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs +//----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements removing redundant extf/truncf pairs inserted from +// LegalizeToF32. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::math { +#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + +using namespace mlir; + +namespace { + +struct CanonicalizeF32PromotionRewritePattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const final { + if (auto innertruncop = op.getOperand().getDefiningOp()) { + if (auto truncinput = innertruncop.getOperand()) { + auto outter_type = op.getType(); + auto intermediate_type = innertruncop.getType(); + auto inner_type = truncinput.getType(); + if (outter_type.isa()) { + outter_type = op.getType().cast().getElementType(); + intermediate_type = + innertruncop.getType().cast().getElementType(); + inner_type = truncinput.getType().cast().getElementType(); + } + if (outter_type.isF32() && + (intermediate_type.isF16() || intermediate_type.isBF16()) && + inner_type.isF32()) { + rewriter.replaceOp(op, {truncinput}); + } + } else + return failure(); + } else + return failure(); + return success(); + } +}; + +struct MathCanonicalizeF32Promotion final + : math::impl::MathCanonicalizeF32PromotionBase< + MathCanonicalizeF32Promotion> { + using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase; + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir new file mode 100644 index 0000000000000..7aad7889e2bf5 --- /dev/null +++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s + +// CHECK-LABEL: @sequences +// CHECK-SAME: ([[ARG0:%.+]]: bf16) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : bf16 +func.func @sequences(%arg0: bf16) -> bf16 { + %0 = math.absf %arg0 : bf16 + %1 = math.sin %0 : bf16 + return %1 : bf16 +} + +// CHECK-LABEL: @eliminatecastoncastf16 +// CHECK: return [[arg0:%.+]] : f32 +func.func @eliminatecastoncastf16(%arg0: f32) -> f32 { + %0 = arith.truncf %arg0 : f32 to f16 + %1 = arith.extf %0 : f16 to f32 + return %1 : f32 +} + +// CHECK-LABEL: @eliminatecastoncastbf16 +// CHECK: return [[arg0:%.+]] : f32 +func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 { + %0 = arith.truncf %arg0 : f32 to bf16 + %1 = arith.extf %0 : bf16 to f32 + return %1 : f32 +} + +// CHECK-LABEL: @bf16_sin_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> +func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = math.absf %arg0 : vector<32x32x32xbf16> + %1 = math.sin %0 : vector<32x32x32xbf16> + return %1 : vector<32x32x32xbf16> +} + +// CHECK-LABEL: @f16_sin_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xf16> +func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { + %0 = math.absf %arg0 : vector<32x32x32xf16> + %1 = math.sin %0 : vector<32x32x32xf16> + return %1 : vector<32x32x32xf16> +} From 02be4d6dedc81e9e5ace44829f388e36e52e0278 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Fri, 10 May 2024 11:09:31 +0800 Subject: [PATCH 02/25] add branch case --- .../mlir/Dialect/Math/Transforms/Passes.td | 6 +++++- .../Transforms/CanonicalizeF32Promotion.cpp | 3 +-- .../Math/canonicalize-f32-promotion.mlir | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td index 538dcbfbe7f77..5bf5eb45f921a 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -44,7 +44,11 @@ def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> { will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal ops. - This pass is to eliminate the redundant truncf/extf pairs. + This pass is to eliminate the redundant truncf/extf pairs to improve + performance. + + However, this pass may introduce numerical difference as the `f32->bf16` rounding + is eliminated. Example: diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp index bfff17df8d7d4..b9b43a0887f14 100644 --- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp +++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp @@ -1,5 +1,4 @@ -//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -//----------===// +//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir index 7aad7889e2bf5..127eece98cf79 100644 --- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir +++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir @@ -54,3 +54,21 @@ func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { %1 = math.sin %0 : vector<32x32x32xf16> return %1 : vector<32x32x32xf16> } + +// CHECK-LABEL: @bf16_branch_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]] +// CHECK: [[COS:%.+]] = math.cos [[ABSF]] +// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]] +// CHECK: [[ADDF:%.+]] = arith.addf +// CHECK: return [[ADDF]] : vector<32x32x32xbf16> +func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = math.absf %arg0 : vector<32x32x32xbf16> + %1 = math.sin %0 : vector<32x32x32xbf16> + %2 = math.cos %0 : vector<32x32x32xbf16> + %3 = arith.addf %1, %2 : vector<32x32x32xbf16> + return %3 : vector<32x32x32xbf16> +} From 07ca29dbe48d010a36fdab154687547f26a6ead5 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Fri, 17 May 2024 14:21:38 +0800 Subject: [PATCH 03/25] use single walk rather than greedy rewrite --- .../Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp index b9b43a0887f14..8257ddb5c2efc 100644 --- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp +++ b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp @@ -64,7 +64,12 @@ struct MathCanonicalizeF32Promotion final RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + SmallVector ops; + getOperation()->walk([&](Operation *op) { + if (isa(op)) + ops.push_back(op); + }); + if (failed(applyOpPatternsAndFold(ops, patternSet))) signalPassFailure(); } }; From 5152f89609f50c3fea755391aec33a2a2bc891da Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 27 May 2024 10:32:13 +0800 Subject: [PATCH 04/25] adjust test case --- .../Dialect/Math/canonicalize-f32-promotion.mlir | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir index 127eece98cf79..5ed189b0033b3 100644 --- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir +++ b/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 -math-canonicalize-f32-promotion | FileCheck %s +// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -math-canonicalize-f32-promotion | FileCheck %s // CHECK-LABEL: @sequences // CHECK-SAME: ([[ARG0:%.+]]: bf16) @@ -59,12 +59,11 @@ func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { // CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] // CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] -// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]] -// CHECK: [[COS:%.+]] = math.cos [[ABSF]] -// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]] -// CHECK: [[ADDF:%.+]] = arith.addf -// CHECK: return [[ADDF]] : vector<32x32x32xbf16> +// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]] +// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { %0 = math.absf %arg0 : vector<32x32x32xbf16> %1 = math.sin %0 : vector<32x32x32xbf16> From f6e310cda6e131843f519363323e60b7bbd18347 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 27 May 2024 14:42:14 +0800 Subject: [PATCH 05/25] do cast elimination in transforms with eliminatable attr --- .../mlir/Dialect/Math/Transforms/Passes.h | 1 - .../mlir/Dialect/Math/Transforms/Passes.td | 47 ---------- mlir/include/mlir/Transforms/Passes.h | 5 ++ mlir/include/mlir/Transforms/Passes.td | 48 +++++++++++ .../Transforms/EmulateUnsupportedFloats.cpp | 11 ++- .../Dialect/Math/Transforms/CMakeLists.txt | 1 - .../Transforms/CanonicalizeF32Promotion.cpp | 77 ----------------- .../Dialect/Math/Transforms/LegalizeToF32.cpp | 11 ++- mlir/lib/Transforms/CMakeLists.txt | 3 + .../Transforms/EliminateExplicitRounding.cpp | 85 +++++++++++++++++++ .../eliminate-explicit-rounding.mlir} | 2 +- 11 files changed, 158 insertions(+), 133 deletions(-) delete mode 100644 mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp create mode 100644 mlir/lib/Transforms/EliminateExplicitRounding.cpp rename mlir/test/{Dialect/Math/canonicalize-f32-promotion.mlir => Transforms/eliminate-explicit-rounding.mlir} (98%) diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index f150ff6f944d2..e2c513047c77a 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -17,7 +17,6 @@ namespace math { #include "mlir/Dialect/Math/Transforms/Passes.h.inc" #define GEN_PASS_DECL_MATHUPLIFTTOFMA #define GEN_PASS_DECL_MATHLEGALIZETOF32 -#define GEN_PASS_DECL_MATHCANONICALIZEF32PROMOTION #include "mlir/Dialect/Math/Transforms/Passes.h.inc" #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Math/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td index 5bf5eb45f921a..e870e714bfda5 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -36,51 +36,4 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> { let dependentDialects = ["math::MathDialect", "arith::ArithDialect"]; } -def MathCanonicalizeF32Promotion : Pass<"math-canonicalize-f32-promotion"> { - let summary = "Eliminate redundant truncf/extf pairs"; - let description = [{ - `legalize-to-f32` pass does f32 promotion for every op belonging to the - illegal op list. Once there are some consecutive illegal ops, `legalize-to-f32` - will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal - ops. - - This pass is to eliminate the redundant truncf/extf pairs to improve - performance. - - However, this pass may introduce numerical difference as the `f32->bf16` rounding - is eliminated. - - Example: - - ```mlir - // the initial func - func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { - %0 = math.absf %arg0 : vector<32xbf16> - %1 = math.sin %0 : vector<32xbf16> - return %1 : vector<32xbf16> - } - // after legalize-to-f32 - func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { - %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> - %1 = math.absf %0 : vector<32xf32> - %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16> - %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32> - %4 = math.sin %3 : vector<32xf32> - %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16> - return %5 : vector<32xbf16> - } - // after canonicalize-f32-promotion - func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { - %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> - %1 = math.absf %0 : vector<32xf32> - %2 = math.sin %1 : vector<32xf32> - %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16> - return %3 : vector<32xbf16> - } - ``` - - }]; - let dependentDialects = ["math::MathDialect", "arith::ArithDialect"]; -} - #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 58bd61b2ae8b8..c618fff9a8040 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -44,6 +44,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_SYMBOLPRIVATIZE #define GEN_PASS_DECL_TOPOLOGICALSORT #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS +#define GEN_PASS_DECL_ELIMINATEEXPLICITROUNDING #include "mlir/Transforms/Passes.h.inc" /// Creates an instance of the Canonicalizer pass, configured with default @@ -137,6 +138,10 @@ std::unique_ptr createCompositeFixedPointPass( std::string name, llvm::function_ref populateFunc, int maxIterations = 10); +/// Create eliminate-explicit-rounding pass, which eliminates the redundant +/// truncf/extf pairs to improve performance. +std::unique_ptr createEliminateExplicitRoundingPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 1b40a87c63f27..1539bda02ac60 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -569,4 +569,52 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> { ]; } +def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> { + let summary = "Eliminate redundant truncf/extf pairs"; + let description = [{ + `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion for every op belonging to the + illegal op list. Once there are some consecutive illegal ops, these passes + will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal + ops. + + This pass is to eliminate the redundant truncf/extf pairs to improve + performance. + + However, this pass may introduce numerical difference as the `f32->bf16` rounding + is eliminated. + + Example: + + ```mlir + // the initial func + func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { + %0 = math.absf %arg0 : vector<32xbf16> + %1 = math.sin %0 : vector<32xbf16> + return %1 : vector<32xbf16> + } + // after legalize-to-f32 + func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { + %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> + %1 = math.absf %0 : vector<32xf32> + %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16> + %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32> + %4 = math.sin %3 : vector<32xf32> + %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16> + return %5 : vector<32xbf16> + } + // after canonicalize-f32-promotion + func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { + %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> + %1 = math.absf %0 : vector<32xf32> + %2 = math.sin %1 : vector<32xf32> + %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16> + return %3 : vector<32xbf16> + } + ``` + + }]; + let constructor = "mlir::createEliminateExplicitRoundingPass()"; +} + + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 4a50da3513f99..9cbb3884659ee 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef operands, SmallVector newResults(expandedOp->getResults()); for (auto [res, oldType, newType] : llvm::zip_equal( MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { - if (oldType != newType) - res = rewriter.create(loc, oldType, res); + if (oldType != newType) { + auto truncFOp = rewriter.create(loc, oldType, res); + truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true)); + res = truncFOp->getResults().front(); + } } rewriter.replaceOp(op, newResults); } @@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions( }); converter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - return b.create(loc, target, input); + auto extFOp = b.create(loc, target, input); + extFOp->setAttr("eliminatable", b.getBoolAttr(true)); + return extFOp; }); } diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index 0d39d14925d23..2a5b4fbcb5271 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRMathTransforms AlgebraicSimplification.cpp ExpandPatterns.cpp LegalizeToF32.cpp - CanonicalizeF32Promotion.cpp PolynomialApproximation.cpp UpliftToFMA.cpp diff --git a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp b/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp deleted file mode 100644 index 8257ddb5c2efc..0000000000000 --- a/mlir/lib/Dialect/Math/Transforms/CanonicalizeF32Promotion.cpp +++ /dev/null @@ -1,77 +0,0 @@ -//===- CanonicalizeF32Promotion.cpp - Remove redundant extf/truncf pairs -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements removing redundant extf/truncf pairs inserted from -// LegalizeToF32. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::math { -#define GEN_PASS_DEF_MATHCANONICALIZEF32PROMOTION -#include "mlir/Dialect/Math/Transforms/Passes.h.inc" -} // namespace mlir::math - -using namespace mlir; - -namespace { - -struct CanonicalizeF32PromotionRewritePattern final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(arith::ExtFOp op, - PatternRewriter &rewriter) const final { - if (auto innertruncop = op.getOperand().getDefiningOp()) { - if (auto truncinput = innertruncop.getOperand()) { - auto outter_type = op.getType(); - auto intermediate_type = innertruncop.getType(); - auto inner_type = truncinput.getType(); - if (outter_type.isa()) { - outter_type = op.getType().cast().getElementType(); - intermediate_type = - innertruncop.getType().cast().getElementType(); - inner_type = truncinput.getType().cast().getElementType(); - } - if (outter_type.isF32() && - (intermediate_type.isF16() || intermediate_type.isBF16()) && - inner_type.isF32()) { - rewriter.replaceOp(op, {truncinput}); - } - } else - return failure(); - } else - return failure(); - return success(); - } -}; - -struct MathCanonicalizeF32Promotion final - : math::impl::MathCanonicalizeF32PromotionBase< - MathCanonicalizeF32Promotion> { - using MathCanonicalizeF32PromotionBase::MathCanonicalizeF32PromotionBase; - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - SmallVector ops; - getOperation()->walk([&](Operation *op) { - if (isa(op)) - ops.push_back(op); - }); - if (failed(applyOpPatternsAndFold(ops, patternSet))) - signalPassFailure(); - } -}; - -} // namespace diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp index 5998133b7eab8..da049602bc909 100644 --- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp +++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp @@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter( }); typeConverter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - return b.create(loc, target, input); + auto extFOp = b.create(loc, target, input); + extFOp->setAttr("eliminatable", b.getBoolAttr(true)); + return extFOp; }); } @@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite( SmallVector results = (*legalized)->getResults(); for (auto [result, newType, origType] : llvm::zip_equal( results, (*legalized)->getResultTypes(), op->getResultTypes())) { - if (newType != origType) - result = rewriter.create(loc, origType, result); + if (newType != origType) { + auto truncFOp = rewriter.create(loc, origType, result); + truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true)); + result = truncFOp->getResults().front(); + } } rewriter.replaceOp(op, results); return success(); diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 90c0298fb5e46..131ee00fd7235 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -20,6 +20,7 @@ add_mlir_library(MLIRTransforms SymbolPrivatize.cpp TopologicalSort.cpp ViewOpGraph.cpp + EliminateExplicitRounding.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms @@ -38,4 +39,6 @@ add_mlir_library(MLIRTransforms MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils + MLIRArithDialect + MLIRMathDialect ) diff --git a/mlir/lib/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Transforms/EliminateExplicitRounding.cpp new file mode 100644 index 0000000000000..ae91a1ba0f24a --- /dev/null +++ b/mlir/lib/Transforms/EliminateExplicitRounding.cpp @@ -0,0 +1,85 @@ +//===- EliminateExplicitRounding.cpp - Remove redundant extf/truncf pairs -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements removing redundant extf/truncf pairs inserted from +// LegalizeToF32 and EmulateUnsupportedFloats. +// +//===----------------------------------------------------------------------===// +#include "mlir/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +// #include "mlir/IR/Types.h" +// #include "mlir/IR/Builders.h" +// #include "mlir/IR/BuiltinOps.h" +// #include "mlir/IR/Region.h" +// #include "mlir/Pass/Pass.h" +// #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_ELIMINATEEXPLICITROUNDING +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +struct EliminateExplicitRoundingRewritePattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp extfop, + PatternRewriter &rewriter) const final { + // check whether the extfop is eliminatable + auto extfAttr = extfop->getAttrOfType("eliminatable"); + if (!extfAttr || (extfAttr && !extfAttr.getValue())) return failure(); + + // check whether match `eliminatable truncf->extf` pair + auto truncfop = extfop.getOperand().getDefiningOp(); + if (!truncfop) return failure(); + auto truncfAttr = truncfop->getAttrOfType("eliminatable"); + if (!truncfAttr || (truncfAttr && !truncfAttr.getValue())) return failure(); + + // check whether the the rounding pair's input and output data type are the same + if (auto input = truncfop.getOperand()) { + auto inTy = input.getType(); + auto outTy = extfop.getType(); + if (inTy == outTy && getElementTypeOrSelf(inTy).isF32()) { + rewriter.replaceOp(extfop, {input}); + } + } + return success(); + } +}; + +struct EliminateExplicitRounding final + : impl::EliminateExplicitRoundingBase< + EliminateExplicitRounding> { + using EliminateExplicitRoundingBase::EliminateExplicitRoundingBase; + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + SmallVector ops; + getOperation()->walk([&](Operation *op) { + if (isa(op)) + ops.push_back(op); + }); + if (failed(applyOpPatternsAndFold(ops, patternSet))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::createEliminateExplicitRoundingPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir b/mlir/test/Transforms/eliminate-explicit-rounding.mlir similarity index 98% rename from mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir rename to mlir/test/Transforms/eliminate-explicit-rounding.mlir index 5ed189b0033b3..2f7765a8fe270 100644 --- a/mlir/test/Dialect/Math/canonicalize-f32-promotion.mlir +++ b/mlir/test/Transforms/eliminate-explicit-rounding.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -math-canonicalize-f32-promotion | FileCheck %s +// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -eliminate-explicit-rounding | FileCheck %s // CHECK-LABEL: @sequences // CHECK-SAME: ([[ARG0:%.+]]: bf16) From cbc176acfbe9b27661b6031609cc39f7392e52ab Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 27 May 2024 15:24:32 +0800 Subject: [PATCH 06/25] fix wording --- mlir/include/mlir/Transforms/Passes.td | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 1539bda02ac60..a99eca2a993cb 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -570,14 +570,14 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> { } def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> { - let summary = "Eliminate redundant truncf/extf pairs"; + let summary = "Eliminate the intermidiate truncf/extf pairs"; let description = [{ - `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion for every op belonging to the - illegal op list. Once there are some consecutive illegal ops, these passes - will insert redundant `arith.truncf` and `arith.extf` pairs between the illegal - ops. + `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion + for every op belonging to the illegal op list. Once there are some consecutive + illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs + between the illegal ops. - This pass is to eliminate the redundant truncf/extf pairs to improve + This pass is to eliminate the intermidiate truncf/extf pairs to improve performance. However, this pass may introduce numerical difference as the `f32->bf16` rounding @@ -602,7 +602,7 @@ def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> { %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16> return %5 : vector<32xbf16> } - // after canonicalize-f32-promotion + // after eliminate-explicit-rounding func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> %1 = math.absf %0 : vector<32xf32> From 2dcb687d5f95e88fe2380340ce63de225f21e175 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 27 May 2024 15:38:09 +0800 Subject: [PATCH 07/25] fix test --- .../Transforms/EliminateExplicitRounding.cpp | 16 ++++------ .../Arith/emulate-unsupported-floats.mlir | 32 +++++++++---------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Transforms/EliminateExplicitRounding.cpp index ae91a1ba0f24a..4731b5a15f415 100644 --- a/mlir/lib/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Transforms/EliminateExplicitRounding.cpp @@ -16,12 +16,6 @@ #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" -// #include "mlir/IR/Types.h" -// #include "mlir/IR/Builders.h" -// #include "mlir/IR/BuiltinOps.h" -// #include "mlir/IR/Region.h" -// #include "mlir/Pass/Pass.h" -// #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -48,12 +42,16 @@ struct EliminateExplicitRoundingRewritePattern final auto truncfAttr = truncfop->getAttrOfType("eliminatable"); if (!truncfAttr || (truncfAttr && !truncfAttr.getValue())) return failure(); - // check whether the the rounding pair's input and output data type are the same + // check whether the the rounding pair's input and output data type are the + // same Currently only consider to eliminate rounding pairs for (bf16 / f16 + // <-> f32) if (auto input = truncfop.getOperand()) { auto inTy = input.getType(); auto outTy = extfop.getType(); - if (inTy == outTy && getElementTypeOrSelf(inTy).isF32()) { - rewriter.replaceOp(extfop, {input}); + auto shortTy = getElementTypeOrSelf(truncfop.getType()); + if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() && + (shortTy.isF16() || shortTy.isBF16())) { + rewriter.replaceOp(extfop, {input}); } } return success(); diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir index a69ef131d8d47..76952297a5452 100644 --- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir +++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir @@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 { // CHECK-LABEL: @basic_expansion // CHECK-SAME: [[X:%.+]]: bf16 // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16 -// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32 -// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32 +// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {eliminatable = true} : bf16 to f32 +// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] {eliminatable = true} : bf16 to f32 // CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32 -// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16 +// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] {eliminatable = true} : f32 to bf16 // CHECK: return [[Y]] %c = arith.constant 1.0 : bf16 %y = arith.addf %x, %c : bf16 @@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 { func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { // CHECK-LABEL: @chained // CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16 -// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32 -// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32 -// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32 +// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {eliminatable = true} : bf16 to f32 +// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] {eliminatable = true} : bf16 to f32 +// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] {eliminatable = true} : bf16 to f32 // CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32 -// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16 -// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32 +// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] {eliminatable = true} : f32 to bf16 +// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] {eliminatable = true} : bf16 to f32 // CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]] -// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16 -// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32 +// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] {eliminatable = true} : f32 to bf16 +// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] {eliminatable = true} : bf16 to f32 // CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32 // CHECK: return [[RES]] %p = arith.addf %x, %y : bf16 @@ -41,12 +41,12 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { // CHECK-LABEL: @memops // CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ> -// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32 +// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] {eliminatable = true} : f8E4M3FNUZ to f32 // CHECK: memref.store [[V]] // CHECK: [[W:%.+]] = memref.load -// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32 +// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] {eliminatable = true} : f8E4M3FNUZ to f32 // CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32 -// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ +// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] {eliminatable = true} : f32 to f8E4M3FNUZ // CHECK: memref.store [[X]] %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -63,10 +63,10 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { // CHECK-LABEL: @vectors // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ> -// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32> +// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] {eliminatable = true} : vector<4xf8E4M3FNUZ> to vector<4xf32> // CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32> -// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ> -// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32> +// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] {eliminatable = true} : vector<4xf32> to vector<4xf8E4M3FNUZ> +// CHECK: [[RET:%.+]] = arith.extf [[B]] {eliminatable = true} : vector<4xf8E4M3FNUZ> to vector<4xf32> // CHECK: return [[RET]] %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> From 336e0eba4a671a1a51e7738f6031764023f73d55 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 27 May 2024 16:53:27 +0800 Subject: [PATCH 08/25] move to arith dialect and add optional filter func --- .../mlir/Dialect/Arith/Transforms/Passes.td | 46 ++++++++++++++ mlir/include/mlir/Transforms/Passes.h | 5 -- mlir/include/mlir/Transforms/Passes.td | 48 --------------- .../Dialect/Arith/Transforms/CMakeLists.txt | 1 + .../Transforms/EliminateExplicitRounding.cpp | 61 ++++++++++--------- .../Transforms/EmulateUnsupportedFloats.cpp | 11 +--- .../Dialect/Math/Transforms/LegalizeToF32.cpp | 11 +--- mlir/lib/Transforms/CMakeLists.txt | 3 - .../Arith/emulate-unsupported-floats.mlir | 32 +++++----- 9 files changed, 101 insertions(+), 117 deletions(-) rename mlir/lib/{ => Dialect/Arith}/Transforms/EliminateExplicitRounding.cpp (55%) diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 4096e309199e9..d0d614078619e 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -117,4 +117,50 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> { ]; } + def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> { + let summary = "Eliminate the intermidiate truncf/extf pairs"; + let description = [{ + `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion + for every op belonging to the illegal op list. Once there are some consecutive + illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs + between the illegal ops. + + This pass is to eliminate the intermidiate truncf/extf pairs to improve + performance. + + However, this pass may introduce numerical difference as the `f32->bf16` rounding + is eliminated. + + Example: + + ```mlir + // the initial func + func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { + %0 = math.absf %arg0 : vector<32xbf16> + %1 = math.sin %0 : vector<32xbf16> + return %1 : vector<32xbf16> + } + // after legalize-to-f32 + func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { + %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> + %1 = math.absf %0 : vector<32xf32> + %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16> + %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32> + %4 = math.sin %3 : vector<32xf32> + %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16> + return %5 : vector<32xbf16> + } + // after eliminate-explicit-rounding + func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { + %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> + %1 = math.absf %0 : vector<32xf32> + %2 = math.sin %1 : vector<32xf32> + %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16> + return %3 : vector<32xbf16> + } + ``` + + }]; +} + #endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index c618fff9a8040..58bd61b2ae8b8 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -44,7 +44,6 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_SYMBOLPRIVATIZE #define GEN_PASS_DECL_TOPOLOGICALSORT #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS -#define GEN_PASS_DECL_ELIMINATEEXPLICITROUNDING #include "mlir/Transforms/Passes.h.inc" /// Creates an instance of the Canonicalizer pass, configured with default @@ -138,10 +137,6 @@ std::unique_ptr createCompositeFixedPointPass( std::string name, llvm::function_ref populateFunc, int maxIterations = 10); -/// Create eliminate-explicit-rounding pass, which eliminates the redundant -/// truncf/extf pairs to improve performance. -std::unique_ptr createEliminateExplicitRoundingPass(); - //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index a99eca2a993cb..1b40a87c63f27 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -569,52 +569,4 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> { ]; } -def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> { - let summary = "Eliminate the intermidiate truncf/extf pairs"; - let description = [{ - `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion - for every op belonging to the illegal op list. Once there are some consecutive - illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs - between the illegal ops. - - This pass is to eliminate the intermidiate truncf/extf pairs to improve - performance. - - However, this pass may introduce numerical difference as the `f32->bf16` rounding - is eliminated. - - Example: - - ```mlir - // the initial func - func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { - %0 = math.absf %arg0 : vector<32xbf16> - %1 = math.sin %0 : vector<32xbf16> - return %1 : vector<32xbf16> - } - // after legalize-to-f32 - func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { - %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> - %1 = math.absf %0 : vector<32xf32> - %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16> - %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32> - %4 = math.sin %3 : vector<32xf32> - %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16> - return %5 : vector<32xbf16> - } - // after eliminate-explicit-rounding - func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { - %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> - %1 = math.absf %0 : vector<32xf32> - %2 = math.sin %1 : vector<32xf32> - %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16> - return %3 : vector<32xbf16> - } - ``` - - }]; - let constructor = "mlir::createEliminateExplicitRoundingPass()"; -} - - #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 12659eaba1fa5..a12da70bae9af 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRArithTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp BufferViewFlowOpInterfaceImpl.cpp + EliminateExplicitRounding.cpp EmulateUnsupportedFloats.cpp EmulateWideInt.cpp EmulateNarrowType.cpp diff --git a/mlir/lib/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp similarity index 55% rename from mlir/lib/Transforms/EliminateExplicitRounding.cpp rename to mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp index 4731b5a15f415..922531e976252 100644 --- a/mlir/lib/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp @@ -1,4 +1,5 @@ -//===- EliminateExplicitRounding.cpp - Remove redundant extf/truncf pairs -===// +//===- EliminateExplicitRounding.cpp - Remove intermediate extf/truncf pairs +//-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,21 +7,22 @@ // //===----------------------------------------------------------------------===// // -// This file implements removing redundant extf/truncf pairs inserted from -// LegalizeToF32 and EmulateUnsupportedFloats. +// This file implements removing intermediate extf/truncf pairs inserted from +// type conversion. // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/Passes.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" + #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { +namespace arith { #define GEN_PASS_DEF_ELIMINATEEXPLICITROUNDING -#include "mlir/Transforms/Passes.h.inc" +#include "mlir/Dialect/Arith/Transforms/Passes.h.inc" +} // namespace arith } // namespace mlir using namespace mlir; @@ -30,37 +32,42 @@ namespace { struct EliminateExplicitRoundingRewritePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; + using FilterFunction = std::function; + + EliminateExplicitRoundingRewritePattern(MLIRContext *context, + FilterFunction filterFunc = nullptr) + : OpRewritePattern(context), filterFunc(filterFunc) {} + LogicalResult matchAndRewrite(arith::ExtFOp extfop, PatternRewriter &rewriter) const final { - // check whether the extfop is eliminatable - auto extfAttr = extfop->getAttrOfType("eliminatable"); - if (!extfAttr || (extfAttr && !extfAttr.getValue())) return failure(); - - // check whether match `eliminatable truncf->extf` pair + if (filterFunc && filterFunc(extfop)) + return failure(); + // check whether match `truncf->extf` pair auto truncfop = extfop.getOperand().getDefiningOp(); - if (!truncfop) return failure(); - auto truncfAttr = truncfop->getAttrOfType("eliminatable"); - if (!truncfAttr || (truncfAttr && !truncfAttr.getValue())) return failure(); + if (!truncfop) + return failure(); // check whether the the rounding pair's input and output data type are the - // same Currently only consider to eliminate rounding pairs for (bf16 / f16 + // same. Currently only consider to eliminate rounding pairs for (bf16 / f16 // <-> f32) if (auto input = truncfop.getOperand()) { - auto inTy = input.getType(); - auto outTy = extfop.getType(); - auto shortTy = getElementTypeOrSelf(truncfop.getType()); - if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() && - (shortTy.isF16() || shortTy.isBF16())) { - rewriter.replaceOp(extfop, {input}); - } + auto inTy = input.getType(); + auto outTy = extfop.getType(); + auto shortTy = getElementTypeOrSelf(truncfop.getType()); + if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() && + (shortTy.isF16() || shortTy.isBF16())) { + rewriter.replaceOp(extfop, {input}); + } } return success(); } + +private: + FilterFunction filterFunc; }; struct EliminateExplicitRounding final - : impl::EliminateExplicitRoundingBase< - EliminateExplicitRounding> { + : arith::impl::EliminateExplicitRoundingBase { using EliminateExplicitRoundingBase::EliminateExplicitRoundingBase; void runOnOperation() override { RewritePatternSet patterns(&getContext()); @@ -77,7 +84,3 @@ struct EliminateExplicitRounding final }; } // namespace - -std::unique_ptr mlir::createEliminateExplicitRoundingPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 9cbb3884659ee..4a50da3513f99 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -94,11 +94,8 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef operands, SmallVector newResults(expandedOp->getResults()); for (auto [res, oldType, newType] : llvm::zip_equal( MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { - if (oldType != newType) { - auto truncFOp = rewriter.create(loc, oldType, res); - truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true)); - res = truncFOp->getResults().front(); - } + if (oldType != newType) + res = rewriter.create(loc, oldType, res); } rewriter.replaceOp(op, newResults); } @@ -117,9 +114,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions( }); converter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - auto extFOp = b.create(loc, target, input); - extFOp->setAttr("eliminatable", b.getBoolAttr(true)); - return extFOp; + return b.create(loc, target, input); }); } diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp index da049602bc909..5998133b7eab8 100644 --- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp +++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp @@ -57,9 +57,7 @@ void mlir::math::populateLegalizeToF32TypeConverter( }); typeConverter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - auto extFOp = b.create(loc, target, input); - extFOp->setAttr("eliminatable", b.getBoolAttr(true)); - return extFOp; + return b.create(loc, target, input); }); } @@ -86,11 +84,8 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite( SmallVector results = (*legalized)->getResults(); for (auto [result, newType, origType] : llvm::zip_equal( results, (*legalized)->getResultTypes(), op->getResultTypes())) { - if (newType != origType) { - auto truncFOp = rewriter.create(loc, origType, result); - truncFOp->setAttr("eliminatable", rewriter.getBoolAttr(true)); - result = truncFOp->getResults().front(); - } + if (newType != origType) + result = rewriter.create(loc, origType, result); } rewriter.replaceOp(op, results); return success(); diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 131ee00fd7235..90c0298fb5e46 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -20,7 +20,6 @@ add_mlir_library(MLIRTransforms SymbolPrivatize.cpp TopologicalSort.cpp ViewOpGraph.cpp - EliminateExplicitRounding.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms @@ -39,6 +38,4 @@ add_mlir_library(MLIRTransforms MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils - MLIRArithDialect - MLIRMathDialect ) diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir index 76952297a5452..a69ef131d8d47 100644 --- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir +++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir @@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 { // CHECK-LABEL: @basic_expansion // CHECK-SAME: [[X:%.+]]: bf16 // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16 -// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {eliminatable = true} : bf16 to f32 -// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] {eliminatable = true} : bf16 to f32 +// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32 +// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32 // CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32 -// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] {eliminatable = true} : f32 to bf16 +// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16 // CHECK: return [[Y]] %c = arith.constant 1.0 : bf16 %y = arith.addf %x, %c : bf16 @@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 { func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { // CHECK-LABEL: @chained // CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16 -// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {eliminatable = true} : bf16 to f32 -// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] {eliminatable = true} : bf16 to f32 -// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] {eliminatable = true} : bf16 to f32 +// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32 +// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32 +// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32 // CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32 -// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] {eliminatable = true} : f32 to bf16 -// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] {eliminatable = true} : bf16 to f32 +// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16 +// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32 // CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]] -// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] {eliminatable = true} : f32 to bf16 -// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] {eliminatable = true} : bf16 to f32 +// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16 +// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32 // CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32 // CHECK: return [[RES]] %p = arith.addf %x, %y : bf16 @@ -41,12 +41,12 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { // CHECK-LABEL: @memops // CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ> -// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] {eliminatable = true} : f8E4M3FNUZ to f32 +// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32 // CHECK: memref.store [[V]] // CHECK: [[W:%.+]] = memref.load -// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] {eliminatable = true} : f8E4M3FNUZ to f32 +// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32 // CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32 -// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] {eliminatable = true} : f32 to f8E4M3FNUZ +// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ // CHECK: memref.store [[X]] %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -63,10 +63,10 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { // CHECK-LABEL: @vectors // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ> -// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] {eliminatable = true} : vector<4xf8E4M3FNUZ> to vector<4xf32> +// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32> // CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32> -// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] {eliminatable = true} : vector<4xf32> to vector<4xf8E4M3FNUZ> -// CHECK: [[RET:%.+]] = arith.extf [[B]] {eliminatable = true} : vector<4xf8E4M3FNUZ> to vector<4xf32> +// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ> +// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32> // CHECK: return [[RET]] %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> From 92e809b0c2c648a8cb796111b2bc5dbe979f0d1a Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 27 May 2024 19:05:36 +0800 Subject: [PATCH 09/25] fix comment --- .../Transforms/EliminateExplicitRounding.cpp | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp index 922531e976252..908d358857013 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp @@ -38,28 +38,32 @@ struct EliminateExplicitRoundingRewritePattern final FilterFunction filterFunc = nullptr) : OpRewritePattern(context), filterFunc(filterFunc) {} - LogicalResult matchAndRewrite(arith::ExtFOp extfop, + LogicalResult matchAndRewrite(arith::ExtFOp extFOp, PatternRewriter &rewriter) const final { - if (filterFunc && filterFunc(extfop)) + // check whether match `truncF->extF` pair + auto truncFOp = extFOp.getOperand().getDefiningOp(); + if (!truncFOp) return failure(); - // check whether match `truncf->extf` pair - auto truncfop = extfop.getOperand().getDefiningOp(); - if (!truncfop) + + // check whether need to filter out + if (filterFunc && filterFunc(extFOp)) return failure(); - // check whether the the rounding pair's input and output data type are the + // check whether the rounding pair's input and output data type are the // same. Currently only consider to eliminate rounding pairs for (bf16 / f16 // <-> f32) - if (auto input = truncfop.getOperand()) { + if (auto input = truncFOp.getOperand()) { auto inTy = input.getType(); - auto outTy = extfop.getType(); - auto shortTy = getElementTypeOrSelf(truncfop.getType()); + auto outTy = extFOp.getType(); + auto shortTy = getElementTypeOrSelf(truncFOp.getType()); if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() && (shortTy.isF16() || shortTy.isBF16())) { - rewriter.replaceOp(extfop, {input}); + rewriter.replaceOp(extFOp, {input}); + return success(); } } - return success(); + + return failure(); } private: From 5583436e89a852a5141b403ea1f1ee19dbc88d8e Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 27 May 2024 20:12:41 +0800 Subject: [PATCH 10/25] remove unnecessary if --- .../Transforms/EliminateExplicitRounding.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp index 908d358857013..bf510f9671c01 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp @@ -52,15 +52,14 @@ struct EliminateExplicitRoundingRewritePattern final // check whether the rounding pair's input and output data type are the // same. Currently only consider to eliminate rounding pairs for (bf16 / f16 // <-> f32) - if (auto input = truncFOp.getOperand()) { - auto inTy = input.getType(); - auto outTy = extFOp.getType(); - auto shortTy = getElementTypeOrSelf(truncFOp.getType()); - if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() && - (shortTy.isF16() || shortTy.isBF16())) { - rewriter.replaceOp(extFOp, {input}); - return success(); - } + auto input = truncFOp.getOperand(); + auto inTy = input.getType(); + auto outTy = extFOp.getType(); + auto shortTy = getElementTypeOrSelf(truncFOp.getType()); + if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() && + (shortTy.isF16() || shortTy.isBF16())) { + rewriter.replaceOp(extFOp, {input}); + return success(); } return failure(); From c9e0e8bd3fd63a849865ca8517cfb63e4c8f9a81 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Wed, 29 May 2024 13:21:30 +0800 Subject: [PATCH 11/25] add test case --- .../Arith}/eliminate-explicit-rounding.mlir | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) rename mlir/test/{Transforms => Dialect/Arith}/eliminate-explicit-rounding.mlir (74%) diff --git a/mlir/test/Transforms/eliminate-explicit-rounding.mlir b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir similarity index 74% rename from mlir/test/Transforms/eliminate-explicit-rounding.mlir rename to mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir index 2f7765a8fe270..70f9570235b56 100644 --- a/mlir/test/Transforms/eliminate-explicit-rounding.mlir +++ b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir @@ -71,3 +71,22 @@ func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xb %3 = arith.addf %1, %2 : vector<32x32x32xbf16> return %3 : vector<32x32x32xbf16> } + +// CHECK-LABEL: @bf16_fma +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]] +// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]] +// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]] +// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]] +// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]] +// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]] +// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16> +func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = math.absf %arg0 : vector<32x32x32xbf16> + %1 = math.sin %0 : vector<32x32x32xbf16> + %2 = math.fma %1, %arg1, %arg2 : vector<32x32x32xbf16> + %3 = arith.addf %2, %1 : vector<32x32x32xbf16> + return %3 : vector<32x32x32xbf16> +} From 923b4513c13c44cb739a2335f665d8b7fa3ec902 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Wed, 29 May 2024 21:13:10 +0800 Subject: [PATCH 12/25] fix comments --- .../Transforms/EliminateExplicitRounding.cpp | 13 ++-- .../Arith/eliminate-explicit-rounding.mlir | 67 +++++++++++++------ 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp index bf510f9671c01..5b2d243eac29d 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp @@ -40,18 +40,18 @@ struct EliminateExplicitRoundingRewritePattern final LogicalResult matchAndRewrite(arith::ExtFOp extFOp, PatternRewriter &rewriter) const final { - // check whether match `truncF->extF` pair + // Check whether match `truncF->extF` pair. auto truncFOp = extFOp.getOperand().getDefiningOp(); if (!truncFOp) return failure(); - // check whether need to filter out + // Check whether need to filter out. if (filterFunc && filterFunc(extFOp)) return failure(); - // check whether the rounding pair's input and output data type are the + // Check whether the rounding pair's input and output data type are the // same. Currently only consider to eliminate rounding pairs for (bf16 / f16 - // <-> f32) + // <-> f32). auto input = truncFOp.getOperand(); auto inTy = input.getType(); auto outTy = extFOp.getType(); @@ -77,10 +77,7 @@ struct EliminateExplicitRounding final patterns.insert(&getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); SmallVector ops; - getOperation()->walk([&](Operation *op) { - if (isa(op)) - ops.push_back(op); - }); + getOperation()->walk([&](arith::ExtFOp op) { ops.push_back(op); }); if (failed(applyOpPatternsAndFold(ops, patternSet))) signalPassFailure(); } diff --git a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir index 70f9570235b56..55cf4fdadd922 100644 --- a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir +++ b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 --arith-emulate-unsupported-floats="source-types=bf16 target-type=f32" -eliminate-explicit-rounding | FileCheck %s +// RUN: mlir-opt %s --split-input-file --eliminate-explicit-rounding | FileCheck %s // CHECK-LABEL: @sequences // CHECK-SAME: ([[ARG0:%.+]]: bf16) @@ -8,9 +8,13 @@ // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] // CHECK: return [[TRUNCF]] : bf16 func.func @sequences(%arg0: bf16) -> bf16 { - %0 = math.absf %arg0 : bf16 - %1 = math.sin %0 : bf16 - return %1 : bf16 + %0 = arith.extf %arg0 : bf16 to f32 + %1 = math.absf %0 : f32 + %2 = arith.truncf %1 : f32 to bf16 + %3 = arith.extf %2 : bf16 to f32 + %4 = math.sin %3 : f32 + %5 = arith.truncf %4 : f32 to bf16 + return %5 : bf16 } // CHECK-LABEL: @eliminatecastoncastf16 @@ -37,9 +41,13 @@ func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 { // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] // CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { - %0 = math.absf %arg0 : vector<32x32x32xbf16> - %1 = math.sin %0 : vector<32x32x32xbf16> - return %1 : vector<32x32x32xbf16> + %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> + return %5 : vector<32x32x32xbf16> } // CHECK-LABEL: @f16_sin_vector @@ -50,9 +58,13 @@ func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16 // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] // CHECK: return [[TRUNCF]] : vector<32x32x32xf16> func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { - %0 = math.absf %arg0 : vector<32x32x32xf16> - %1 = math.sin %0 : vector<32x32x32xf16> - return %1 : vector<32x32x32xf16> + %0 = arith.extf %arg0 : vector<32x32x32xf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xf16> + %3 = arith.extf %2 : vector<32x32x32xf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xf16> + return %5 : vector<32x32x32xf16> } // CHECK-LABEL: @bf16_branch_vector @@ -65,11 +77,19 @@ func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { // CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]] // CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { - %0 = math.absf %arg0 : vector<32x32x32xbf16> - %1 = math.sin %0 : vector<32x32x32xbf16> - %2 = math.cos %0 : vector<32x32x32xbf16> - %3 = arith.addf %1, %2 : vector<32x32x32xbf16> - return %3 : vector<32x32x32xbf16> + %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %7 = math.cos %3 : vector<32x32x32xf32> + %8 = arith.truncf %7 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %9 = arith.extf %8 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %10 = arith.addf %6, %9 : vector<32x32x32xf32> + %11 = arith.truncf %10 : vector<32x32x32xf32> to vector<32x32x32xbf16> + return %11 : vector<32x32x32xbf16> } // CHECK-LABEL: @bf16_fma @@ -84,9 +104,16 @@ func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xb // CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]] // CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16> func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { - %0 = math.absf %arg0 : vector<32x32x32xbf16> - %1 = math.sin %0 : vector<32x32x32xbf16> - %2 = math.fma %1, %arg1, %arg2 : vector<32x32x32xbf16> - %3 = arith.addf %2, %1 : vector<32x32x32xbf16> - return %3 : vector<32x32x32xbf16> + %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16> + %8 = arith.extf %7 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %9 = arith.addf %8, %6 : vector<32x32x32xf32> + %10 = arith.truncf %9 : vector<32x32x32xf32> to vector<32x32x32xbf16> + return %10 : vector<32x32x32xbf16> } From a2c2e012f35bb49be358bd665691b4ddac1bc183 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Thu, 30 May 2024 10:10:36 +0800 Subject: [PATCH 13/25] fix typo --- mlir/include/mlir/Dialect/Arith/Transforms/Passes.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index d0d614078619e..6fc89cc91b740 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -125,7 +125,7 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> { illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs between the illegal ops. - This pass is to eliminate the intermidiate truncf/extf pairs to improve + This pass is to eliminate the intermediate truncf/extf pairs to improve performance. However, this pass may introduce numerical difference as the `f32->bf16` rounding From 345bd9cb8b121e88d2c49f2f22b4fa7da2f312ce Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Thu, 30 May 2024 10:40:59 +0800 Subject: [PATCH 14/25] fix comment --- .../Transforms/EliminateExplicitRounding.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp index 5b2d243eac29d..8a5f10a6cbbb0 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp @@ -46,18 +46,20 @@ struct EliminateExplicitRoundingRewritePattern final return failure(); // Check whether need to filter out. - if (filterFunc && filterFunc(extFOp)) + if (filterFunc && filterFunc(extFOp)) { + extFOp.emitError("Operation filtered out by filterFunc"); return failure(); + } // Check whether the rounding pair's input and output data type are the // same. Currently only consider to eliminate rounding pairs for (bf16 / f16 // <-> f32). - auto input = truncFOp.getOperand(); - auto inTy = input.getType(); - auto outTy = extFOp.getType(); - auto shortTy = getElementTypeOrSelf(truncFOp.getType()); - if (inTy == outTy && getElementTypeOrSelf(inTy).isF32() && - (shortTy.isF16() || shortTy.isBF16())) { + Value input = truncFOp.getOperand(); + Type inTy = getElementTypeOrSelf(input.getType()); + Type outTy = getElementTypeOrSelf(extFOp.getType()); + Type shortTy = getElementTypeOrSelf(truncFOp.getType()); + if (isa(inTy) && isa(outTy) && + (isa(shortTy) || isa(shortTy))) { rewriter.replaceOp(extFOp, {input}); return success(); } From e6fd571eb8f7dde4eee9127b681010a66282f27a Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Thu, 30 May 2024 11:55:43 +0800 Subject: [PATCH 15/25] fix --- .../mlir/Dialect/Arith/Transforms/Passes.td | 2 +- .../Transforms/EliminateExplicitRounding.cpp | 18 ++++++++---------- .../Arith/eliminate-explicit-rounding.mlir | 2 +- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 6fc89cc91b740..7afec9f752cfa 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -117,7 +117,7 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> { ]; } - def EliminateExplicitRounding : Pass<"eliminate-explicit-rounding"> { + def ArithEliminateExplicitRounding : Pass<"arith-eliminate-explicit-rounding"> { let summary = "Eliminate the intermidiate truncf/extf pairs"; let description = [{ `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp index 8a5f10a6cbbb0..5ab540fa0e9fb 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp @@ -20,7 +20,7 @@ namespace mlir { namespace arith { -#define GEN_PASS_DEF_ELIMINATEEXPLICITROUNDING +#define GEN_PASS_DEF_ARITHELIMINATEEXPLICITROUNDING #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" } // namespace arith } // namespace mlir @@ -34,10 +34,6 @@ struct EliminateExplicitRoundingRewritePattern final using OpRewritePattern::OpRewritePattern; using FilterFunction = std::function; - EliminateExplicitRoundingRewritePattern(MLIRContext *context, - FilterFunction filterFunc = nullptr) - : OpRewritePattern(context), filterFunc(filterFunc) {} - LogicalResult matchAndRewrite(arith::ExtFOp extFOp, PatternRewriter &rewriter) const final { // Check whether match `truncF->extF` pair. @@ -47,8 +43,9 @@ struct EliminateExplicitRoundingRewritePattern final // Check whether need to filter out. if (filterFunc && filterFunc(extFOp)) { - extFOp.emitError("Operation filtered out by filterFunc"); - return failure(); + return rewriter.notifyMatchFailure(extFOp, [](Diagnostic &diag) { + diag << "Operation filtered out by filterFunc"; + }); } // Check whether the rounding pair's input and output data type are the @@ -59,7 +56,7 @@ struct EliminateExplicitRoundingRewritePattern final Type outTy = getElementTypeOrSelf(extFOp.getType()); Type shortTy = getElementTypeOrSelf(truncFOp.getType()); if (isa(inTy) && isa(outTy) && - (isa(shortTy) || isa(shortTy))) { + (isa(shortTy))) { rewriter.replaceOp(extFOp, {input}); return success(); } @@ -72,8 +69,9 @@ struct EliminateExplicitRoundingRewritePattern final }; struct EliminateExplicitRounding final - : arith::impl::EliminateExplicitRoundingBase { - using EliminateExplicitRoundingBase::EliminateExplicitRoundingBase; + : arith::impl::ArithEliminateExplicitRoundingBase< + EliminateExplicitRounding> { + using ArithEliminateExplicitRoundingBase::ArithEliminateExplicitRoundingBase; void runOnOperation() override { RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); diff --git a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir index 55cf4fdadd922..f2ba276a4f7bb 100644 --- a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir +++ b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --split-input-file --eliminate-explicit-rounding | FileCheck %s +// RUN: mlir-opt %s --split-input-file --arith-eliminate-explicit-rounding | FileCheck %s // CHECK-LABEL: @sequences // CHECK-SAME: ([[ARG0:%.+]]: bf16) From ab80bafa7d9936e59d7050beeca3679555da7d4d Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Thu, 30 May 2024 13:40:49 +0800 Subject: [PATCH 16/25] remove filter func --- .../Arith/Transforms/EliminateExplicitRounding.cpp | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp index 5ab540fa0e9fb..6b2bdd1404bd6 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp @@ -32,7 +32,6 @@ namespace { struct EliminateExplicitRoundingRewritePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - using FilterFunction = std::function; LogicalResult matchAndRewrite(arith::ExtFOp extFOp, PatternRewriter &rewriter) const final { @@ -41,13 +40,6 @@ struct EliminateExplicitRoundingRewritePattern final if (!truncFOp) return failure(); - // Check whether need to filter out. - if (filterFunc && filterFunc(extFOp)) { - return rewriter.notifyMatchFailure(extFOp, [](Diagnostic &diag) { - diag << "Operation filtered out by filterFunc"; - }); - } - // Check whether the rounding pair's input and output data type are the // same. Currently only consider to eliminate rounding pairs for (bf16 / f16 // <-> f32). @@ -63,9 +55,6 @@ struct EliminateExplicitRoundingRewritePattern final return failure(); } - -private: - FilterFunction filterFunc; }; struct EliminateExplicitRounding final From 961f6f8798b4bddde5ef83547b941ca50e95b8b1 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 3 Jun 2024 10:26:18 +0800 Subject: [PATCH 17/25] do not use pattern --- .../Transforms/EliminateExplicitRounding.cpp | 57 +++++++------------ 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp index 6b2bdd1404bd6..b341dfd40ed4f 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp @@ -29,46 +29,33 @@ using namespace mlir; namespace { -struct EliminateExplicitRoundingRewritePattern final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(arith::ExtFOp extFOp, - PatternRewriter &rewriter) const final { - // Check whether match `truncF->extF` pair. - auto truncFOp = extFOp.getOperand().getDefiningOp(); - if (!truncFOp) - return failure(); - - // Check whether the rounding pair's input and output data type are the - // same. Currently only consider to eliminate rounding pairs for (bf16 / f16 - // <-> f32). - Value input = truncFOp.getOperand(); - Type inTy = getElementTypeOrSelf(input.getType()); - Type outTy = getElementTypeOrSelf(extFOp.getType()); - Type shortTy = getElementTypeOrSelf(truncFOp.getType()); - if (isa(inTy) && isa(outTy) && - (isa(shortTy))) { - rewriter.replaceOp(extFOp, {input}); - return success(); - } - - return failure(); - } -}; - struct EliminateExplicitRounding final : arith::impl::ArithEliminateExplicitRoundingBase< EliminateExplicitRounding> { using ArithEliminateExplicitRoundingBase::ArithEliminateExplicitRoundingBase; void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - SmallVector ops; - getOperation()->walk([&](arith::ExtFOp op) { ops.push_back(op); }); - if (failed(applyOpPatternsAndFold(ops, patternSet))) - signalPassFailure(); + getOperation()->walk([&](arith::ExtFOp extFOp) { + // Check whether match `truncF->extF` pair. + auto truncFOp = extFOp.getOperand().getDefiningOp(); + if (truncFOp) { + // Check whether the rounding pair's input and output data type are the + // same. Currently only consider to eliminate rounding pairs for (bf16 / + // f16 + // <-> f32). + Value input = truncFOp.getOperand(); + Type inTy = getElementTypeOrSelf(input.getType()); + Type outTy = getElementTypeOrSelf(extFOp.getType()); + Type shortTy = getElementTypeOrSelf(truncFOp.getType()); + if (isa(inTy) && isa(outTy) && + (isa(shortTy))) { + for (auto &use : + llvm::make_early_inc_range(extFOp.getResult().getUses())) { + use.set(input); + } + extFOp.erase(); + } + } + }); } }; From 8e6de0d46b9393fb01ea82be959ea7cdb80bcfa1 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 3 Jun 2024 14:02:26 +0800 Subject: [PATCH 18/25] fix comment --- .../Transforms/EliminateExplicitRounding.cpp | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp index b341dfd40ed4f..7df77629127fa 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp @@ -37,23 +37,21 @@ struct EliminateExplicitRounding final getOperation()->walk([&](arith::ExtFOp extFOp) { // Check whether match `truncF->extF` pair. auto truncFOp = extFOp.getOperand().getDefiningOp(); - if (truncFOp) { - // Check whether the rounding pair's input and output data type are the - // same. Currently only consider to eliminate rounding pairs for (bf16 / - // f16 - // <-> f32). - Value input = truncFOp.getOperand(); - Type inTy = getElementTypeOrSelf(input.getType()); - Type outTy = getElementTypeOrSelf(extFOp.getType()); - Type shortTy = getElementTypeOrSelf(truncFOp.getType()); - if (isa(inTy) && isa(outTy) && - (isa(shortTy))) { - for (auto &use : - llvm::make_early_inc_range(extFOp.getResult().getUses())) { - use.set(input); - } - extFOp.erase(); - } + if (!truncFOp) + return; + // Check whether the rounding pair's input and output data type are the + // same. Currently only consider to eliminate rounding pairs for (bf16 / + // f16 <-> f32). + Value input = truncFOp.getOperand(); + Type inTy = getElementTypeOrSelf(input.getType()); + Type outTy = getElementTypeOrSelf(extFOp.getType()); + Type shortTy = getElementTypeOrSelf(truncFOp.getType()); + if (isa(inTy) && isa(outTy) && + (isa(shortTy))) { + extFOp.replaceAllUsesWith(input); + extFOp.erase(); + if (truncFOp.getResult().getUses().empty()) + truncFOp.erase(); } }); } From 66fef95ac3d07eebfae4b126bb6dca4359329e16 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Fri, 7 Jun 2024 00:48:47 +0800 Subject: [PATCH 19/25] add fastmath flag attrs and use canonicalizer --- .../include/mlir/Dialect/Arith/IR/ArithOps.td | 13 +- .../mlir/Dialect/Arith/Transforms/Passes.td | 46 ------- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 33 +++++ .../Dialect/Arith/Transforms/CMakeLists.txt | 1 - .../Transforms/EliminateExplicitRounding.cpp | 60 --------- .../Arith/eliminate-explicit-rounding.mlir | 119 ------------------ mlir/test/Transforms/canonicalize.mlir | 118 +++++++++++++++++ 7 files changed, 163 insertions(+), 227 deletions(-) delete mode 100644 mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp delete mode 100644 mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index ead52332e8eec..6fff83dc3df7f 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1195,6 +1195,14 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf"> { }]; let hasVerifier = 1; let hasFolder = 1; + let hasCanonicalizer = 1; + + let arguments = (ins FloatLike:$in, DefaultValuedAttr< + Arith_FastMathAttr, + "::mlir::arith::FastMathFlags::contract">:$fastmath); + let results = (outs FloatLike:$out); + + let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)"; } //===----------------------------------------------------------------------===// @@ -1235,7 +1243,10 @@ def Arith_TruncFOp : DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]>, Arguments<(ins FloatLike:$in, - OptionalAttr:$roundingmode)>, + OptionalAttr:$roundingmode, + DefaultValuedAttr< + Arith_FastMathAttr, + "::mlir::arith::FastMathFlags::contract">:$fastmath)>, Results<(outs FloatLike:$out)> { let summary = "cast from floating-point to narrower floating-point"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 7afec9f752cfa..4096e309199e9 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -117,50 +117,4 @@ def ArithIntNarrowing : Pass<"arith-int-narrowing"> { ]; } - def ArithEliminateExplicitRounding : Pass<"arith-eliminate-explicit-rounding"> { - let summary = "Eliminate the intermidiate truncf/extf pairs"; - let description = [{ - `legalize-to-f32` and `arith-emulate-unsupported-floats` pass does f32 promotion - for every op belonging to the illegal op list. Once there are some consecutive - illegal ops, these passes will insert `arith.truncf` and `arith.extf` pairs - between the illegal ops. - - This pass is to eliminate the intermediate truncf/extf pairs to improve - performance. - - However, this pass may introduce numerical difference as the `f32->bf16` rounding - is eliminated. - - Example: - - ```mlir - // the initial func - func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { - %0 = math.absf %arg0 : vector<32xbf16> - %1 = math.sin %0 : vector<32xbf16> - return %1 : vector<32xbf16> - } - // after legalize-to-f32 - func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { - %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> - %1 = math.absf %0 : vector<32xf32> - %2 = arith.truncf %1 : vector<32xf32> to vector<32xbf16> - %3 = arith.extf %2 : vector<32xbf16> to vector<32xf32> - %4 = math.sin %3 : vector<32xf32> - %5 = arith.truncf %4 : vector<32xf32> to vector<32xbf16> - return %5 : vector<32xbf16> - } - // after eliminate-explicit-rounding - func.func @bf16_sin_vector(%arg0: vector<32xbf16>) -> vector<32xbf16> { - %0 = arith.extf %arg0 : vector<32xbf16> to vector<32xf32> - %1 = math.absf %0 : vector<32xf32> - %2 = math.sin %1 : vector<32xf32> - %3 = arith.truncf %2 : vector<32xf32> to vector<32xbf16> - return %3 : vector<32xbf16> - } - ``` - - }]; -} - #endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index a0b50251c6b67..1a135668a23e6 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1410,6 +1410,39 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { LogicalResult arith::ExtFOp::verify() { return verifyExtOp(*this); } +struct SimplifyExtFTruncFOpPair : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtFOp extFOp, + PatternRewriter &rewriter) const override { + if (auto truncFOp = extFOp.getOperand().getDefiningOp()) { + Value input = truncFOp.getOperand(); + Type inTy = getElementTypeOrSelf(input.getType()); + Type outTy = getElementTypeOrSelf(extFOp.getType()); + Type shortTy = getElementTypeOrSelf(truncFOp.getType()); + if (isa(inTy) && isa(outTy) && + (isa(shortTy))) { + arith::FastMathFlags truncFMF = truncFOp.getFastmathAttr().getValue(); + bool isTruncContract = + bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract); + arith::FastMathFlags extFMF = extFOp.getFastmathAttr().getValue(); + bool isExtContract = + bitEnumContainsAll(extFMF, arith::FastMathFlags::contract); + if (isTruncContract && isExtContract) { + rewriter.replaceOp(extFOp, truncFOp.getOperand()); + return success(); + } + } + } + return failure(); + } +}; + +void arith::ExtFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index a12da70bae9af..12659eaba1fa5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRArithTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp BufferViewFlowOpInterfaceImpl.cpp - EliminateExplicitRounding.cpp EmulateUnsupportedFloats.cpp EmulateWideInt.cpp EmulateNarrowType.cpp diff --git a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp b/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp deleted file mode 100644 index 7df77629127fa..0000000000000 --- a/mlir/lib/Dialect/Arith/Transforms/EliminateExplicitRounding.cpp +++ /dev/null @@ -1,60 +0,0 @@ -//===- EliminateExplicitRounding.cpp - Remove intermediate extf/truncf pairs -//-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements removing intermediate extf/truncf pairs inserted from -// type conversion. -// -//===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arith/Transforms/Passes.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace arith { -#define GEN_PASS_DEF_ARITHELIMINATEEXPLICITROUNDING -#include "mlir/Dialect/Arith/Transforms/Passes.h.inc" -} // namespace arith -} // namespace mlir - -using namespace mlir; - -namespace { - -struct EliminateExplicitRounding final - : arith::impl::ArithEliminateExplicitRoundingBase< - EliminateExplicitRounding> { - using ArithEliminateExplicitRoundingBase::ArithEliminateExplicitRoundingBase; - void runOnOperation() override { - getOperation()->walk([&](arith::ExtFOp extFOp) { - // Check whether match `truncF->extF` pair. - auto truncFOp = extFOp.getOperand().getDefiningOp(); - if (!truncFOp) - return; - // Check whether the rounding pair's input and output data type are the - // same. Currently only consider to eliminate rounding pairs for (bf16 / - // f16 <-> f32). - Value input = truncFOp.getOperand(); - Type inTy = getElementTypeOrSelf(input.getType()); - Type outTy = getElementTypeOrSelf(extFOp.getType()); - Type shortTy = getElementTypeOrSelf(truncFOp.getType()); - if (isa(inTy) && isa(outTy) && - (isa(shortTy))) { - extFOp.replaceAllUsesWith(input); - extFOp.erase(); - if (truncFOp.getResult().getUses().empty()) - truncFOp.erase(); - } - }); - } -}; - -} // namespace diff --git a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir b/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir deleted file mode 100644 index f2ba276a4f7bb..0000000000000 --- a/mlir/test/Dialect/Arith/eliminate-explicit-rounding.mlir +++ /dev/null @@ -1,119 +0,0 @@ -// RUN: mlir-opt %s --split-input-file --arith-eliminate-explicit-rounding | FileCheck %s - -// CHECK-LABEL: @sequences -// CHECK-SAME: ([[ARG0:%.+]]: bf16) -// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] -// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] -// CHECK: return [[TRUNCF]] : bf16 -func.func @sequences(%arg0: bf16) -> bf16 { - %0 = arith.extf %arg0 : bf16 to f32 - %1 = math.absf %0 : f32 - %2 = arith.truncf %1 : f32 to bf16 - %3 = arith.extf %2 : bf16 to f32 - %4 = math.sin %3 : f32 - %5 = arith.truncf %4 : f32 to bf16 - return %5 : bf16 -} - -// CHECK-LABEL: @eliminatecastoncastf16 -// CHECK: return [[arg0:%.+]] : f32 -func.func @eliminatecastoncastf16(%arg0: f32) -> f32 { - %0 = arith.truncf %arg0 : f32 to f16 - %1 = arith.extf %0 : f16 to f32 - return %1 : f32 -} - -// CHECK-LABEL: @eliminatecastoncastbf16 -// CHECK: return [[arg0:%.+]] : f32 -func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 { - %0 = arith.truncf %arg0 : f32 to bf16 - %1 = arith.extf %0 : bf16 to f32 - return %1 : f32 -} - -// CHECK-LABEL: @bf16_sin_vector -// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) -// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] -// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] -// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> -func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { - %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %1 = math.absf %0 : vector<32x32x32xf32> - %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %4 = math.sin %3 : vector<32x32x32xf32> - %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> - return %5 : vector<32x32x32xbf16> -} - -// CHECK-LABEL: @f16_sin_vector -// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>) -// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] -// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] -// CHECK: return [[TRUNCF]] : vector<32x32x32xf16> -func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { - %0 = arith.extf %arg0 : vector<32x32x32xf16> to vector<32x32x32xf32> - %1 = math.absf %0 : vector<32x32x32xf32> - %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xf16> - %3 = arith.extf %2 : vector<32x32x32xf16> to vector<32x32x32xf32> - %4 = math.sin %3 : vector<32x32x32xf32> - %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xf16> - return %5 : vector<32x32x32xf16> -} - -// CHECK-LABEL: @bf16_branch_vector -// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) -// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] -// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]] -// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]] -// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]] -// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> -func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { - %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %1 = math.absf %0 : vector<32x32x32xf32> - %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %4 = math.sin %3 : vector<32x32x32xf32> - %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %7 = math.cos %3 : vector<32x32x32xf32> - %8 = arith.truncf %7 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %9 = arith.extf %8 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %10 = arith.addf %6, %9 : vector<32x32x32xf32> - %11 = arith.truncf %10 : vector<32x32x32xf32> to vector<32x32x32xbf16> - return %11 : vector<32x32x32xbf16> -} - -// CHECK-LABEL: @bf16_fma -// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>) -// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]] -// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]] -// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]] -// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]] -// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]] -// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]] -// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16> -func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { - %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %1 = math.absf %0 : vector<32x32x32xf32> - %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %4 = math.sin %3 : vector<32x32x32xf32> - %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16> - %8 = arith.extf %7 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %9 = arith.addf %8, %6 : vector<32x32x32xf32> - %10 = arith.truncf %9 : vector<32x32x32xf32> to vector<32x32x32xbf16> - return %10 : vector<32x32x32xbf16> -} diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index d2c2c12d32389..cd06cca33c926 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1243,3 +1243,121 @@ func.func @test_materialize_failure() -> i64 { %u = index.castu %const : index to i64 return %u: i64 } + +// CHECK-LABEL: @sequences +// CHECK-SAME: ([[ARG0:%.+]]: bf16) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : bf16 +func.func @sequences(%arg0: bf16) -> bf16 { + %0 = arith.extf %arg0 : bf16 to f32 + %1 = math.absf %0 : f32 + %2 = arith.truncf %1 : f32 to bf16 + %3 = arith.extf %2 : bf16 to f32 + %4 = math.sin %3 : f32 + %5 = arith.truncf %4 : f32 to bf16 + return %5 : bf16 +} + +// CHECK-LABEL: @eliminatecastoncastf16 +// CHECK: return [[arg0:%.+]] : f32 +func.func @eliminatecastoncastf16(%arg0: f32) -> f32 { + %0 = arith.truncf %arg0 : f32 to f16 + %1 = arith.extf %0 : f16 to f32 + return %1 : f32 +} + +// CHECK-LABEL: @eliminatecastoncastbf16 +// CHECK: return [[arg0:%.+]] : f32 +func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 { + %0 = arith.truncf %arg0 : f32 to bf16 + %1 = arith.extf %0 : bf16 to f32 + return %1 : f32 +} + +// CHECK-LABEL: @bf16_sin_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> +func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> + return %5 : vector<32x32x32xbf16> +} + +// CHECK-LABEL: @f16_sin_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xf16> +func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { + %0 = arith.extf %arg0 : vector<32x32x32xf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xf16> + %3 = arith.extf %2 : vector<32x32x32xf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xf16> + return %5 : vector<32x32x32xf16> +} + +// CHECK-LABEL: @bf16_branch_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]] +// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> +func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %7 = math.cos %3 : vector<32x32x32xf32> + %8 = arith.truncf %7 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %9 = arith.extf %8 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %10 = arith.addf %6, %9 : vector<32x32x32xf32> + %11 = arith.truncf %10 : vector<32x32x32xf32> to vector<32x32x32xbf16> + return %11 : vector<32x32x32xbf16> +} + +// CHECK-LABEL: @bf16_fma +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]] +// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]] +// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]] +// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]] +// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]] +// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]] +// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16> +func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> + %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16> + %8 = arith.extf %7 : vector<32x32x32xbf16> to vector<32x32x32xf32> + %9 = arith.addf %8, %6 : vector<32x32x32xf32> + %10 = arith.truncf %9 : vector<32x32x32xf32> to vector<32x32x32xbf16> + return %10 : vector<32x32x32xbf16> +} From 5d73e9dfe4cb82137e9302526f164dbe0fe9e32d Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Fri, 7 Jun 2024 09:57:46 +0800 Subject: [PATCH 20/25] remove fastmathflags on truncf and extf --- .../include/mlir/Conversion/LLVMCommon/VectorPattern.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index 964281592cc65..a7be4ff0fba7a 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -10,6 +10,7 @@ #define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -98,6 +99,15 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); + + // Check if the operation is remove the fastMathAttr on ExtFOp / TruncFOp. + if (isa(op.getOperation()) || + isa(op.getOperation())) { + if (op->hasAttr("fastmath")) { + op->removeAttr("fastmath"); + } + } + // Determine attributes for the target op AttrConvert attrConvert(op); From d53358cac11555c32e2ec533ecf6922989303645 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Sat, 8 Jun 2024 01:39:03 +0800 Subject: [PATCH 21/25] cancel default contract --- .../Conversion/LLVMCommon/VectorPattern.h | 10 -- .../include/mlir/Dialect/Arith/IR/ArithOps.td | 18 +-- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 15 +- .../Transforms/EmulateUnsupportedFloats.cpp | 11 +- .../Dialect/Math/Transforms/LegalizeToF32.cpp | 11 +- mlir/test/Dialect/Arith/canonicalize.mlir | 137 ++++++++++++++++++ .../Arith/emulate-unsupported-floats.mlir | 137 +++++++++--------- mlir/test/Transforms/canonicalize.mlir | 118 --------------- 8 files changed, 238 insertions(+), 219 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index a7be4ff0fba7a..964281592cc65 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -10,7 +10,6 @@ #define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -99,15 +98,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); - - // Check if the operation is remove the fastMathAttr on ExtFOp / TruncFOp. - if (isa(op.getOperation()) || - isa(op.getOperation())) { - if (op->hasAttr("fastmath")) { - op->removeAttr("fastmath"); - } - } - // Determine attributes for the target op AttrConvert attrConvert(op); diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 6fff83dc3df7f..2e0a1d8d2f678 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1186,7 +1186,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> { // ExtFOp //===----------------------------------------------------------------------===// -def Arith_ExtFOp : Arith_FToFCastOp<"extf"> { +def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods]> { let summary = "cast from floating-point to wider floating-point"; let description = [{ Cast a floating-point value to a larger floating-point-typed value. @@ -1197,12 +1197,11 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf"> { let hasFolder = 1; let hasCanonicalizer = 1; - let arguments = (ins FloatLike:$in, DefaultValuedAttr< - Arith_FastMathAttr, - "::mlir::arith::FastMathFlags::contract">:$fastmath); + let arguments = (ins FloatLike:$in, OptionalAttr:$fastmath); let results = (outs FloatLike:$out); - let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)"; + let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)? + attr-dict `:` type($in) `to` type($out) }]; } //===----------------------------------------------------------------------===// @@ -1241,12 +1240,11 @@ def Arith_TruncFOp : Arith_Op<"truncf", [Pure, SameOperandsAndResultShape, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]>, Arguments<(ins FloatLike:$in, OptionalAttr:$roundingmode, - DefaultValuedAttr< - Arith_FastMathAttr, - "::mlir::arith::FastMathFlags::contract">:$fastmath)>, + OptionalAttr:$fastmath)>, Results<(outs FloatLike:$out)> { let summary = "cast from floating-point to narrower floating-point"; let description = [{ @@ -1265,7 +1263,9 @@ def Arith_TruncFOp : let hasFolder = 1; let hasVerifier = 1; - let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)"; + let assemblyFormat = [{ $in ($roundingmode^)? + (`fastmath` `` $fastmath^)? + attr-dict `:` type($in) `to` type($out) }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 1a135668a23e6..895d72c6f9ca9 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1416,16 +1416,15 @@ struct SimplifyExtFTruncFOpPair : public OpRewritePattern { LogicalResult matchAndRewrite(ExtFOp extFOp, PatternRewriter &rewriter) const override { if (auto truncFOp = extFOp.getOperand().getDefiningOp()) { - Value input = truncFOp.getOperand(); - Type inTy = getElementTypeOrSelf(input.getType()); - Type outTy = getElementTypeOrSelf(extFOp.getType()); - Type shortTy = getElementTypeOrSelf(truncFOp.getType()); - if (isa(inTy) && isa(outTy) && - (isa(shortTy))) { - arith::FastMathFlags truncFMF = truncFOp.getFastmathAttr().getValue(); + if (truncFOp.getOperand().getType() == extFOp.getType()) { + // RoundingMode roundingMode = + // getRoundingmode().value_or(RoundingMode::to_nearest_even); + arith::FastMathFlags truncFMF = + truncFOp.getFastmath().value_or(arith::FastMathFlags::none); bool isTruncContract = bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract); - arith::FastMathFlags extFMF = extFOp.getFastmathAttr().getValue(); + arith::FastMathFlags extFMF = + extFOp.getFastmath().value_or(arith::FastMathFlags::none); bool isExtContract = bitEnumContainsAll(extFMF, arith::FastMathFlags::contract); if (isTruncContract && isExtContract) { diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 4a50da3513f99..8e1cb474feee7 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef operands, SmallVector newResults(expandedOp->getResults()); for (auto [res, oldType, newType] : llvm::zip_equal( MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { - if (oldType != newType) - res = rewriter.create(loc, oldType, res); + if (oldType != newType) { + auto truncFOp = rewriter.create(loc, oldType, res); + truncFOp.setFastmath(arith::FastMathFlags::contract); + res = truncFOp.getResult(); + } } rewriter.replaceOp(op, newResults); } @@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions( }); converter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - return b.create(loc, target, input); + auto extFOp = b.create(loc, target, input); + extFOp.setFastmath(arith::FastMathFlags::contract); + return extFOp; }); } diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp index 5998133b7eab8..3d99f3033cf56 100644 --- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp +++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp @@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter( }); typeConverter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - return b.create(loc, target, input); + auto extFOp = b.create(loc, target, input); + extFOp.setFastmath(arith::FastMathFlags::contract); + return extFOp; }); } @@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite( SmallVector results = (*legalized)->getResults(); for (auto [result, newType, origType] : llvm::zip_equal( results, (*legalized)->getResultTypes(), op->getResultTypes())) { - if (newType != origType) - result = rewriter.create(loc, origType, result); + if (newType != origType) { + auto truncFOp = rewriter.create(loc, origType, result); + truncFOp.setFastmath(arith::FastMathFlags::contract); + result = truncFOp.getResult(); + } } rewriter.replaceOp(op, results); return success(); diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 1a387c20c4b29..78d12af4c3054 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -3039,6 +3039,143 @@ func.func @mulsi_extended_i0() -> (i0, i0) { return %mulsi_extended#0, %mulsi_extended#1 : i0, i0 } +// CHECK-LABEL: @sequences_fastmath_contract +// CHECK-SAME: ([[ARG0:%.+]]: bf16) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : bf16 +func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 { + %0 = arith.extf %arg0 fastmath : bf16 to f32 + %1 = math.absf %0 : f32 + %2 = arith.truncf %1 fastmath : f32 to bf16 + %3 = arith.extf %2 fastmath : bf16 to f32 + %4 = math.sin %3 : f32 + %5 = arith.truncf %4 fastmath : f32 to bf16 + return %5 : bf16 +} + +// CHECK-LABEL: @sequences_no_fastmath +// CHECK-SAME: ([[ARG0:%.+]]: bf16) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]] +// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]] +// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : bf16 +func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 { + %0 = arith.extf %arg0 : bf16 to f32 + %1 = math.absf %0 : f32 + %2 = arith.truncf %1 : f32 to bf16 + %3 = arith.extf %2 : bf16 to f32 + %4 = math.sin %3 : f32 + %5 = arith.truncf %4 : f32 to bf16 + return %5 : bf16 +} + +// CHECK-LABEL: @eliminatecastoncastf16 +// CHECK: return [[arg0:%.+]] : f32 +func.func @eliminatecastoncastf16(%arg0: f32) -> f32 { + %0 = arith.truncf %arg0 fastmath : f32 to f16 + %1 = arith.extf %0 fastmath : f16 to f32 + return %1 : f32 +} + +// CHECK-LABEL: @eliminatecastoncastbf16 +// CHECK: return [[arg0:%.+]] : f32 +func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 { + %0 = arith.truncf %arg0 fastmath : f32 to bf16 + %1 = arith.extf %0 fastmath : bf16 to f32 + return %1 : f32 +} + +// CHECK-LABEL: @bf16_sin_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> +func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = arith.extf %arg0 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 fastmath : vector<32x32x32xf32> to vector<32x32x32xbf16> + %3 = arith.extf %2 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 fastmath : vector<32x32x32xf32> to vector<32x32x32xbf16> + return %5 : vector<32x32x32xbf16> +} + +// CHECK-LABEL: @f16_sin_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xf16> +func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { + %0 = arith.extf %arg0 fastmath : vector<32x32x32xf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 fastmath : vector<32x32x32xf32> to vector<32x32x32xf16> + %3 = arith.extf %2 fastmath : vector<32x32x32xf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 fastmath : vector<32x32x32xf32> to vector<32x32x32xf16> + return %5 : vector<32x32x32xf16> +} + +// CHECK-LABEL: @bf16_branch_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]] +// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> +func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = arith.extf %arg0 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 fastmath : vector<32x32x32xf32> to vector<32x32x32xbf16> + %3 = arith.extf %2 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 fastmath : vector<32x32x32xf32> to vector<32x32x32xbf16> + %6 = arith.extf %5 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %7 = math.cos %3 : vector<32x32x32xf32> + %8 = arith.truncf %7 fastmath : vector<32x32x32xf32> to vector<32x32x32xbf16> + %9 = arith.extf %8 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %10 = arith.addf %6, %9 : vector<32x32x32xf32> + %11 = arith.truncf %10 fastmath : vector<32x32x32xf32> to vector<32x32x32xbf16> + return %11 : vector<32x32x32xbf16> +} + +// CHECK-LABEL: @bf16_fma +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]] +// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]] +// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]] +// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]] +// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]] +// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]] +// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16> +func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = arith.extf %arg0 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %1 = math.absf %0 : vector<32x32x32xf32> + %2 = arith.truncf %1 fastmath : vector<32x32x32xf32> to vector<32x32x32xbf16> + %3 = arith.extf %2 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %4 = math.sin %3 : vector<32x32x32xf32> + %5 = arith.truncf %4 fastmath : vector<32x32x32xf32> to vector<32x32x32xbf16> + %6 = arith.extf %5 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16> + %8 = arith.extf %7 fastmath : vector<32x32x32xbf16> to vector<32x32x32xf32> + %9 = arith.addf %8, %6 : vector<32x32x32xf32> + %10 = arith.truncf %9 fastmath : vector<32x32x32xf32> to vector<32x32x32xbf16> + return %10 : vector<32x32x32xbf16> +} + {-# dialect_resources: { builtin: { diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir index a69ef131d8d47..75ae4168dd1b1 100644 --- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir +++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir @@ -4,84 +4,85 @@ func.func @basic_expansion(%x: bf16) -> bf16 { // CHECK-LABEL: @basic_expansion // CHECK-SAME: [[X:%.+]]: bf16 // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16 -// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32 -// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32 +// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath : bf16 to f32 +// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath : bf16 to f32 // CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32 -// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16 +// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath : f32 to bf16 // CHECK: return [[Y]] %c = arith.constant 1.0 : bf16 %y = arith.addf %x, %c : bf16 func.return %y : bf16 } -// ----- - -func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { -// CHECK-LABEL: @chained -// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16 -// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32 -// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32 -// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32 -// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32 -// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16 -// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32 -// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]] -// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16 -// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32 -// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32 -// CHECK: return [[RES]] - %p = arith.addf %x, %y : bf16 - %q = arith.mulf %p, %z : bf16 - %res = arith.cmpf ole, %p, %q : bf16 - func.return %res : i1 + // ----- + + func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { + // CHECK-LABEL: @chained + // CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16 + // CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath : bf16 to f32 + // CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath : bf16 to f32 + // CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath : bf16 to f32 + // CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32 + // CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath : f32 to bf16 + // CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath : bf16 to f32 + // CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]] + // CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath : f32 to bf16 + // CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath : bf16 to f32 + // CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32 + // CHECK: return [[RES]] + %p = arith.addf %x, %y : bf16 + %q = arith.mulf %p, %z : bf16 + %res = arith.cmpf ole, %p, %q : bf16 + func.return %res : i1 } -// ----- - -func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { -// CHECK-LABEL: @memops -// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ> -// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32 -// CHECK: memref.store [[V]] -// CHECK: [[W:%.+]] = memref.load -// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32 -// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32 -// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ -// CHECK: memref.store [[X]] - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ> - memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ> - %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ> - %x = arith.addf %v, %w : f8E4M3FNUZ - memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ> - func.return + // ----- + + func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { + // CHECK-LABEL: @memops + // CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ> + // CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath : f8E4M3FNUZ to f32 + // CHECK: memref.store [[V]] + // CHECK: [[W:%.+]] = memref.load + // CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath : f8E4M3FNUZ to f32 + // CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32 + // CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath : f32 to f8E4M3FNUZ + // CHECK: memref.store [[X]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ> + memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ> + %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ> + %x = arith.addf %v, %w : f8E4M3FNUZ + memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ> + func.return } -// ----- - -func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { -// CHECK-LABEL: @vectors -// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ> -// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32> -// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32> -// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ> -// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32> -// CHECK: return [[RET]] - %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> - %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> - func.return %ret : vector<4xf32> -} + // ----- + + func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { + // CHECK-LABEL: @vectors + // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ> + // CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath : vector<4xf8E4M3FNUZ> to vector<4xf32> + // CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32> + // CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath : vector<4xf32> to vector<4xf8E4M3FNUZ> + // CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32> + // CHECK: return [[RET]] + %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> + %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> + func.return %ret : vector<4xf32> + } -// ----- + // ----- + + func.func @no_expansion(%x: f32) -> f32 { + // CHECK-LABEL: @no_expansion + // CHECK-SAME: [[X:%.+]]: f32 + // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32 + // CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32 + // CHECK: return [[Y]] + %c = arith.constant 1.0 : f32 + %y = arith.addf %x, %c : f32 + func.return %y : f32 + } -func.func @no_expansion(%x: f32) -> f32 { -// CHECK-LABEL: @no_expansion -// CHECK-SAME: [[X:%.+]]: f32 -// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32 -// CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32 -// CHECK: return [[Y]] - %c = arith.constant 1.0 : f32 - %y = arith.addf %x, %c : f32 - func.return %y : f32 -} diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index cd06cca33c926..d2c2c12d32389 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1243,121 +1243,3 @@ func.func @test_materialize_failure() -> i64 { %u = index.castu %const : index to i64 return %u: i64 } - -// CHECK-LABEL: @sequences -// CHECK-SAME: ([[ARG0:%.+]]: bf16) -// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] -// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] -// CHECK: return [[TRUNCF]] : bf16 -func.func @sequences(%arg0: bf16) -> bf16 { - %0 = arith.extf %arg0 : bf16 to f32 - %1 = math.absf %0 : f32 - %2 = arith.truncf %1 : f32 to bf16 - %3 = arith.extf %2 : bf16 to f32 - %4 = math.sin %3 : f32 - %5 = arith.truncf %4 : f32 to bf16 - return %5 : bf16 -} - -// CHECK-LABEL: @eliminatecastoncastf16 -// CHECK: return [[arg0:%.+]] : f32 -func.func @eliminatecastoncastf16(%arg0: f32) -> f32 { - %0 = arith.truncf %arg0 : f32 to f16 - %1 = arith.extf %0 : f16 to f32 - return %1 : f32 -} - -// CHECK-LABEL: @eliminatecastoncastbf16 -// CHECK: return [[arg0:%.+]] : f32 -func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 { - %0 = arith.truncf %arg0 : f32 to bf16 - %1 = arith.extf %0 : bf16 to f32 - return %1 : f32 -} - -// CHECK-LABEL: @bf16_sin_vector -// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) -// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] -// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] -// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> -func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { - %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %1 = math.absf %0 : vector<32x32x32xf32> - %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %4 = math.sin %3 : vector<32x32x32xf32> - %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> - return %5 : vector<32x32x32xbf16> -} - -// CHECK-LABEL: @f16_sin_vector -// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>) -// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] -// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] -// CHECK: return [[TRUNCF]] : vector<32x32x32xf16> -func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { - %0 = arith.extf %arg0 : vector<32x32x32xf16> to vector<32x32x32xf32> - %1 = math.absf %0 : vector<32x32x32xf32> - %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xf16> - %3 = arith.extf %2 : vector<32x32x32xf16> to vector<32x32x32xf32> - %4 = math.sin %3 : vector<32x32x32xf32> - %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xf16> - return %5 : vector<32x32x32xf16> -} - -// CHECK-LABEL: @bf16_branch_vector -// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) -// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] -// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]] -// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]] -// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]] -// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> -func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { - %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %1 = math.absf %0 : vector<32x32x32xf32> - %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %4 = math.sin %3 : vector<32x32x32xf32> - %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %7 = math.cos %3 : vector<32x32x32xf32> - %8 = arith.truncf %7 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %9 = arith.extf %8 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %10 = arith.addf %6, %9 : vector<32x32x32xf32> - %11 = arith.truncf %10 : vector<32x32x32xf32> to vector<32x32x32xbf16> - return %11 : vector<32x32x32xbf16> -} - -// CHECK-LABEL: @bf16_fma -// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>) -// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]] -// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]] -// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]] -// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]] -// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]] -// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]] -// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]] -// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16> -func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { - %0 = arith.extf %arg0 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %1 = math.absf %0 : vector<32x32x32xf32> - %2 = arith.truncf %1 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %3 = arith.extf %2 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %4 = math.sin %3 : vector<32x32x32xf32> - %5 = arith.truncf %4 : vector<32x32x32xf32> to vector<32x32x32xbf16> - %6 = arith.extf %5 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16> - %8 = arith.extf %7 : vector<32x32x32xbf16> to vector<32x32x32xf32> - %9 = arith.addf %8, %6 : vector<32x32x32xf32> - %10 = arith.truncf %9 : vector<32x32x32xf32> to vector<32x32x32xbf16> - return %10 : vector<32x32x32xbf16> -} From 30e3d66ed7813226d3a8eaabf7532cecda7c03ce Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Sat, 8 Jun 2024 02:07:44 +0800 Subject: [PATCH 22/25] use folder instead --- .../include/mlir/Dialect/Arith/IR/ArithOps.td | 1 - mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 48 +++++++------------ 2 files changed, 16 insertions(+), 33 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 2e0a1d8d2f678..29591bab5010e 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1195,7 +1195,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods:$fastmath); let results = (outs FloatLike:$out); diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 895d72c6f9ca9..d5f352bad0fa4 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1390,6 +1390,22 @@ LogicalResult arith::ExtSIOp::verify() { /// Fold extension of float constants when there is no information loss due the /// difference in fp semantics. OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) { + if (auto truncFOp = getOperand().getDefiningOp()) { + if (truncFOp.getOperand().getType() == getType()) { + arith::FastMathFlags truncFMF = + truncFOp.getFastmath().value_or(arith::FastMathFlags::none); + bool isTruncContract = + bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract); + arith::FastMathFlags extFMF = + getFastmath().value_or(arith::FastMathFlags::none); + bool isExtContract = + bitEnumContainsAll(extFMF, arith::FastMathFlags::contract); + if (isTruncContract && isExtContract) { + return truncFOp.getOperand(); + } + } + } + auto resElemType = cast(getElementTypeOrSelf(getType())); const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); return constFoldCastOp( @@ -1410,38 +1426,6 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { LogicalResult arith::ExtFOp::verify() { return verifyExtOp(*this); } -struct SimplifyExtFTruncFOpPair : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtFOp extFOp, - PatternRewriter &rewriter) const override { - if (auto truncFOp = extFOp.getOperand().getDefiningOp()) { - if (truncFOp.getOperand().getType() == extFOp.getType()) { - // RoundingMode roundingMode = - // getRoundingmode().value_or(RoundingMode::to_nearest_even); - arith::FastMathFlags truncFMF = - truncFOp.getFastmath().value_or(arith::FastMathFlags::none); - bool isTruncContract = - bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract); - arith::FastMathFlags extFMF = - extFOp.getFastmath().value_or(arith::FastMathFlags::none); - bool isExtContract = - bitEnumContainsAll(extFMF, arith::FastMathFlags::contract); - if (isTruncContract && isExtContract) { - rewriter.replaceOp(extFOp, truncFOp.getOperand()); - return success(); - } - } - } - return failure(); - } -}; - -void arith::ExtFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) { - patterns.add(context); -} - //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// From 618cda445d36b6bac08ab8581a4ee142ef2d430d Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Tue, 11 Jun 2024 10:04:34 +0800 Subject: [PATCH 23/25] fix comment --- .../include/mlir/Dialect/Arith/IR/ArithOps.td | 8 +- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 6 +- .../Arith/emulate-unsupported-floats.mlir | 132 +++++++++--------- 3 files changed, 73 insertions(+), 73 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 29591bab5010e..e5c4b9f32354a 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1196,7 +1196,8 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods:$fastmath); + let arguments = (ins FloatLike:$in, DefaultValuedAttr< + Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath); let results = (outs FloatLike:$out); let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)? @@ -1242,8 +1243,9 @@ def Arith_TruncFOp : DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]>, Arguments<(ins FloatLike:$in, - OptionalAttr:$roundingmode, - OptionalAttr:$fastmath)>, + DefaultValuedAttr< + Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath, + OptionalAttr:$roundingmode)>, Results<(outs FloatLike:$out)> { let summary = "cast from floating-point to narrower floating-point"; let description = [{ diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d5f352bad0fa4..d304c848ad4dd 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1392,12 +1392,10 @@ LogicalResult arith::ExtSIOp::verify() { OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) { if (auto truncFOp = getOperand().getDefiningOp()) { if (truncFOp.getOperand().getType() == getType()) { - arith::FastMathFlags truncFMF = - truncFOp.getFastmath().value_or(arith::FastMathFlags::none); + arith::FastMathFlags truncFMF = truncFOp.getFastmath(); bool isTruncContract = bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract); - arith::FastMathFlags extFMF = - getFastmath().value_or(arith::FastMathFlags::none); + arith::FastMathFlags extFMF = getFastmath(); bool isExtContract = bitEnumContainsAll(extFMF, arith::FastMathFlags::contract); if (isTruncContract && isExtContract) { diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir index 75ae4168dd1b1..a34c4cd8979b6 100644 --- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir +++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir @@ -14,75 +14,75 @@ func.func @basic_expansion(%x: bf16) -> bf16 { func.return %y : bf16 } - // ----- - - func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { - // CHECK-LABEL: @chained - // CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16 - // CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath : bf16 to f32 - // CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath : bf16 to f32 - // CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath : bf16 to f32 - // CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32 - // CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath : f32 to bf16 - // CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath : bf16 to f32 - // CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]] - // CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath : f32 to bf16 - // CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath : bf16 to f32 - // CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32 - // CHECK: return [[RES]] - %p = arith.addf %x, %y : bf16 - %q = arith.mulf %p, %z : bf16 - %res = arith.cmpf ole, %p, %q : bf16 - func.return %res : i1 +// ----- + +func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { +// CHECK-LABEL: @chained +// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16 +// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath : bf16 to f32 +// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath : bf16 to f32 +// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath : bf16 to f32 +// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32 +// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath : f32 to bf16 +// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath : bf16 to f32 +// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]] +// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath : f32 to bf16 +// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath : bf16 to f32 +// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32 +// CHECK: return [[RES]] + %p = arith.addf %x, %y : bf16 + %q = arith.mulf %p, %z : bf16 + %res = arith.cmpf ole, %p, %q : bf16 + func.return %res : i1 +} + +// ----- + +func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { +// CHECK-LABEL: @memops +// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ> +// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath : f8E4M3FNUZ to f32 +// CHECK: memref.store [[V]] +// CHECK: [[W:%.+]] = memref.load +// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath : f8E4M3FNUZ to f32 +// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32 +// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath : f32 to f8E4M3FNUZ +// CHECK: memref.store [[X]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ> + memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ> + %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ> + %x = arith.addf %v, %w : f8E4M3FNUZ + memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ> + func.return } - // ----- - - func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { - // CHECK-LABEL: @memops - // CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ> - // CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath : f8E4M3FNUZ to f32 - // CHECK: memref.store [[V]] - // CHECK: [[W:%.+]] = memref.load - // CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath : f8E4M3FNUZ to f32 - // CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32 - // CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath : f32 to f8E4M3FNUZ - // CHECK: memref.store [[X]] - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ> - memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ> - %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ> - %x = arith.addf %v, %w : f8E4M3FNUZ - memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ> - func.return +// ----- + +func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { +// CHECK-LABEL: @vectors +// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ> +// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath : vector<4xf8E4M3FNUZ> to vector<4xf32> +// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32> +// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath : vector<4xf32> to vector<4xf8E4M3FNUZ> +// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32> +// CHECK: return [[RET]] + %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> + %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> + func.return %ret : vector<4xf32> } - // ----- - - func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { - // CHECK-LABEL: @vectors - // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ> - // CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath : vector<4xf8E4M3FNUZ> to vector<4xf32> - // CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32> - // CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath : vector<4xf32> to vector<4xf8E4M3FNUZ> - // CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32> - // CHECK: return [[RET]] - %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> - %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> - func.return %ret : vector<4xf32> - } +// ----- - // ----- - - func.func @no_expansion(%x: f32) -> f32 { - // CHECK-LABEL: @no_expansion - // CHECK-SAME: [[X:%.+]]: f32 - // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32 - // CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32 - // CHECK: return [[Y]] - %c = arith.constant 1.0 : f32 - %y = arith.addf %x, %c : f32 - func.return %y : f32 - } +func.func @no_expansion(%x: f32) -> f32 { +// CHECK-LABEL: @no_expansion +// CHECK-SAME: [[X:%.+]]: f32 +// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32 +// CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32 +// CHECK: return [[Y]] + %c = arith.constant 1.0 : f32 + %y = arith.addf %x, %c : f32 + func.return %y : f32 +} From 2532206541a4143895571072df888655e0ed8982 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Tue, 11 Jun 2024 19:49:43 +0800 Subject: [PATCH 24/25] fix ci --- .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 34 ++++++++--------- .../Arith/emulate-unsupported-floats.mlir | 37 +++++++++---------- 2 files changed, 35 insertions(+), 36 deletions(-) diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 56ae930e6d627..cacdd801871fb 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -162,11 +162,11 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) { // Checking conversion of integer types to floating point. // CHECK-LABEL: @fpext func.func @fpext(%arg0 : f16, %arg1 : f32) { -// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f32 +// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath} : f16 to f32 %0 = arith.extf %arg0: f16 to f32 -// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f64 +// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath} : f16 to f64 %1 = arith.extf %arg0: f16 to f64 -// CHECK-NEXT: = llvm.fpext {{.*}} : f32 to f64 +// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath} : f32 to f64 %2 = arith.extf %arg1: f32 to f64 return } @@ -174,11 +174,11 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) { // Checking conversion of integer types to floating point. // CHECK-LABEL: @fpext func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) { -// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf32> +// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath} : vector<2xf16> to vector<2xf32> %0 = arith.extf %arg0: vector<2xf16> to vector<2xf32> -// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf64> +// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath} : vector<2xf16> to vector<2xf64> %1 = arith.extf %arg0: vector<2xf16> to vector<2xf64> -// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf32> to vector<2xf64> +// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath} : vector<2xf32> to vector<2xf64> %2 = arith.extf %arg1: vector<2xf32> to vector<2xf64> return } @@ -268,11 +268,11 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v // Checking conversion of integer types to floating point. // CHECK-LABEL: @fptrunc func.func @fptrunc(%arg0 : f32, %arg1 : f64) { -// CHECK-NEXT: = llvm.fptrunc {{.*}} : f32 to f16 +// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath} : f32 to f16 %0 = arith.truncf %arg0: f32 to f16 -// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f16 +// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath} : f64 to f16 %1 = arith.truncf %arg1: f64 to f16 -// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f32 +// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath} : f64 to f32 %2 = arith.truncf %arg1: f64 to f32 return } @@ -280,26 +280,26 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) { // Checking conversion of integer types to floating point. // CHECK-LABEL: @fptrunc func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) { -// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf32> to vector<2xf16> +// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath} : vector<2xf32> to vector<2xf16> %0 = arith.truncf %arg0: vector<2xf32> to vector<2xf16> -// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf16> +// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath} : vector<2xf64> to vector<2xf16> %1 = arith.truncf %arg1: vector<2xf64> to vector<2xf16> -// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf32> +// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath} : vector<2xf64> to vector<2xf32> %2 = arith.truncf %arg1: vector<2xf64> to vector<2xf32> return } // CHECK-LABEL: experimental_constrained_fptrunc func.func @experimental_constrained_fptrunc(%arg0 : f64) { -// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32 +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore {fastmath = #arith.fastmath} : f64 to f32 %0 = arith.truncf %arg0 to_nearest_even : f64 to f32 -// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32 +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore {fastmath = #arith.fastmath} : f64 to f32 %1 = arith.truncf %arg0 downward : f64 to f32 -// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32 +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore {fastmath = #arith.fastmath} : f64 to f32 %2 = arith.truncf %arg0 upward : f64 to f32 -// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32 +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore {fastmath = #arith.fastmath} : f64 to f32 %3 = arith.truncf %arg0 toward_zero : f64 to f32 -// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32 +// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore {fastmath = #arith.fastmath} : f64 to f32 %4 = arith.truncf %arg0 to_nearest_away : f64 to f32 return } diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir index a34c4cd8979b6..99790cc45d490 100644 --- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir +++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir @@ -30,10 +30,10 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 { // CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath : bf16 to f32 // CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32 // CHECK: return [[RES]] - %p = arith.addf %x, %y : bf16 - %q = arith.mulf %p, %z : bf16 - %res = arith.cmpf ole, %p, %q : bf16 - func.return %res : i1 + %p = arith.addf %x, %y : bf16 + %q = arith.mulf %p, %z : bf16 + %res = arith.cmpf ole, %p, %q : bf16 + func.return %res : i1 } // ----- @@ -48,14 +48,14 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { // CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32 // CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath : f32 to f8E4M3FNUZ // CHECK: memref.store [[X]] - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ> - memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ> - %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ> - %x = arith.addf %v, %w : f8E4M3FNUZ - memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ> - func.return + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ> + memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ> + %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ> + %x = arith.addf %v, %w : f8E4M3FNUZ + memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ> + func.return } // ----- @@ -68,9 +68,9 @@ func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { // CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath : vector<4xf32> to vector<4xf8E4M3FNUZ> // CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32> // CHECK: return [[RET]] - %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> - %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> - func.return %ret : vector<4xf32> + %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> + %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> + func.return %ret : vector<4xf32> } // ----- @@ -81,8 +81,7 @@ func.func @no_expansion(%x: f32) -> f32 { // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32 // CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32 // CHECK: return [[Y]] - %c = arith.constant 1.0 : f32 - %y = arith.addf %x, %c : f32 - func.return %y : f32 + %c = arith.constant 1.0 : f32 + %y = arith.addf %x, %c : f32 + func.return %y : f32 } - From a8cf350b61a2f19fe27b0cb31a89bee69188fe5a Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Wed, 12 Jun 2024 15:10:54 +0800 Subject: [PATCH 25/25] rename --- mlir/test/Dialect/Arith/canonicalize.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index de362b0b83a41..4fe7cfb689be8 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -3067,17 +3067,17 @@ func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 { return %5 : bf16 } -// CHECK-LABEL: @eliminatecastoncastf16 +// CHECK-LABEL: @eliminate_cast_to_f16 // CHECK: return [[arg0:%.+]] : f32 -func.func @eliminatecastoncastf16(%arg0: f32) -> f32 { +func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 { %0 = arith.truncf %arg0 fastmath : f32 to f16 %1 = arith.extf %0 fastmath : f16 to f32 return %1 : f32 } -// CHECK-LABEL: @eliminatecastoncastbf16 +// CHECK-LABEL: @eliminate_cast_to_bf16 // CHECK: return [[arg0:%.+]] : f32 -func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 { +func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 { %0 = arith.truncf %arg0 fastmath : f32 to bf16 %1 = arith.extf %0 fastmath : bf16 to f32 return %1 : f32