From e1c84a3a243c6b8c963fa98f916fc70612f6093c Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Fri, 7 Jun 2024 00:54:48 +0000 Subject: [PATCH 1/4] [mlir][Arith] Generalize and improve -int-range-optimizations When the integer range analysis was first develop, a pass that did integer range-based constant folding was developed and used as a test pass. There was an intent to add such a folding to SCCP, but that hasn't happened. Meanwhile, -int-range-optimizations was added to the arith dialect's transformations. The cmpi simplification in that pass is a strict subset of the constant folding that lived in -test-int-range-inference. This commit moves the former test pass into -int-range-optimizaitons, subsuming its previous contents. It also adds an optimization from rocMLIR where `rem{s,u}i` operations that are noops are replaced by their left operands. --- .../mlir/Dialect/Arith/Transforms/Passes.h | 4 - .../mlir/Dialect/Arith/Transforms/Passes.td | 9 +- .../Transforms/IntRangeOptimizations.cpp | 287 ++++++++---------- .../Dialect/Arith/int-range-interface.mlir | 2 +- mlir/test/Dialect/Arith/int-range-opts.mlir | 36 +++ .../test/Dialect/GPU/int-range-interface.mlir | 2 +- .../Dialect/Index/int-range-inference.mlir | 2 +- .../infer-int-range-test-ops.mlir | 10 +- mlir/test/lib/Transforms/CMakeLists.txt | 1 - .../lib/Transforms/TestIntRangeInference.cpp | 125 -------- mlir/tools/mlir-opt/mlir-opt.cpp | 2 - 11 files changed, 181 insertions(+), 299 deletions(-) delete mode 100644 mlir/test/lib/Transforms/TestIntRangeInference.cpp diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index 9dc262cc72ed0..b8a7d0c78d323 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -64,10 +64,6 @@ void populateArithExpandOpsPatterns(RewritePatternSet &patterns); /// equivalent. std::unique_ptr createArithUnsignedWhenEquivalentPass(); -/// Add patterns for int range based optimizations. -void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, - DataFlowSolver &solver); - /// Create a pass which do optimizations based on integer range analysis. std::unique_ptr createIntRangeOptimizationsPass(); diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 550c5c0cf4f60..1517f71f1a7c9 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> { let summary = "Do optimizations based on integer range analysis"; let description = [{ This pass runs integer range analysis and apllies optimizations based on its - results. e.g. replace arith.cmpi with const if it can be inferred from - args ranges. + results. It replaces operations with known-constant results with said constants, + rewrites `(0 <= %x < D) mod D` to `%x`. }]; + // Explicitly depend on "arith" because this pass could create operations in + // `arith` out of thin air in some cases. + let dependentDialects = [ + "::mlir::arith::ArithDialect" + ]; } def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> { diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 2473169962b95..e991d0fbe7410 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -13,7 +13,8 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/FoldUtils.h" namespace mlir::arith { #define GEN_PASS_DEF_ARITHINTRANGEOPTS @@ -24,155 +25,145 @@ using namespace mlir; using namespace mlir::arith; using namespace mlir::dataflow; -/// Returns true if 2 integer ranges have intersection. -static bool intersects(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs) { - return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) && - (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax()))); +/// Patterned after SCCP +static LogicalResult replaceWithConstant(DataFlowSolver &solver, + RewriterBase &rewriter, + OperationFolder &folder, Value value) { + auto *maybeInferredRange = + solver.lookupState(value); + if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) + return failure(); + const ConstantIntRanges &inferredRange = + maybeInferredRange->getValue().getValue(); + std::optional maybeConstValue = inferredRange.getConstantValue(); + if (!maybeConstValue.has_value()) + return failure(); + + Operation *maybeDefiningOp = value.getDefiningOp(); + Dialect *valueDialect = + maybeDefiningOp ? maybeDefiningOp->getDialect() + : value.getParentRegion()->getParentOp()->getDialect(); + Attribute constAttr = + rewriter.getIntegerAttr(value.getType(), *maybeConstValue); + Value constant = folder.getOrCreateConstant( + rewriter.getInsertionBlock(), valueDialect, constAttr, value.getType()); + // Fall back to arith.constant if the dialect materializer doesn't know what + // to do with an integer constant. + if (!constant) + constant = folder.getOrCreateConstant( + rewriter.getInsertionBlock(), + rewriter.getContext()->getLoadedDialect(), constAttr, + value.getType()); + if (!constant) + return failure(); + + rewriter.replaceAllUsesWith(value, constant); + return success(); } -static FailureOr handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) { - if (!intersects(lhs, rhs)) - return false; - - return failure(); -} - -static FailureOr handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) { - if (!intersects(lhs, rhs)) - return true; - - return failure(); -} - -static FailureOr handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) { - if (lhs.smax().slt(rhs.smin())) - return true; - - if (lhs.smin().sge(rhs.smax())) - return false; - - return failure(); -} - -static FailureOr handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) { - if (lhs.smax().sle(rhs.smin())) - return true; - - if (lhs.smin().sgt(rhs.smax())) - return false; - - return failure(); -} - -static FailureOr handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) { - return handleSlt(std::move(rhs), std::move(lhs)); -} - -static FailureOr handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) { - return handleSle(std::move(rhs), std::move(lhs)); -} - -static FailureOr handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) { - if (lhs.umax().ult(rhs.umin())) - return true; - - if (lhs.umin().uge(rhs.umax())) - return false; - - return failure(); -} - -static FailureOr handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) { - if (lhs.umax().ule(rhs.umin())) - return true; - - if (lhs.umin().ugt(rhs.umax())) - return false; - +/// Rewrite any results of `op` that were inferred to be constant integers to +/// and replace their uses with that constant. Return success() if all results +/// where thus replaced and the operation is erased. +static LogicalResult foldResultsToConstants(DataFlowSolver &solver, + RewriterBase &rewriter, + OperationFolder &folder, + Operation &op) { + bool replacedAll = op.getNumResults() != 0; + for (Value res : op.getResults()) + replacedAll &= + succeeded(replaceWithConstant(solver, rewriter, folder, res)); + + // If all of the results of the operation were replaced, try to erase + // the operation completely. + if (replacedAll && wouldOpBeTriviallyDead(&op)) { + assert(op.use_empty() && "expected all uses to be replaced"); + rewriter.eraseOp(&op); + return success(); + } return failure(); } -static FailureOr handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) { - return handleUlt(std::move(rhs), std::move(lhs)); +/// This function hasn't come from anywhere and is relying on the overall +/// tests of the integer range inference implementation for its correctness. +static LogicalResult deleteTrivialRemainder(DataFlowSolver &solver, + RewriterBase &rewriter, + Operation &op) { + if (!isa(op)) + return failure(); + Value lhs = op.getOperand(0); + Value rhs = op.getOperand(1); + auto rhsConstVal = rhs.getDefiningOp(); + if (!rhsConstVal) + return failure(); + int64_t modulus = rhsConstVal.value(); + if (modulus <= 0) + return failure(); + auto *maybeLhsRange = solver.lookupState(lhs); + if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) + return failure(); + const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); + const APInt &min = llvm::isa(op) ? lhsRange.umin() : lhsRange.smin(); + const APInt &max = llvm::isa(op) ? lhsRange.umax() : lhsRange.smax(); + // The minima and maxima here are given as closed ranges, we must be strictly + // less than the modulus. + if (min.isNegative() || min.uge(modulus)) + return failure(); + if (max.isNegative() || max.uge(modulus)) + return failure(); + if (!min.ule(max)) + return failure(); + + // With all those conditions out of the way, we know thas this invocation of + // a remainder is a noop because the input is strictly within the range + // [0, modulus), so get rid of it. + rewriter.replaceOp(&op, ValueRange{lhs}); + return success(); } -static FailureOr handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) { - return handleUle(std::move(rhs), std::move(lhs)); +static void doRewrites(DataFlowSolver &solver, MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; + auto addToWorklist = [&](MutableArrayRef regions) { + for (Region ®ion : regions) + for (Block &block : llvm::reverse(region)) + worklist.push_back(&block); + }; + + IRRewriter rewriter(context); + OperationFolder folder(context, rewriter.getListener()); + + addToWorklist(initialRegions); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + for (Operation &op : llvm::make_early_inc_range(*block)) { + if (matchPattern(&op, m_Constant())) { + if (auto arithConstant = dyn_cast(op)) + folder.insertKnownConstant(&op, arithConstant.getValue()); + else + folder.insertKnownConstant(&op); + continue; + } + rewriter.setInsertionPoint(&op); + + // Try rewrites. Success means that the underlying operation was erased. + if (succeeded(foldResultsToConstants(solver, rewriter, folder, op))) + continue; + if (isa(op) && + succeeded(deleteTrivialRemainder(solver, rewriter, op))) + continue; + // Add any the regions of this operation to the worklist. + addToWorklist(op.getRegions()); + } + + // Replace any block arguments with constants. + rewriter.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + (void)replaceWithConstant(solver, rewriter, folder, arg); + } } namespace { -/// This class listens on IR transformations performed during a pass relying on -/// information from a `DataflowSolver`. It erases state associated with the -/// erased operation and its results from the `DataFlowSolver` so that Patterns -/// do not accidentally query old state information for newly created Ops. -class DataFlowListener : public RewriterBase::Listener { -public: - DataFlowListener(DataFlowSolver &s) : s(s) {} - -protected: - void notifyOperationErased(Operation *op) override { - s.eraseState(op); - for (Value res : op->getResults()) - s.eraseState(res); - } - - DataFlowSolver &s; -}; - -struct ConvertCmpOp : public OpRewritePattern { - - ConvertCmpOp(MLIRContext *context, DataFlowSolver &s) - : OpRewritePattern(context), solver(s) {} - - LogicalResult matchAndRewrite(arith::CmpIOp op, - PatternRewriter &rewriter) const override { - auto *lhsResult = - solver.lookupState(op.getLhs()); - if (!lhsResult || lhsResult->getValue().isUninitialized()) - return failure(); - - auto *rhsResult = - solver.lookupState(op.getRhs()); - if (!rhsResult || rhsResult->getValue().isUninitialized()) - return failure(); - - using HandlerFunc = - FailureOr (*)(ConstantIntRanges, ConstantIntRanges); - std::array - handlers{}; - using Pred = arith::CmpIPredicate; - handlers[static_cast(Pred::eq)] = &handleEq; - handlers[static_cast(Pred::ne)] = &handleNe; - handlers[static_cast(Pred::slt)] = &handleSlt; - handlers[static_cast(Pred::sle)] = &handleSle; - handlers[static_cast(Pred::sgt)] = &handleSgt; - handlers[static_cast(Pred::sge)] = &handleSge; - handlers[static_cast(Pred::ult)] = &handleUlt; - handlers[static_cast(Pred::ule)] = &handleUle; - handlers[static_cast(Pred::ugt)] = &handleUgt; - handlers[static_cast(Pred::uge)] = &handleUge; - - HandlerFunc handler = handlers[static_cast(op.getPredicate())]; - if (!handler) - return failure(); - - ConstantIntRanges lhsValue = lhsResult->getValue().getValue(); - ConstantIntRanges rhsValue = rhsResult->getValue().getValue(); - FailureOr result = handler(lhsValue, rhsValue); - - if (failed(result)) - return failure(); - - rewriter.replaceOpWithNewOp( - op, static_cast(*result), /*width*/ 1); - return success(); - } - -private: - DataFlowSolver &solver; -}; - struct IntRangeOptimizationsPass : public arith::impl::ArithIntRangeOptsBase { @@ -185,25 +176,11 @@ struct IntRangeOptimizationsPass if (failed(solver.initializeAndRun(op))) return signalPassFailure(); - DataFlowListener listener(solver); - - RewritePatternSet patterns(ctx); - populateIntRangeOptimizationsPatterns(patterns, solver); - - GreedyRewriteConfig config; - config.listener = &listener; - - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) - signalPassFailure(); + doRewrites(solver, ctx, op->getRegions()); } }; } // namespace -void mlir::arith::populateIntRangeOptimizationsPatterns( - RewritePatternSet &patterns, DataFlowSolver &solver) { - patterns.add(patterns.getContext(), solver); -} - std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir index 60f0ab41afa48..e00b7692fe396 100644 --- a/mlir/test/Dialect/Arith/int-range-interface.mlir +++ b/mlir/test/Dialect/Arith/int-range-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s +// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s // CHECK-LABEL: func @add_min_max // CHECK: %[[c3:.*]] = arith.constant 3 : index diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir index dd62a481a1246..ea5969a100258 100644 --- a/mlir/test/Dialect/Arith/int-range-opts.mlir +++ b/mlir/test/Dialect/Arith/int-range-opts.mlir @@ -96,3 +96,39 @@ func.func @test() -> i8 { return %1: i8 } +// ----- + +// CHECK-LABEL: func @trivial_rem +// CHECK: [[val:%.+]] = test.with_bounds +// CHECK: return [[val]] +func.func @trivial_rem() -> i8 { + %c64 = arith.constant 64 : i8 + %val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8 + %mod = arith.remsi %val, %c64 : i8 + return %mod : i8 +} + +// ----- + +// CHECK-LABEL: func @non_const_rhs +// CHECK: [[mod:%.+]] = arith.remui +// CHECK: return [[mod]] +func.func @non_const_rhs() -> i8 { + %c64 = arith.constant 64 : i8 + %val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8 + %rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8 + %mod = arith.remui %val, %rhs : i8 + return %mod : i8 +} + +// ----- + +// CHECK-LABEL: func @wraps +// CHECK: [[mod:%.+]] = arith.remsi +// CHECK: return [[mod]] +func.func @wraps() -> i8 { + %c64 = arith.constant 64 : i8 + %val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8 + %mod = arith.remsi %val, %c64 : i8 + return %mod : i8 +} diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir index 980f7e5873e0c..a0917a2fdf110 100644 --- a/mlir/test/Dialect/GPU/int-range-interface.mlir +++ b/mlir/test/Dialect/GPU/int-range-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s +// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s // CHECK-LABEL: func @launch_func func.func @launch_func(%arg0 : index) { diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir index 2784d5fd5cf70..951624d573a64 100644 --- a/mlir/test/Dialect/Index/int-range-inference.mlir +++ b/mlir/test/Dialect/Index/int-range-inference.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s +// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s // Most operations are covered by the `arith` tests, which use the same code // Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir index 2106eeefdca4d..1ec3441b1fde8 100644 --- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s +// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s // CHECK-LABEL: func @constant // CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index} @@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index { // CHECK-LABEL: func @propagate_across_while_loop_false() func.func @propagate_across_while_loop_false() -> index { - // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0 - // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1 + // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1 %0 = test.with_bounds { umin = 0 : index, umax = 0 : index, smin = 0 : index, smax = 0 : index } : index %1 = scf.while : () -> index { %false = arith.constant false - // CHECK: scf.condition(%{{.*}}) %[[C0]] scf.condition(%false) %0 : index } do { ^bb0(%i1: index): @@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index { // CHECK-LABEL: func @propagate_across_while_loop func.func @propagate_across_while_loop(%arg0 : i1) -> index { - // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0 - // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1 + // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1 %0 = test.with_bounds { umin = 0 : index, umax = 0 : index, smin = 0 : index, smax = 0 : index } : index %1 = scf.while : () -> index { - // CHECK: scf.condition(%{{.*}}) %[[C0]] scf.condition(%arg0) %0 : index } do { ^bb0(%i1: index): diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 975a41ac3d5fe..66b1faf78e2d8 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -24,7 +24,6 @@ add_mlir_library(MLIRTestTransforms TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp - TestIntRangeInference.cpp TestMakeIsolatedFromAbove.cpp ${MLIRTestTransformsPDLSrc} diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp deleted file mode 100644 index 5758f6acf2f0f..0000000000000 --- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp +++ /dev/null @@ -1,125 +0,0 @@ -//===- TestIntRangeInference.cpp - Create consts from range inference ---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// TODO: This pass is needed to test integer range inference until that -// functionality has been integrated into SCCP. -//===----------------------------------------------------------------------===// - -#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" -#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" -#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/TypeID.h" -#include "mlir/Transforms/FoldUtils.h" -#include - -using namespace mlir; -using namespace mlir::dataflow; - -/// Patterned after SCCP -static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b, - OperationFolder &folder, Value value) { - auto *maybeInferredRange = - solver.lookupState(value); - if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) - return failure(); - const ConstantIntRanges &inferredRange = - maybeInferredRange->getValue().getValue(); - std::optional maybeConstValue = inferredRange.getConstantValue(); - if (!maybeConstValue.has_value()) - return failure(); - - Operation *maybeDefiningOp = value.getDefiningOp(); - Dialect *valueDialect = - maybeDefiningOp ? maybeDefiningOp->getDialect() - : value.getParentRegion()->getParentOp()->getDialect(); - Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); - Value constant = folder.getOrCreateConstant( - b.getInsertionBlock(), valueDialect, constAttr, value.getType()); - if (!constant) - return failure(); - - value.replaceAllUsesWith(constant); - return success(); -} - -static void rewrite(DataFlowSolver &solver, MLIRContext *context, - MutableArrayRef initialRegions) { - SmallVector worklist; - auto addToWorklist = [&](MutableArrayRef regions) { - for (Region ®ion : regions) - for (Block &block : llvm::reverse(region)) - worklist.push_back(&block); - }; - - OpBuilder builder(context); - OperationFolder folder(context); - - addToWorklist(initialRegions); - while (!worklist.empty()) { - Block *block = worklist.pop_back_val(); - - for (Operation &op : llvm::make_early_inc_range(*block)) { - builder.setInsertionPoint(&op); - - // Replace any result with constants. - bool replacedAll = op.getNumResults() != 0; - for (Value res : op.getResults()) - replacedAll &= - succeeded(replaceWithConstant(solver, builder, folder, res)); - - // If all of the results of the operation were replaced, try to erase - // the operation completely. - if (replacedAll && wouldOpBeTriviallyDead(&op)) { - assert(op.use_empty() && "expected all uses to be replaced"); - op.erase(); - continue; - } - - // Add any the regions of this operation to the worklist. - addToWorklist(op.getRegions()); - } - - // Replace any block arguments with constants. - builder.setInsertionPointToStart(block); - for (BlockArgument arg : block->getArguments()) - (void)replaceWithConstant(solver, builder, folder, arg); - } -} - -namespace { -struct TestIntRangeInference - : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference) - - StringRef getArgument() const final { return "test-int-range-inference"; } - StringRef getDescription() const final { - return "Test integer range inference analysis"; - } - - void runOnOperation() override { - Operation *op = getOperation(); - DataFlowSolver solver; - solver.load(); - solver.load(); - solver.load(); - if (failed(solver.initializeAndRun(op))) - return signalPassFailure(); - rewrite(solver, op->getContext(), op->getRegions()); - } -}; -} // end anonymous namespace - -namespace mlir { -namespace test { -void registerTestIntRangeInference() { - PassRegistration(); -} -} // end namespace test -} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 0e8b161d51345..b50cae1056ba4 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -97,7 +97,6 @@ void registerTestExpandMathPass(); void registerTestFooAnalysisPass(); void registerTestComposeSubView(); void registerTestMultiBuffering(); -void registerTestIntRangeInference(); void registerTestIRVisitorsPass(); void registerTestGenericIRVisitorsPass(); void registerTestInterfaces(); @@ -226,7 +225,6 @@ void registerTestPasses() { mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); - mlir::test::registerTestIntRangeInference(); mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestGenericIRVisitorsPass(); mlir::test::registerTestInterfaces(); From dbccd530cb80df0711e0533f10efe7460a37b5e7 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Fri, 7 Jun 2024 20:54:21 +0000 Subject: [PATCH 2/4] Go to patterns instead --- .../mlir/Dialect/Arith/Transforms/Passes.h | 4 + .../Transforms/IntRangeOptimizations.cpp | 247 ++++++++++-------- 2 files changed, 142 insertions(+), 109 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index b8a7d0c78d323..9dc262cc72ed0 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -64,6 +64,10 @@ void populateArithExpandOpsPatterns(RewritePatternSet &patterns); /// equivalent. std::unique_ptr createArithUnsignedWhenEquivalentPass(); +/// Add patterns for int range based optimizations. +void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, + DataFlowSolver &solver); + /// Create a pass which do optimizations based on integer range analysis. std::unique_ptr createIntRangeOptimizationsPass(); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index e991d0fbe7410..c3938ed1be15a 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -8,13 +8,17 @@ #include +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::arith { #define GEN_PASS_DEF_ARITHINTRANGEOPTS @@ -25,17 +29,22 @@ using namespace mlir; using namespace mlir::arith; using namespace mlir::dataflow; -/// Patterned after SCCP -static LogicalResult replaceWithConstant(DataFlowSolver &solver, - RewriterBase &rewriter, - OperationFolder &folder, Value value) { +std::optional getMaybeConstantValue(DataFlowSolver &solver, + Value value) { auto *maybeInferredRange = solver.lookupState(value); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) - return failure(); + return std::nullopt; const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue(); - std::optional maybeConstValue = inferredRange.getConstantValue(); + return inferredRange.getConstantValue(); +} + +/// Patterned after SCCP +static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, + PatternRewriter &rewriter, + Value value) { + std::optional maybeConstValue = getMaybeConstantValue(solver, value); if (!maybeConstValue.has_value()) return failure(); @@ -45,125 +54,130 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver, : value.getParentRegion()->getParentOp()->getDialect(); Attribute constAttr = rewriter.getIntegerAttr(value.getType(), *maybeConstValue); - Value constant = folder.getOrCreateConstant( - rewriter.getInsertionBlock(), valueDialect, constAttr, value.getType()); + Operation *constOp = valueDialect->materializeConstant( + rewriter, constAttr, value.getType(), value.getLoc()); // Fall back to arith.constant if the dialect materializer doesn't know what // to do with an integer constant. - if (!constant) - constant = folder.getOrCreateConstant( - rewriter.getInsertionBlock(), - rewriter.getContext()->getLoadedDialect(), constAttr, - value.getType()); - if (!constant) + if (!constOp) + constOp = rewriter.getContext() + ->getLoadedDialect() + ->materializeConstant(rewriter, constAttr, value.getType(), + value.getLoc()); + if (!constOp) return failure(); - rewriter.replaceAllUsesWith(value, constant); + rewriter.replaceAllUsesWith(value, constOp->getResult(0)); return success(); } +namespace { +class DataFlowListener : public RewriterBase::Listener { +public: + DataFlowListener(DataFlowSolver &s) : s(s) {} + +protected: + void notifyOperationErased(Operation *op) override { + s.eraseState(op); + for (Value res : op->getResults()) + s.eraseState(res); + } + + DataFlowSolver &s; +}; + /// Rewrite any results of `op` that were inferred to be constant integers to /// and replace their uses with that constant. Return success() if all results -/// where thus replaced and the operation is erased. -static LogicalResult foldResultsToConstants(DataFlowSolver &solver, - RewriterBase &rewriter, - OperationFolder &folder, - Operation &op) { - bool replacedAll = op.getNumResults() != 0; - for (Value res : op.getResults()) - replacedAll &= - succeeded(replaceWithConstant(solver, rewriter, folder, res)); - - // If all of the results of the operation were replaced, try to erase - // the operation completely. - if (replacedAll && wouldOpBeTriviallyDead(&op)) { - assert(op.use_empty() && "expected all uses to be replaced"); - rewriter.eraseOp(&op); - return success(); +/// where thus replaced and the operation is erased. Also replace any block +/// arguments with their constant values. +struct MaterializeKnownConstantValues : public RewritePattern { + MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) + : RewritePattern(Pattern::MatchAnyOpTypeTag(), 1, context), solver(s) {} + + LogicalResult match(Operation *op) const override { + if (matchPattern(op, m_Constant())) + return failure(); + + auto needsReplacing = [&](Value v) { + return getMaybeConstantValue(solver, v).has_value() && !v.use_empty(); + }; + bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing); + if (op->getNumRegions() == 0) + return success(hasConstantResults); + bool hasConstantRegionArgs = false; + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) { + hasConstantRegionArgs |= + llvm::any_of(block.getArguments(), needsReplacing); + } + } + return success(hasConstantResults || hasConstantRegionArgs); } - return failure(); -} - -/// This function hasn't come from anywhere and is relying on the overall -/// tests of the integer range inference implementation for its correctness. -static LogicalResult deleteTrivialRemainder(DataFlowSolver &solver, - RewriterBase &rewriter, - Operation &op) { - if (!isa(op)) - return failure(); - Value lhs = op.getOperand(0); - Value rhs = op.getOperand(1); - auto rhsConstVal = rhs.getDefiningOp(); - if (!rhsConstVal) - return failure(); - int64_t modulus = rhsConstVal.value(); - if (modulus <= 0) - return failure(); - auto *maybeLhsRange = solver.lookupState(lhs); - if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) - return failure(); - const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); - const APInt &min = llvm::isa(op) ? lhsRange.umin() : lhsRange.smin(); - const APInt &max = llvm::isa(op) ? lhsRange.umax() : lhsRange.smax(); - // The minima and maxima here are given as closed ranges, we must be strictly - // less than the modulus. - if (min.isNegative() || min.uge(modulus)) - return failure(); - if (max.isNegative() || max.uge(modulus)) - return failure(); - if (!min.ule(max)) - return failure(); - // With all those conditions out of the way, we know thas this invocation of - // a remainder is a noop because the input is strictly within the range - // [0, modulus), so get rid of it. - rewriter.replaceOp(&op, ValueRange{lhs}); - return success(); -} + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + bool replacedAll = (op->getNumResults() != 0); + for (Value v : op->getResults()) + replacedAll &= succeeded(maybeReplaceWithConstant(solver, rewriter, v)); + if (replacedAll && isOpTriviallyDead(op)) { + rewriter.eraseOp(op); + return; + } -static void doRewrites(DataFlowSolver &solver, MLIRContext *context, - MutableArrayRef initialRegions) { - SmallVector worklist; - auto addToWorklist = [&](MutableArrayRef regions) { - for (Region ®ion : regions) - for (Block &block : llvm::reverse(region)) - worklist.push_back(&block); - }; - - IRRewriter rewriter(context); - OperationFolder folder(context, rewriter.getListener()); - - addToWorklist(initialRegions); - while (!worklist.empty()) { - Block *block = worklist.pop_back_val(); - - for (Operation &op : llvm::make_early_inc_range(*block)) { - if (matchPattern(&op, m_Constant())) { - if (auto arithConstant = dyn_cast(op)) - folder.insertKnownConstant(&op, arithConstant.getValue()); - else - folder.insertKnownConstant(&op); - continue; + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) { + for (BlockArgument &arg : block.getArguments()) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + (void)maybeReplaceWithConstant(solver, rewriter, arg); + } } - rewriter.setInsertionPoint(&op); - - // Try rewrites. Success means that the underlying operation was erased. - if (succeeded(foldResultsToConstants(solver, rewriter, folder, op))) - continue; - if (isa(op) && - succeeded(deleteTrivialRemainder(solver, rewriter, op))) - continue; - // Add any the regions of this operation to the worklist. - addToWorklist(op.getRegions()); } + } - // Replace any block arguments with constants. - rewriter.setInsertionPointToStart(block); - for (BlockArgument arg : block->getArguments()) - (void)replaceWithConstant(solver, rewriter, folder, arg); +private: + DataFlowSolver &solver; +}; + +template +struct DeleteTrivialRem : public OpRewritePattern { + DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s) + : OpRewritePattern(context), solver(s) {} + + LogicalResult matchAndRewrite(RemOp op, + PatternRewriter &rewriter) const override { + Value lhs = op.getOperand(0); + Value rhs = op.getOperand(1); + auto rhsConstVal = rhs.getDefiningOp(); + if (!rhsConstVal) + return failure(); + int64_t modulus = rhsConstVal.value(); + if (modulus <= 0) + return failure(); + auto *maybeLhsRange = solver.lookupState(lhs); + if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) + return failure(); + const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); + const APInt &min = isa(op) ? lhsRange.umin() : lhsRange.smin(); + const APInt &max = isa(op) ? lhsRange.umax() : lhsRange.smax(); + // The minima and maxima here are given as closed ranges, we must be + // strictly less than the modulus. + if (min.isNegative() || min.uge(modulus)) + return failure(); + if (max.isNegative() || max.uge(modulus)) + return failure(); + if (!min.ule(max)) + return failure(); + + // With all those conditions out of the way, we know thas this invocation of + // a remainder is a noop because the input is strictly within the range + // [0, modulus), so get rid of it. + rewriter.replaceOp(op, ValueRange{lhs}); + return success(); } -} -namespace { +private: + DataFlowSolver &solver; +}; + struct IntRangeOptimizationsPass : public arith::impl::ArithIntRangeOptsBase { @@ -176,11 +190,26 @@ struct IntRangeOptimizationsPass if (failed(solver.initializeAndRun(op))) return signalPassFailure(); - doRewrites(solver, ctx, op->getRegions()); + DataFlowListener listener(solver); + + RewritePatternSet patterns(ctx); + populateIntRangeOptimizationsPatterns(patterns, solver); + + GreedyRewriteConfig config; + config.listener = &listener; + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) + signalPassFailure(); } }; } // namespace +void mlir::arith::populateIntRangeOptimizationsPatterns( + RewritePatternSet &patterns, DataFlowSolver &solver) { + patterns.add, + DeleteTrivialRem>(patterns.getContext(), solver); +} + std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { return std::make_unique(); } From 5a1d684e5d986c52b95d7778c6dbc26aaff5cb72 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Sun, 9 Jun 2024 22:38:14 -0700 Subject: [PATCH 3/4] Address review comments --- .../Transforms/IntRangeOptimizations.cpp | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index c3938ed1be15a..8005f9103b235 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -14,6 +14,7 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -29,8 +30,8 @@ using namespace mlir; using namespace mlir::arith; using namespace mlir::dataflow; -std::optional getMaybeConstantValue(DataFlowSolver &solver, - Value value) { +static std::optional getMaybeConstantValue(DataFlowSolver &solver, + Value value) { auto *maybeInferredRange = solver.lookupState(value); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) @@ -44,6 +45,8 @@ std::optional getMaybeConstantValue(DataFlowSolver &solver, static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, PatternRewriter &rewriter, Value value) { + if (value.use_empty()) + return failure(); std::optional maybeConstValue = getMaybeConstantValue(solver, value); if (!maybeConstValue.has_value()) return failure(); @@ -91,7 +94,8 @@ class DataFlowListener : public RewriterBase::Listener { /// arguments with their constant values. struct MaterializeKnownConstantValues : public RewritePattern { MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) - : RewritePattern(Pattern::MatchAnyOpTypeTag(), 1, context), solver(s) {} + : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context), + solver(s) {} LogicalResult match(Operation *op) const override { if (matchPattern(op, m_Constant())) @@ -116,17 +120,19 @@ struct MaterializeKnownConstantValues : public RewritePattern { void rewrite(Operation *op, PatternRewriter &rewriter) const override { bool replacedAll = (op->getNumResults() != 0); for (Value v : op->getResults()) - replacedAll &= succeeded(maybeReplaceWithConstant(solver, rewriter, v)); + replacedAll &= + (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) || + v.use_empty()); if (replacedAll && isOpTriviallyDead(op)) { rewriter.eraseOp(op); return; } + PatternRewriter::InsertionGuard guard(rewriter); for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { + rewriter.setInsertionPointToStart(&block); for (BlockArgument &arg : block.getArguments()) { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); (void)maybeReplaceWithConstant(solver, rewriter, arg); } } @@ -146,10 +152,10 @@ struct DeleteTrivialRem : public OpRewritePattern { PatternRewriter &rewriter) const override { Value lhs = op.getOperand(0); Value rhs = op.getOperand(1); - auto rhsConstVal = rhs.getDefiningOp(); - if (!rhsConstVal) + auto maybeModulus = getConstantIntValue(rhs); + if (!maybeModulus.has_value()) return failure(); - int64_t modulus = rhsConstVal.value(); + int64_t modulus = *maybeModulus; if (modulus <= 0) return failure(); auto *maybeLhsRange = solver.lookupState(lhs); From 0dc4182d17880132b38cd820430c43e486d83be7 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Mon, 10 Jun 2024 00:35:43 -0700 Subject: [PATCH 4/4] Actually solve the merge conflicts --- mlir/tools/mlir-opt/mlir-opt.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index b3b0bc7e1e1a4..d0de74dd6eaf4 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -228,15 +228,11 @@ void registerTestPasses() { mlir::test::registerTestEmulateNarrowTypePass(); mlir::test::registerTestExpandMathPass(); mlir::test::registerTestFooAnalysisPass(); -<<<<<<< HEAD mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); mlir::test::registerTestIRVisitorsPass(); -======= ->>>>>>> main mlir::test::registerTestGenericIRVisitorsPass(); mlir::test::registerTestInterfaces(); - mlir::test::registerTestIntRangeInference(); mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgDecomposeOps();