-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Make tensor::PadOp a ReifyRankedShapedTypeOpInterface #145867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir-tensor Author: Nicolas Vasilache (nicolasvasilache) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145867.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..821384eb7d15a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
def Tensor_PadOp : Tensor_Op<"pad", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
AttrSizedOperandSegments,
Pure,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376fb9973a..c66110f6915e9 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -98,6 +98,9 @@ OpFoldResult getAsOpFoldResult(Value val);
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
/// Convert `arrayAttr` to a vector of OpFoldResult.
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
+// TODO: implement a mixed form of this and deprecate getMixedPadImpl.
+// SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr, ValueRange
+// values);
/// Convert int64_t to integer attributes of index type and return them as
/// OpFoldResult.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 72144ec71c5d2..20f48436b1901 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -3793,6 +3794,30 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
} // namespace
+LogicalResult
+PadOp::reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+ SmallVector<OpFoldResult> lp = getMixedLowPad();
+ SmallVector<OpFoldResult> hp = getMixedHighPad();
+ for (int64_t i = 0; i < getResultType().getRank(); ++i) {
+ if (!getType().isDynamicDim(i)) {
+ reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
+ continue;
+ }
+ Location loc = getLoc();
+ Value dim = b.createOrFold<tensor::DimOp>(
+ loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
+
+ affine::AffineBuilder ab(b, loc);
+ AffineExpr d0, d1, d2;
+ bindDims(b.getContext(), d0, d1, d2);
+ reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
+ b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
+ }
+ return success();
+}
+
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 6610d3180cf02..a96aa3cda51d3 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -213,3 +213,20 @@ func.func @dynamic_dims_are_maybe_equal_2(%t: tensor<?x?xf32>) {
"test.compare"(%dim0, %dim1) : (index, index) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) {
+ %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+ %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+ %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+ ^bb0(%a: index, %b: index, %c: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+ // CHECK: arith.constant 256: index
+ %1 = "test.reify_bound"(%padded) {dim = 1, constant} : (tensor<1x?x64xf32>) -> (index)
+ return
+}
|
@llvm/pr-subscribers-mlir Author: Nicolas Vasilache (nicolasvasilache) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145867.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..821384eb7d15a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
def Tensor_PadOp : Tensor_Op<"pad", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
AttrSizedOperandSegments,
Pure,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376fb9973a..c66110f6915e9 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -98,6 +98,9 @@ OpFoldResult getAsOpFoldResult(Value val);
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
/// Convert `arrayAttr` to a vector of OpFoldResult.
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
+// TODO: implement a mixed form of this and deprecate getMixedPadImpl.
+// SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr, ValueRange
+// values);
/// Convert int64_t to integer attributes of index type and return them as
/// OpFoldResult.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 72144ec71c5d2..20f48436b1901 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -3793,6 +3794,30 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
} // namespace
+LogicalResult
+PadOp::reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+ SmallVector<OpFoldResult> lp = getMixedLowPad();
+ SmallVector<OpFoldResult> hp = getMixedHighPad();
+ for (int64_t i = 0; i < getResultType().getRank(); ++i) {
+ if (!getType().isDynamicDim(i)) {
+ reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
+ continue;
+ }
+ Location loc = getLoc();
+ Value dim = b.createOrFold<tensor::DimOp>(
+ loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
+
+ affine::AffineBuilder ab(b, loc);
+ AffineExpr d0, d1, d2;
+ bindDims(b.getContext(), d0, d1, d2);
+ reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
+ b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
+ }
+ return success();
+}
+
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 6610d3180cf02..a96aa3cda51d3 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -213,3 +213,20 @@ func.func @dynamic_dims_are_maybe_equal_2(%t: tensor<?x?xf32>) {
"test.compare"(%dim0, %dim1) : (index, index) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) {
+ %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+ %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+ %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+ ^bb0(%a: index, %b: index, %c: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+ // CHECK: arith.constant 256: index
+ %1 = "test.reify_bound"(%padded) {dim = 1, constant} : (tensor<1x?x64xf32>) -> (index)
+ return
+}
|
4345f4a
to
ca1110d
Compare
ca1110d
to
07b70b6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, modulo my comment and fixing failed tests.
07b70b6
to
bfabf8d
Compare
b7aa8b6
to
da51d56
Compare
da51d56
to
de0bdb9
Compare
Co-authored-by: Fabian Mora <fmora.dev@gmail.com>
de0bdb9
to
4bc1938
Compare
…lvm#145867) Co-authored-by: Fabian Mora <fmora.dev@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this, but I am curious how the dim resolution of pad ops worked before. I dont see any tests of the form https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir that shows something that worked before and works now because the interface was added. All I see are tests with operands swapped around.
…lvm#145867) Co-authored-by: Fabian Mora <fmora.dev@gmail.com>
mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir |
No description provided.