Skip to content

Commit bfabf8d

Browse files
[mlir][tensor] Make tensor::PadOp a ReifyRankedShapedTypeOpInterface
1 parent 6dad1e8 commit bfabf8d

File tree

8 files changed

+53
-10
lines changed

8 files changed

+53
-10
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
12561256

12571257
def Tensor_PadOp : Tensor_Op<"pad", [
12581258
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1259+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
12591260
AttrSizedOperandSegments,
12601261
Pure,
12611262
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Affine/Utils.h"
1011
#include "mlir/Dialect/Arith/IR/Arith.h"
1112
#include "mlir/Dialect/Arith/Utils/Utils.h"
1213
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -3793,6 +3794,30 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
37933794

37943795
} // namespace
37953796

3797+
LogicalResult
3798+
PadOp::reifyResultShapes(OpBuilder &b,
3799+
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3800+
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3801+
SmallVector<OpFoldResult> lp = getMixedLowPad();
3802+
SmallVector<OpFoldResult> hp = getMixedHighPad();
3803+
for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3804+
if (!getType().isDynamicDim(i)) {
3805+
reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3806+
continue;
3807+
}
3808+
Location loc = getLoc();
3809+
Value dim = b.createOrFold<tensor::DimOp>(
3810+
loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
3811+
3812+
affine::AffineBuilder ab(b, loc);
3813+
AffineExpr d0, d1, d2;
3814+
bindDims(b.getContext(), d0, d1, d2);
3815+
reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3816+
b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3817+
}
3818+
return success();
3819+
}
3820+
37963821
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
37973822
MLIRContext *context) {
37983823
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,

mlir/test/Dialect/Linalg/pad_fusion.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ func.func @dynamic_pad_fusion(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : in
3434
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
3535
// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic
3636
// CHECK-DAG: %[[SOURCE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
37-
// CHECK-DAG: %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]]
37+
// CHECK-DAG: %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[SOURCE_D0]], %[[ARG1]], %[[ARG3]]]
3838
// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
39-
// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]]
39+
// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[SOURCE_D1]], %[[ARG2]], %[[ARG4]]]
4040
// CHECK: %[[INIT:.+]] = tensor.empty(%[[TARGET_D0]], %[[TARGET_D1]])
4141
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ARG5]]{{.*}}outs(%[[INIT]]
4242
// CHECK-DAG: %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
@@ -80,7 +80,7 @@ func.func @mixed_pad_fusion(%arg0 : tensor<?x42xf32>, %arg1 : index, %arg2 : ind
8080
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
8181
// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic
8282
// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
83-
// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]]
83+
// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[SOURCE_D1]], %[[ARG1]], %[[ARG2]]]
8484
// CHECK: %[[INIT:.+]] = tensor.empty(%[[TARGET_D1]])
8585
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ARG3]]{{.*}}outs(%[[INIT]]
8686
// CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]

mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,9 @@ func.func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index
268268
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
269269
// CHECK-DAG: %[[C12:.+]] = arith.constant 12 : index
270270
// CHECK: %[[IN_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
271-
// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
271+
// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[IN_DIM1]], %[[ARG1]]]
272272
// CHECK: %[[IN_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
273-
// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
273+
// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[IN_DIM2]], %[[ARG2]]]
274274
// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
275275

276276
// -----

mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
1010
// CHECK-DAG: %[[c50:.*]] = arith.constant 50 : index
1111
// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]]
12-
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
12+
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[h1]]]
1313
// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
1414
// CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xindex>
1515
// CHECK: linalg.fill ins(%[[c50]] : index) outs(%[[alloc]] : memref<?x?xindex>)

mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ module attributes {transform.with_named_sequence} {
119119
// CHECK-SAME: %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index
120120
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
121121
// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]]
122-
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
122+
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[h1]]]
123123
// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
124124
// CHECK: %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor<?x?xindex>
125125
// CHECK: %[[generic:.*]] = linalg.generic
@@ -162,7 +162,7 @@ module attributes {transform.with_named_sequence} {
162162
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
163163
// CHECK-DAG: %[[c50:.*]] = arith.constant 50 : index
164164
// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]]
165-
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
165+
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[h1]]]
166166
// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
167167
// CHECK: %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor<?x?xindex>
168168
// CHECK: %[[filled:.*]] = linalg.fill ins(%[[c50]] : index) outs(%[[empty]] : tensor<?x?xindex>)
@@ -197,7 +197,7 @@ module attributes {transform.with_named_sequence} {
197197
// CHECK-SAME: %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index, %[[padding:.*]]: index
198198
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
199199
// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t1]], %[[c0]]
200-
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
200+
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[h1]]]
201201
// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
202202
// CHECK: %[[empty:.*]] = tensor.empty(%[[size0]], %[[size1]]) : tensor<?x?xindex>
203203
// CHECK: %[[filled:.*]] = linalg.fill ins(%[[padding]] : index) outs(%[[empty]] : tensor<?x?xindex>)

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
571571
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
572572
// CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
573573
// CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
574-
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map_1]]()[%[[h1]], %[[dim0]]]
574+
// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map_1]]()[%[[dim0]], %[[h1]]]
575575
// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map_2]]()[%[[l2]], %[[h2]]]
576576
// CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref<?x?xindex>
577577
// CHECK: %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]]

mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,20 @@ func.func @dynamic_dims_are_maybe_equal_2(%t: tensor<?x?xf32>) {
213213
"test.compare"(%dim0, %dim1) : (index, index) -> ()
214214
return
215215
}
216+
217+
// -----
218+
219+
// CHECK-LABEL: func.func @pad_reification
220+
func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) {
221+
%pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
222+
%es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
223+
224+
%padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
225+
^bb0(%a: index, %b: index, %c: index):
226+
tensor.yield %cst : f32
227+
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
228+
229+
// CHECK: arith.constant 256 : index
230+
%1 = "test.reify_bound"(%padded) {dim = 1, constant} : (tensor<1x?x64xf32>) -> (index)
231+
return
232+
}

0 commit comments

Comments
 (0)