diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index ad313c2d5ce60..e73df61c96434 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -956,6 +956,69 @@ class FoldWithProducerReshapeOpByExpansion ControlFusionFn controlFoldingReshapes; }; +class FoldPadWithProducerReshapeOpByExpansion + : public OpRewritePattern { +public: + FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + tensor::CollapseShapeOp reshapeOp = + padOp.getSource().getDefiningOp(); + if (!reshapeOp) + return failure(); + if (!reshapeOp->hasOneUse()) + return failure(); + + if (!controlFoldingReshapes(&padOp.getSourceMutable())) { + return rewriter.notifyMatchFailure(padOp, + "fusion blocked by control function"); + } + + ArrayRef low = padOp.getStaticLow(); + ArrayRef high = padOp.getStaticHigh(); + SmallVector reassociations = + reshapeOp.getReassociationIndices(); + + for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { + if (reInd.size() != 1 && (l != 0 || h != 0)) + return failure(); + } + + SmallVector newLow, newHigh; + RankedTensorType expandedType = reshapeOp.getSrcType(); + RankedTensorType paddedType = padOp.getResultType(); + SmallVector expandedPaddedShape(expandedType.getShape()); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; + } + for (size_t i = 0; i < reInd.size(); ++i) { + newLow.push_back(padOp.getMixedLowPad()[idx]); + newHigh.push_back(padOp.getMixedHighPad()[idx]); + } + } + + Location loc = padOp->getLoc(); + RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); + auto newPadOp = rewriter.create( + loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOpWithNewOp( + padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to fold a tensor.expand_shape op with its producer generic op /// by expanding the dimensionality of the loop in the producer op. struct FoldReshapeWithGenericOpByExpansion @@ -1702,6 +1765,85 @@ class FoldWithProducerReshapeOpByCollapsing ControlFusionFn controlFoldingReshapes; }; +class FoldPadWithProducerReshapeOpByCollapsing + : public OpRewritePattern { +public: + FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + tensor::ExpandShapeOp reshapeOp = + padOp.getSource().getDefiningOp(); + if (!reshapeOp) + return failure(); + if (!reshapeOp->hasOneUse()) + return failure(); + + if (!controlFoldingReshapes(&padOp.getSourceMutable())) { + return rewriter.notifyMatchFailure(padOp, + "fusion blocked by control function"); + } + + ArrayRef low = padOp.getStaticLow(); + ArrayRef high = padOp.getStaticHigh(); + SmallVector reassociations = + reshapeOp.getReassociationIndices(); + + for (auto reInd : reassociations) { + if (reInd.size() == 1) + continue; + if (llvm::any_of(reInd, [&](int64_t ind) { + return low[ind] != 0 || high[ind] != 0; + })) { + return failure(); + } + } + + SmallVector newLow, newHigh; + RankedTensorType collapsedType = reshapeOp.getSrcType(); + RankedTensorType paddedType = padOp.getResultType(); + SmallVector collapsedPaddedShape(collapsedType.getShape()); + SmallVector expandedPaddedSizes( + getMixedValues(reshapeOp.getStaticOutputShape(), + reshapeOp.getOutputShape(), rewriter)); + AffineExpr d0, d1, d2; + bindDims(rewriter.getContext(), d0, d1, d2); + auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2}); + Location loc = reshapeOp->getLoc(); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + OpFoldResult l = padOp.getMixedLowPad()[reInd[0]]; + OpFoldResult h = padOp.getMixedHighPad()[reInd[0]]; + if (reInd.size() == 1) { + collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; + OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( + rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]}); + expandedPaddedSizes[reInd[0]] = paddedSize; + } + newLow.push_back(l); + newHigh.push_back(h); + } + + RankedTensorType collapsedPaddedType = + paddedType.clone(collapsedPaddedShape); + auto newPadOp = rewriter.create( + loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOpWithNewOp( + padOp, padOp.getResultType(), newPadOp.getResult(), reassociations, + expandedPaddedSizes); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to collapse dimensions. template class CollapseLinalgDimensions : public OpRewritePattern { @@ -1937,6 +2079,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( const ControlFusionFn &controlFoldingReshapes) { patterns.add(patterns.getContext(), controlFoldingReshapes); + patterns.add(patterns.getContext(), + controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } @@ -1946,6 +2090,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( const ControlFusionFn &controlFoldingReshapes) { patterns.add(patterns.getContext(), controlFoldingReshapes); + patterns.add( + patterns.getContext(), controlFoldingReshapes); } void mlir::linalg::populateElementwiseOpsFusionPatterns( diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 0d40df534a3bb..600f0dea31f4a 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor, %sz0: // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXPAND_ARG0]] : // CHECK: return %[[GENERIC]] + +// ----- + +func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, + %arg5: index, %arg6: index, %arg7: index, %arg8: index): + tensor.yield %cst : i32 + } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32> + return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32> +} +// CHECK: func @fuse_by_collapsing_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] +// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK-SAME: output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32> +// CHECK: return %[[EXPAND]] + +// ----- + +func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, + %arg5: index, %arg6: index, %arg7: index, %arg8: index): + tensor.yield %cst : i32 + } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32> + return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32> +} +// CHECK: func @no_fuse_by_collapsing_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>) +// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK-SAME: output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]] +// CHECK-SAME: low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] +// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32> +// CHECK: return %[[PAD]] + +// ----- + +func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor, + %s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index, + %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor { + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor into tensor + %cst = arith.constant 0.0 : f32 + %padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor to tensor + return %padded_0 : tensor +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)> +// CHECK: func @fuse_by_collapsing_dynamic_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]] +// CHECK: %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK-SAME: low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0] +// CHECK: tensor to tensor +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]] +// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor into tensor +// CHECK: return %[[EXPAND]] diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index f42666f81bbad..b8df5fc88e199 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -826,3 +826,64 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor, // CHECK-SAME: [0, 1], [2, 3] // CHECK-SAME: tensor into tensor // CHECK: return %[[T4]] + +// ----- + +func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index): + tensor.yield %cst : i32 + } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32> + return %padded_0 : tensor<8x12x17x336x14xi32> +} +// CHECK: func @fuse_by_expanding_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK-SAME: low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] +// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK-SAME: : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32> +// CHECK: return %[[COLLAPSE]] + +// ----- + +func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index): + tensor.yield %cst : i32 + } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32> + return %padded_0 : tensor<8x12x17x339x14xi32> +} +// CHECK: func @no_fuse_by_expanding_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>) +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK-SAME: : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] +// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] +// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32> +// CHECK: return %[[PAD]] + +// ----- + +func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor into tensor + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst : i32 + } : tensor to tensor + return %padded_0 : tensor +} +// CHECK: func @fuse_by_expanding_dynamic_pad( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK-SAME: low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0] +// CHECK: tensor to tensor +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]] +// CHECK-SAME: : tensor into tensor +// CHECK: return %[[COLLAPSE]]