diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td index e870e714bfda5..234b4f43f08c0 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -34,6 +34,14 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> { that is an operation frequently implemented at low precisions. }]; let dependentDialects = ["math::MathDialect", "arith::ArithDialect"]; + let options = [ + Option<"useCanonicalizeF32Promotion", "use-canonicalize-f32-promotion", "bool", + /*default=*/"true", + "Eliminate the redundant truncf/extf pairs to improve performance," + "while may introduce numerical difference as the f32->bf16 rounding is" + "eliminated."> + ]; + } #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp index 5998133b7eab8..883238fba9fbf 100644 --- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp +++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::math { #define GEN_PASS_DEF_MATHLEGALIZETOF32 @@ -37,6 +38,8 @@ struct LegalizeToF32RewritePattern final : ConversionPattern { struct LegalizeToF32Pass final : mlir::math::impl::MathLegalizeToF32Base { + LegalizeToF32Pass() = default; + LegalizeToF32Pass(const mlir::math::MathLegalizeToF32Options &options) {} void runOnOperation() override; }; } // namespace @@ -97,6 +100,29 @@ void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns, patterns.getContext()); } +struct CanonicalizeF32PromotionRewritePattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const final { + if (auto innertruncop = op.getOperand().getDefiningOp()) { + if (auto truncinput = innertruncop.getOperand()) { + auto outterTy = getElementTypeOrSelf(op.getType()); + auto intermediateTy = getElementTypeOrSelf(innertruncop.getType()); + auto innerTy = getElementTypeOrSelf(truncinput.getType()); + if (outterTy.isF32() && + (intermediateTy.isF16() || intermediateTy.isBF16()) && + innerTy.isF32()) { + rewriter.replaceOp(op, {truncinput}); + } + } else + return failure(); + } else + return failure(); + return success(); + } +}; + void LegalizeToF32Pass::runOnOperation() { Operation *op = getOperation(); MLIRContext &ctx = getContext(); @@ -109,4 +135,14 @@ void LegalizeToF32Pass::runOnOperation() { math::populateLegalizeToF32Patterns(patterns, typeConverter); if (failed(applyPartialConversion(op, target, std::move(patterns)))) return signalPassFailure(); + + if (useCanonicalizeF32Promotion) { + RewritePatternSet cano_patterns(&getContext()); + cano_patterns.insert(&getContext()); + FrozenRewritePatternSet cano_patternSet(std::move(cano_patterns)); + op->walk([cano_patternSet](arith::ExtFOp extop) { + if (failed(applyOpPatternsAndFold({extop}, cano_patternSet))) + extop->emitError("fail to do implicit rounding removement"); + }); + } } diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir index ae6ae7c5bc4b4..1b7bb51e771fb 100644 --- a/mlir/test/Dialect/Math/legalize-to-f32.mlir +++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s +// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32=use-canonicalize-f32-promotion=true | FileCheck %s // CHECK-LABEL: @sin // CHECK-SAME: ([[ARG0:%.+]]: f16) @@ -70,16 +70,74 @@ func.func @fastmath(%arg0: f16) -> f16 { } // CHECK-LABEL: @sequences -// CHECK-SAME: ([[ARG0:%.+]]: f16) -// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]] -// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]] -// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]] -// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]] -// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]] -// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]] -// CHECK: return [[TRUNCF1]] : f16 -func.func @sequences(%arg0: f16) -> f16 { - %0 = math.absf %arg0 : f16 - %1 = math.sin %0 : f16 - return %1 : f16 +// CHECK-SAME: ([[ARG0:%.+]]: bf16) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : bf16 +func.func @sequences(%arg0: bf16) -> bf16 { + %0 = math.absf %arg0 : bf16 + %1 = math.sin %0 : bf16 + return %1 : bf16 +} + +// CHECK-LABEL: @eliminatecastoncastf16 +// CHECK: return [[arg0:%.+]] : f32 +func.func @eliminatecastoncastf16(%arg0: f32) -> f32 { + %0 = arith.truncf %arg0 : f32 to f16 + %1 = arith.extf %0 : f16 to f32 + return %1 : f32 +} + +// CHECK-LABEL: @eliminatecastoncastbf16 +// CHECK: return [[arg0:%.+]] : f32 +func.func @eliminatecastoncastbf16(%arg0: f32) -> f32 { + %0 = arith.truncf %arg0 : f32 to bf16 + %1 = arith.extf %0 : bf16 to f32 + return %1 : f32 +} + +// CHECK-LABEL: @bf16_sin_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16> +func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = math.absf %arg0 : vector<32x32x32xbf16> + %1 = math.sin %0 : vector<32x32x32xbf16> + return %1 : vector<32x32x32xbf16> +} + +// CHECK-LABEL: @f16_sin_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]] +// CHECK: return [[TRUNCF]] : vector<32x32x32xf16> +func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> { + %0 = math.absf %arg0 : vector<32x32x32xf16> + %1 = math.sin %0 : vector<32x32x32xf16> + return %1 : vector<32x32x32xf16> +} + +// CHECK-LABEL: @bf16_branch_vector +// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>) +// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]] +// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]] +// CHECK: [[SIN:%.+]] = math.sin [[ABSF]] +// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]] +// CHECK: [[COS:%.+]] = math.cos [[ABSF]] +// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[COS]] +// CHECK: [[ADDF:%.+]] = arith.addf +// CHECK: return [[ADDF]] : vector<32x32x32xbf16> +func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> { + %0 = math.absf %arg0 : vector<32x32x32xbf16> + %1 = math.sin %0 : vector<32x32x32xbf16> + %2 = math.cos %0 : vector<32x32x32xbf16> + %3 = arith.addf %1, %2 : vector<32x32x32xbf16> + return %3 : vector<32x32x32xbf16> }