Skip to content

Commit ca1110d

Browse files
[mlir][tensor] Make tensor::PadOp a ReifyRankedShapedTypeOpInterface
1 parent 6dad1e8 commit ca1110d

File tree

4 files changed

+46
-0
lines changed

4 files changed

+46
-0
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
12561256

12571257
def Tensor_PadOp : Tensor_Op<"pad", [
12581258
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1259+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
12591260
AttrSizedOperandSegments,
12601261
Pure,
12611262
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ OpFoldResult getAsOpFoldResult(Value val);
9898
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
9999
/// Convert `arrayAttr` to a vector of OpFoldResult.
100100
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
101+
// TODO: implement a mixed form of this and deprecate getMixedPadImpl.
102+
// SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr, ValueRange
103+
// values);
101104

102105
/// Convert int64_t to integer attributes of index type and return them as
103106
/// OpFoldResult.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Affine/Utils.h"
1011
#include "mlir/Dialect/Arith/IR/Arith.h"
1112
#include "mlir/Dialect/Arith/Utils/Utils.h"
1213
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -3793,6 +3794,30 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
37933794

37943795
} // namespace
37953796

3797+
LogicalResult
3798+
PadOp::reifyResultShapes(OpBuilder &b,
3799+
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3800+
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3801+
SmallVector<OpFoldResult> lp = getMixedLowPad();
3802+
SmallVector<OpFoldResult> hp = getMixedHighPad();
3803+
for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3804+
if (!getType().isDynamicDim(i)) {
3805+
reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3806+
continue;
3807+
}
3808+
Location loc = getLoc();
3809+
Value dim = b.createOrFold<tensor::DimOp>(
3810+
loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
3811+
3812+
affine::AffineBuilder ab(b, loc);
3813+
AffineExpr d0, d1, d2;
3814+
bindDims(b.getContext(), d0, d1, d2);
3815+
reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3816+
b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3817+
}
3818+
return success();
3819+
}
3820+
37963821
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
37973822
MLIRContext *context) {
37983823
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,

mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,20 @@ func.func @dynamic_dims_are_maybe_equal_2(%t: tensor<?x?xf32>) {
213213
"test.compare"(%dim0, %dim1) : (index, index) -> ()
214214
return
215215
}
216+
217+
// -----
218+
219+
// CHECK-LABEL: func.func @pad_reification
220+
func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) {
221+
%pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
222+
%es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
223+
224+
%padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
225+
^bb0(%a: index, %b: index, %c: index):
226+
tensor.yield %cst : f32
227+
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
228+
229+
// CHECK: arith.constant 256 : index
230+
%1 = "test.reify_bound"(%padded) {dim = 1, constant} : (tensor<1x?x64xf32>) -> (index)
231+
return
232+
}

0 commit comments

Comments
 (0)