Skip to content

Commit eb33654

Browse files
[mlir][tensor] Add a PadOp::FoldReifiedShape canonicalization
1 parent 23384cd commit eb33654

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3791,6 +3791,47 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
37913791
}
37923792
};
37933793

3794+
struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> {
3795+
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3796+
3797+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
3798+
PatternRewriter &rewriter) const override {
3799+
if (padOp.getNofold()) {
3800+
return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3801+
}
3802+
3803+
ReifiedRankedShapedTypeDims reifiedResultShapes;
3804+
if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes)))
3805+
return failure();
3806+
3807+
SmallVector<int64_t> newShape;
3808+
for (const auto &[s, ofr] : llvm::zip_equal(
3809+
padOp.getResultType().getShape(), reifiedResultShapes.front())) {
3810+
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
3811+
// Reification does not add static information, just use existing shape.
3812+
if (!maybeCst.has_value()) {
3813+
newShape.push_back(s);
3814+
continue;
3815+
}
3816+
int64_t cst = *maybeCst;
3817+
assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!");
3818+
newShape.push_back(cst);
3819+
}
3820+
if (newShape == padOp.getResultType().getShape())
3821+
return failure();
3822+
3823+
Type oldType = padOp.getResultType();
3824+
Type newType =
3825+
RankedTensorType::Builder(padOp.getResultType()).setShape(newShape);
3826+
Location loc = padOp->getLoc();
3827+
Operation *newPad = rewriter.clone(*padOp);
3828+
newPad->getResult(0).setType(newType);
3829+
rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType,
3830+
newPad->getResult(0));
3831+
return success();
3832+
}
3833+
};
3834+
37943835
} // namespace
37953836

37963837
LogicalResult
@@ -3820,7 +3861,7 @@ void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
38203861
MLIRContext *context) {
38213862
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
38223863
FoldOrthogonalPaddings, FoldStaticPadding,
3823-
FoldConsecutiveConstantPadding>(context);
3864+
FoldConsecutiveConstantPadding, FoldReifiedShape>(context);
38243865
}
38253866

38263867
/// Return the padding value of the PadOp if it constant. In this context,

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2561,3 +2561,21 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
25612561
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
25622562
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
25632563
// CHECK: return %[[RES]]
2564+
2565+
// -----
2566+
2567+
// CHECK-LABEL: func.func @pad_reification
2568+
func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
2569+
-> tensor<1x?x64xf32> {
2570+
%pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
2571+
%es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
2572+
2573+
// CHECK: tensor.pad
2574+
// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
2575+
%padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
2576+
^bb0(%a: index, %b: index, %c: index):
2577+
tensor.yield %cst : f32
2578+
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
2579+
2580+
return %padded : tensor<1x?x64xf32>
2581+
}

0 commit comments

Comments
 (0)