Skip to content

Commit 1f4fba7

Browse files
committed
rename transform to reify-result-shapes
1 parent 465c660 commit 1f4fba7

File tree

7 files changed

+221
-154
lines changed

7 files changed

+221
-154
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,40 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
182182
];
183183
}
184184

185-
def InferStaticShapesPass : Pass<"infer-static-shapes"> {
186-
let summary = "Resolve memref.dim of result values";
185+
def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
186+
let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
187187
let description = [{
188-
The pass resolves memref.dim of result of operations that
189-
implement the `InferShapedTypeOpInterface` or
190-
`ReifyRankedShapedTypeOpInterface` in terms of shapes of its
191-
operands.
188+
This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
189+
operation with ranked `memref` and `tensor` results. Replacing the
190+
operations with their reified versions, and inserting casts when results
191+
shapes are updated.
192+
193+
Example:
194+
```mlir
195+
#map = affine_map<(d0) -> (-d0 + 256)>
196+
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
197+
%0 = affine.apply #map(%arg1)
198+
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
199+
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
200+
^bb0(%arg3: index, %arg4: index, %arg5: index):
201+
tensor.yield %arg0 : f32
202+
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
203+
return %padded : tensor<1x?x64xf32>
204+
}
205+
206+
// mlir-opt --reify-result-shapes
207+
#map = affine_map<()[s0] -> (-s0 + 256)>
208+
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
209+
%0 = affine.apply #map()[%arg1]
210+
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
211+
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
212+
^bb0(%arg3: index, %arg4: index, %arg5: index):
213+
tensor.yield %arg0 : f32
214+
} : tensor<1x?x64xf32> to tensor<1x256x64xf32>
215+
%cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
216+
return %cast : tensor<1x?x64xf32>
217+
}
218+
```
192219
}];
193220
let dependentDialects = [
194221
"affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class RewritePatternSet;
2323
class RewriterBase;
2424
class Value;
2525
class ValueRange;
26+
class ReifyRankedShapedTypeOpInterface;
2627

2728
namespace arith {
2829
class WideIntEmulationConverter;
@@ -57,10 +58,6 @@ void populateResolveRankedShapedTypeResultDimsPatterns(
5758
/// terms of shapes of its input operands.
5859
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
5960

60-
/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops
61-
/// shapes more static.
62-
void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns);
63-
6461
/// Appends patterns for expanding memref operations that modify the metadata
6562
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
6663
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
@@ -213,6 +210,17 @@ memref::AllocaOp allocToAlloca(
213210
RewriterBase &rewriter, memref::AllocOp alloc,
214211
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
215212

213+
/// Reifies the results of `op`, potentially replacing `op` with a reified
214+
/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
215+
/// otherwise it always succeeds. Users of this transform should always expect
216+
/// it to modify the IR, even when it fails. If any of the result types changes,
217+
/// the transform will insert cast operations to the old type to keep the IR
218+
/// consistent.
219+
///
220+
/// Note: This transform only works on ranked `memref` or `tensor` results,
221+
/// other types are ignored.
222+
LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
223+
ReifyRankedShapedTypeOpInterface op);
216224
} // namespace memref
217225
} // namespace mlir
218226

mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
1313
IndependenceTransforms.cpp
1414
MultiBuffer.cpp
1515
NormalizeMemRefs.cpp
16+
ReifyResultShapes.cpp
1617
ResolveShapedTypeResultDims.cpp
1718
RuntimeOpVerification.cpp
1819

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
10+
// operations with ranked `memref` and `tensor` results.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
15+
16+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
19+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
20+
#include "mlir/Interfaces/InferTypeOpInterface.h"
21+
#include "llvm/Support/InterleavedRange.h"
22+
23+
#define DEBUG_TYPE "reify-result-shapes"
24+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25+
26+
namespace mlir {
27+
namespace memref {
28+
#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
29+
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
30+
} // namespace memref
31+
} // namespace mlir
32+
33+
using namespace mlir;
34+
35+
LogicalResult
36+
mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
37+
ReifyRankedShapedTypeOpInterface op) {
38+
LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
39+
// Get the reified out shapes.
40+
ReifiedRankedShapedTypeDims reifiedResultShapes;
41+
if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
42+
reifiedResultShapes.empty()) {
43+
return op.emitError() << "failed to get the reified shapes";
44+
}
45+
46+
bool modified = false;
47+
// Compute the new output types.
48+
SmallVector<Type> outTypes;
49+
for (const auto &[oldTy, reifiedShape] :
50+
llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
51+
// Skip if it's not a memref or tensor type.
52+
if (!isa<RankedTensorType, MemRefType>(oldTy)) {
53+
outTypes.push_back(oldTy);
54+
continue;
55+
}
56+
57+
ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
58+
59+
SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
60+
for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
61+
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
62+
// If the reified dim is dynamic set it appropriately.
63+
if (!maybeCst.has_value()) {
64+
dim = ShapedType::kDynamic;
65+
continue;
66+
}
67+
// Set the static dim.
68+
dim = *maybeCst;
69+
}
70+
71+
// If the shape didn't change continue.
72+
if (shape == shapedTy.getShape()) {
73+
outTypes.push_back(oldTy);
74+
continue;
75+
}
76+
modified = true;
77+
outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
78+
}
79+
80+
// Return if we don't need to update.
81+
if (!modified) {
82+
LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
83+
return success();
84+
}
85+
86+
LLVM_DEBUG({
87+
DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
88+
<< " \n";
89+
DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
90+
});
91+
92+
// We now have outTypes that need to be turned to cast ops.
93+
Location loc = op->getLoc();
94+
SmallVector<Value> newResults;
95+
Operation *newOp = rewriter.clone(*op);
96+
for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
97+
OpResult newRes = newOp->getResult(oldRes.getResultNumber());
98+
Type oldTy = oldRes.getType();
99+
// Continue if the type remained invariant or is not shaped.
100+
if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
101+
newResults.push_back(newRes);
102+
continue;
103+
}
104+
105+
// Update the type.
106+
newRes.setType(reifiedTy);
107+
if (isa<RankedTensorType>(reifiedTy)) {
108+
newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
109+
} else {
110+
assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
111+
newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
112+
}
113+
}
114+
115+
LLVM_DEBUG({
116+
DBGS() << "- reified results " << llvm::interleaved_array(newResults)
117+
<< "\n";
118+
});
119+
rewriter.replaceOp(op, newResults);
120+
return success();
121+
}
122+
123+
//===----------------------------------------------------------------------===//
124+
// Pass registration
125+
//===----------------------------------------------------------------------===//
126+
127+
namespace {
128+
struct ReifyResultShapesPass final
129+
: public memref::impl::ReifyResultShapesPassBase<ReifyResultShapesPass> {
130+
void runOnOperation() override;
131+
};
132+
} // namespace
133+
134+
void ReifyResultShapesPass::runOnOperation() {
135+
SmallVector<ReifyRankedShapedTypeOpInterface> ops;
136+
getOperation()->walk(
137+
[&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); });
138+
IRRewriter rewriter(&getContext());
139+
for (ReifyRankedShapedTypeOpInterface op : ops) {
140+
rewriter.setInsertionPoint(op);
141+
if (failed(memref::reifyOpResultShapes(rewriter, op)))
142+
return signalPassFailure();
143+
}
144+
}

0 commit comments

Comments
 (0)