Skip to content

[mlir][Vector] Add a rewrite pattern for gather over a strided memref #72991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 93 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,87 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
}
};

/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
/// MemRef with updated indices that model the strided access.
///
/// ```mlir
/// %subview = memref.subview %M (...)
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
/// ```
/// ==>
/// ```mlir
/// %collapse_shape = memref.collapse_shape %M (...)
/// : memref<100x3xf32> into memref<300xf32>
/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
/// %gather = vector.gather %collapse_shape[%new_idxs] (...)
/// : memref<300xf32> (...)
/// ```
///
/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
/// but should be fairly straightforward to extend beyond that.
struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
Value base = op.getBase();

// TODO: Strided accesses might be coming from other ops as well
auto subview = base.getDefiningOp<memref::SubViewOp>();
if (!subview)
return failure();

auto sourceType = subview.getSource().getType();

// TODO: Allow ranks > 2.
if (sourceType.getRank() != 2)
return failure();

// Get strides
auto layout = subview.getResult().getType().getLayout();
auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
if (!stridedLayoutAttr)
return failure();

// TODO: Allow the access to be strided in multiple dimensions.
if (stridedLayoutAttr.getStrides().size() != 1)
return failure();

int64_t srcTrailingDim = sourceType.getShape().back();

// Assume that the stride matches the trailing dimension of the source
// memref.
// TODO: Relax this assumption.
if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
return failure();

// 1. Collapse the input memref so that it's "flat".
SmallVector<ReassociationIndices> reassoc = {{0, 1}};
Value collapsed = rewriter.create<memref::CollapseShapeOp>(
op.getLoc(), subview.getSource(), reassoc);

// 2. Generate new gather indices that will model the
// strided access.
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
VectorType vType = op.getIndexVec().getType();
Value mulCst = rewriter.create<arith::ConstantOp>(
op.getLoc(), vType, DenseElementsAttr::get(vType, stride));

Value newIdxs =
rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);

// 3. Create an updated gather op with the collapsed input memref and the
// updated indices.
Value newGather = rewriter.create<vector::GatherOp>(
op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
newIdxs, op.getMask(), op.getPassThru());
rewriter.replaceOp(op, newGather);

return success();
}
};

/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
Expand All @@ -115,6 +196,16 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {

Value condMask = op.getMask();
Value base = op.getBase();

// vector.load requires the most minor memref dim to have unit stride
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
if (stridesAttr.getStrides().back() != 1)
return failure();
}
}

Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
op.getIndexVec());
Expand Down Expand Up @@ -168,6 +259,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {

void mlir::vector::populateVectorGatherLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
benefit);
patterns.add<FlattenGather, RemoveStrideFromGatherSource,
Gather1DToConditionalLoads>(patterns.getContext(), benefit);
}
55 changes: 55 additions & 0 deletions mlir/test/Dialect/Vector/vector-gather-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,58 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}

// Check that vector.gather of a strided memref is replaced with a
// vector.gather with indices encoding the original strides. Note that multiple
// patterns are run for this example, e.g.:
// 1. "remove stride from gather source"
// 2. "flatten gather"
// However, the main goal is to the test Pattern 1 above.
#map = affine_map<()[s0] -> (s0 * 4096)>
func.func @strided_gather(%base : memref<100x3xf32>,
%idxs : vector<4xindex>,
%x : index, %y : index) -> vector<4xf32> {
%c0 = arith.constant 0 : index
%x_1 = affine.apply #map()[%x]
// Strided MemRef
%subview = memref.subview %base[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
%mask = arith.constant dense<true> : vector<4xi1>
%pass_thru = arith.constant dense<0.000000e+00> : vector<4xf32>
// Gather of a strided MemRef
%res = vector.gather %subview[%c0] [%idxs], %mask, %pass_thru : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
return %res : vector<4xf32>
}
// CHECK-LABEL: func.func @strided_gather(
// CHECK-SAME: %[[base:.*]]: memref<100x3xf32>,
// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
// CHECK-SAME: %[[VAL_4:.*]]: index,
// CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> {
// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>

// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>

// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>

// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>

// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>)
// CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>

// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>