diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td index 81bab1b0c82f7..2be2d019e1122 100644 --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -189,4 +189,15 @@ def TypeConversionCastShapeDynamicDimsOp : Op]> { + let description = [{ + Indicates that tensor.gather ops should be decomposed into a chain of + tensor.extract_slice and linalg.generic to extract the element from source. + }]; + + let assemblyFormat = "attr-dict"; +} + #endif // TENSOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index ae695e0326ca1..fa73f74d0be66 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -102,6 +102,13 @@ using ControlFoldFn = std::function; void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn); +/// Populates `patterns` with patterns that decompose `tensor.gather` into +/// `tensor.empty` and `linalg.geric`, followed by a chain +/// of `tensor.extract_slice` operations on the inputs. This is intended to be +/// used as a tensor -> linalg lowering that decomposes gather such +/// that it can be bufferized into a sequence of bufferized op. +void populateDecomposeTensorGatherPatterns(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Transform helpers //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 99199252710f9..cb2d01df40b8d 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -143,6 +143,11 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn); } +void transform::ApplyDecomposeTensorGatherPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateDecomposeTensorGatherPatterns(patterns); +} + //===----------------------------------------------------------------------===// // TypeConversionCastTensorShapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt index cc6275fee671a..f1a23e5e3bfbf 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTensorTransforms EmptyOpPatterns.cpp ExtractSliceFromReshapeUtils.cpp FoldTensorSubsetOps.cpp + GatherOpPatterns.cpp IndependenceTransforms.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp PackAndUnpackPatterns.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp new file mode 100644 index 0000000000000..5905ee049228a --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp @@ -0,0 +1,166 @@ +//===- GatherOpPatterns.cpp - Patterns related to tensor.concat lowering --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +/// Decompose `tensor.gather` into `linalg.generic`. +/// +/// %2 = tensor.gather %0[%1] gather_dims([0]) : (tensor<7x128xf16>, +/// tensor<1x7x1xindex>) -> tensor<1x7x128xf16> +/// +/// Becomes +/// +/// %empty = tensor.empty() : tensor<1x7x128xf16> +/// %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, +/// 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = +/// ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x7x1xindex>) +/// outs(%13 : tensor<1x7x128xf16>) { +/// ^bb0(%in: index, %out: f16): +/// %17 = linalg.index 2 : index +/// %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16> +/// linalg.yield %extracted : f16 +/// } -> tensor<1x7x128xf16> +struct DecomposeTensorGatherOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + SmallVector getDstMixedSizes(PatternRewriter &rewriter, + Location loc, + tensor::GatherOp gatherOp) const { + SmallVector dstSize = + tensor::getMixedSizes(rewriter, loc, gatherOp.getResult()); + SmallVector indexSize = + tensor::getMixedSizes(rewriter, loc, gatherOp.getIndices()); + SmallVector srcSize = + tensor::getMixedSizes(rewriter, loc, gatherOp.getSource()); + SmallVector gatherDims(gatherOp.getGatherDims()); + bool isShrinkDst = (indexSize.size() - 1) + srcSize.size() == + dstSize.size() + gatherDims.size(); + for (size_t i = 0; i < indexSize.size() - 1; i++) { + dstSize[i] = indexSize[i]; + } + auto cnt = 0; + for (size_t i = indexSize.size() - 1; i < dstSize.size(); i++) { + while (isShrinkDst && llvm::find(gatherDims, cnt) != gatherDims.end()) { + cnt++; + } + dstSize[i] = llvm::find(gatherDims, cnt) == gatherDims.end() + ? srcSize[cnt] + : getAsIndexOpFoldResult(rewriter.getContext(), 1); + cnt++; + } + return dstSize; + } + + LogicalResult matchAndRewrite(tensor::GatherOp gatherOp, + PatternRewriter &rewriter) const override { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(gatherOp); + Location loc = gatherOp.getLoc(); + SmallVector gatherDims(gatherOp.getGatherDims()); + + // create destination tensor for linalg out + RankedTensorType dstType = gatherOp.getResultType(); + Value dstTensor = rewriter.create( + loc, getDstMixedSizes(rewriter, loc, gatherOp), + dstType.getElementType()); + + // split index tensor to create the linalg input + SmallVector indexTensors; + Value originIndexTensor = gatherOp.getIndices(); + SmallVector indexTensorSize = + tensor::getMixedSizes(rewriter, loc, originIndexTensor); + SmallVector indexTensorStride( + indexTensorSize.size(), + getAsIndexOpFoldResult(rewriter.getContext(), 1)); + SmallVector indexTensorOffset( + indexTensorSize.size(), + getAsIndexOpFoldResult(rewriter.getContext(), 0)); + indexTensorSize[indexTensorSize.size() - 1] = + getAsIndexOpFoldResult(rewriter.getContext(), 1); + + for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) { + indexTensorOffset[indexTensorSize.size() - 1] = + getAsIndexOpFoldResult(rewriter.getContext(), cnt); + Value indexTensor = rewriter.create( + loc, originIndexTensor, indexTensorOffset, indexTensorSize, + indexTensorStride); + indexTensors.emplace_back(indexTensor); + } + + // create the affine map + SmallVector affineMaps; + SmallVector dimExprs; + size_t dstRank = dstType.getShape().size(); + for (unsigned i = 0; i < indexTensorSize.size() - 1; ++i) + dimExprs.push_back(rewriter.getAffineDimExpr(i)); + dimExprs.push_back(getAffineConstantExpr(0, rewriter.getContext())); + + for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) { + AffineMap currentMap = + AffineMap::get(/*dimCount=*/dstRank, /*symbolCount=*/0, dimExprs, + rewriter.getContext()); + affineMaps.emplace_back(currentMap); + } + affineMaps.emplace_back(rewriter.getMultiDimIdentityMap(dstRank)); + + // create iterater types array + SmallVector iteratorTypesArray( + dstRank, utils::IteratorType::parallel); + + // check whether the gather op is valid + size_t srcRank = gatherOp.getSourceType().getShape().size(); + assert(((indexTensorSize.size() - 1) + srcRank == dstRank || + (indexTensorSize.size() - 1) + srcRank == + dstRank + gatherDims.size()) && + "Expected: index_size - 1 + source_size == dst_size or dst_szie - " + "gather_size. \n"); + rewriter.replaceOpWithNewOp( + gatherOp, TypeRange(dstType), indexTensors, ValueRange{dstTensor}, + affineMaps, iteratorTypesArray, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indexValues(srcRank); + bool isShrinkDst = (indexTensorSize.size() - 1) + srcRank == + dstRank + gatherDims.size(); + int cnt = 0; + for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) { + while (isShrinkDst && + llvm::find(gatherDims, cnt) != gatherDims.end()) { + cnt++; + } + indexValues[cnt] = b.create(loc, i); + cnt++; + } + for (auto &&[i, dim] : llvm::enumerate(gatherDims)) { + indexValues[dim] = args[i]; + } + + Value extract = b.create(loc, gatherOp.getSource(), + indexValues); + b.create(loc, extract); + }); + return success(); + } +}; + +} // namespace + +void mlir::tensor::populateDecomposeTensorGatherPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Tensor/decompose-gather.mlir b/mlir/test/Dialect/Tensor/decompose-gather.mlir new file mode 100644 index 0000000000000..587dfc8cc7e2f --- /dev/null +++ b/mlir/test/Dialect/Tensor/decompose-gather.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s + +/// CHECK-LABEL: @gather_single_gather_dim +func.func @gather_single_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> { + /// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2x2xf32> + /// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2x2xf32>) + %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> + return %1 : tensor<2x3x2x2x2xf32> +} + +/// CHECK-LABEL: @gather_single_gather_dim_no_shrink +func.func @gather_single_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> { + /// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x3x2x1x2x2xf32> + /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY1:.*]] : tensor<2x3x2x1x2x2xf32>) + %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> + return %1 : tensor<2x3x2x1x2x2xf32> +} + +/// CHECK-LABEL: @gather_multiple_gather_dim +func.func @gather_multiple_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> { + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2xf32> + /// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex> + /// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex> + /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2xf32>) + %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> + return %1 : tensor<2x3x2x2xf32> +} + +/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink +func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> { + %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> + return %1 : tensor<2x3x2x1x1x2xf32> +} + +/// CHECK-LABEL: @gather_single_gather_dim_dynamic +func.func @gather_single_gather_dim_dynamic(%arg0: tensor, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32> { + /// CHECK: %[[DIM1:.*]] = tensor.dim + /// CHECK: %[[DIM2:.*]] = tensor.dim + /// CHECK: %[[DIM3:.*]] = tensor.dim + /// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]], %[[DIM3:.*]]) : tensor<2x3x?x?x?xf32> + /// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x?x?x?xf32>) + %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor, tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32> + return %1 : tensor<2x3x?x?x?xf32> +} + +/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink_dynamic +func.func @gather_multiple_gather_dim_no_shrink_dynamic(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor) -> tensor { + /// CHECK: %[[DIM1:.*]] = tensor.dim + /// CHECK: %[[DIM2:.*]] = tensor.dim + /// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]]) : tensor + /// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor to tensor + /// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor to tensor + /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor, tensor) outs(%[[EMPTY:.*]] : tensor) + %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor) -> tensor + return %1 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { + %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op { + transform.apply_patterns.tensor.decompose_gather + } : !transform.op<"func.func"> + transform.yield + } +}