@@ -3791,6 +3791,47 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3791
3791
}
3792
3792
};
3793
3793
3794
+ struct FoldReifiedShape : public OpRewritePattern <tensor::PadOp> {
3795
+ using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3796
+
3797
+ LogicalResult matchAndRewrite (tensor::PadOp padOp,
3798
+ PatternRewriter &rewriter) const override {
3799
+ if (padOp.getNofold ()) {
3800
+ return rewriter.notifyMatchFailure (padOp, " skipping unfoldable pad" );
3801
+ }
3802
+
3803
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
3804
+ if (failed (reifyResultShapes (rewriter, padOp, reifiedResultShapes)))
3805
+ return failure ();
3806
+
3807
+ SmallVector<int64_t > newShape;
3808
+ for (const auto &[s, ofr] : llvm::zip_equal (
3809
+ padOp.getResultType ().getShape (), reifiedResultShapes.front ())) {
3810
+ std::optional<int64_t > maybeCst = getConstantIntValue (ofr);
3811
+ // Reification does not add static information, just use existing shape.
3812
+ if (!maybeCst.has_value ()) {
3813
+ newShape.push_back (s);
3814
+ continue ;
3815
+ }
3816
+ int64_t cst = *maybeCst;
3817
+ assert ((ShapedType::isDynamic (s) || s == cst) && " constants must agree!" );
3818
+ newShape.push_back (cst);
3819
+ }
3820
+ if (newShape == padOp.getResultType ().getShape ())
3821
+ return failure ();
3822
+
3823
+ Type oldType = padOp.getResultType ();
3824
+ Type newType =
3825
+ RankedTensorType::Builder (padOp.getResultType ()).setShape (newShape);
3826
+ Location loc = padOp->getLoc ();
3827
+ Operation *newPad = rewriter.clone (*padOp);
3828
+ newPad->getResult (0 ).setType (newType);
3829
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(padOp, oldType,
3830
+ newPad->getResult (0 ));
3831
+ return success ();
3832
+ }
3833
+ };
3834
+
3794
3835
} // namespace
3795
3836
3796
3837
LogicalResult
@@ -3820,7 +3861,7 @@ void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3820
3861
MLIRContext *context) {
3821
3862
results.add <FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3822
3863
FoldOrthogonalPaddings, FoldStaticPadding,
3823
- FoldConsecutiveConstantPadding>(context);
3864
+ FoldConsecutiveConstantPadding, FoldReifiedShape >(context);
3824
3865
}
3825
3866
3826
3867
// / Return the padding value of the PadOp if it constant. In this context,
0 commit comments