diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index 61b07d222d156..d6d038ef65bdf 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -60,9 +60,10 @@ enum class SparseEmitStrategy { // The SparseAssembler pass. //===----------------------------------------------------------------------===// -void populateSparseAssembler(RewritePatternSet &patterns); +void populateSparseAssembler(RewritePatternSet &patterns, bool directOut); std::unique_ptr createSparseAssembler(); +std::unique_ptr createSparseAssembler(bool directOut); //===----------------------------------------------------------------------===// // The SparseReinterpretMap pass. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 58e2d6f32386c..4706d5ba2f218 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -23,12 +23,21 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> { sparse tensors as numpy arrays from and to Python. Note that eventual bufferization decisions (e.g. who [de]allocates the underlying memory) should be resolved in agreement with the external runtime. + + By default, the pass uses the [dis]assemble operations to input and output + sparse tensors. When the direct-out option is set, however, the output + directly returns the MLIR allocated buffers to the external runtime. }]; let constructor = "mlir::createSparseAssembler()"; let dependentDialects = [ + "bufferization::BufferizationDialect", "sparse_tensor::SparseTensorDialect", "tensor::TensorDialect", ]; + let options = [ + Option<"directOut", "direct-out", "bool", + "false", "Directly returns buffers externally">, + ]; } def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp index a91d32a23cac9..eafbe95b7aebe 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp @@ -8,6 +8,7 @@ #include "Utils/CodegenUtils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" @@ -24,7 +25,7 @@ using namespace sparse_tensor; // Convert type range to new types range, with sparse tensors externalized. static void convTypes(TypeRange types, SmallVectorImpl &convTypes, - SmallVectorImpl *extraTypes = nullptr) { + SmallVectorImpl *extraTypes, bool directOut) { for (auto type : types) { // All "dense" data passes through unmodified. if (!getSparseTensorEncoding(type)) { @@ -32,31 +33,33 @@ static void convTypes(TypeRange types, SmallVectorImpl &convTypes, continue; } - // Convert the external representation of the position/coordinate array + // Convert the external representations of the pos/crd/val arrays. const SparseTensorType stt(cast(type)); - foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes]( - Type t, FieldIndex, - SparseTensorFieldKind kind, - Level, LevelType) { - if (kind == SparseTensorFieldKind::CrdMemRef || - kind == SparseTensorFieldKind::PosMemRef || - kind == SparseTensorFieldKind::ValMemRef) { - ShapedType st = t.cast(); - auto rtp = RankedTensorType::get(st.getShape(), st.getElementType()); - convTypes.push_back(rtp); - if (extraTypes) - extraTypes->push_back(rtp); - } - return true; - }); + foreachFieldAndTypeInSparseTensor( + stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex, + SparseTensorFieldKind kind, + Level, LevelType) { + if (kind == SparseTensorFieldKind::PosMemRef || + kind == SparseTensorFieldKind::CrdMemRef || + kind == SparseTensorFieldKind::ValMemRef) { + auto rtp = t.cast(); + if (!directOut) { + rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); + if (extraTypes) + extraTypes->push_back(rtp); + } + convTypes.push_back(rtp); + } + return true; + }); } } // Convert input and output values to [dis]assemble ops for sparse tensors. static void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, - SmallVectorImpl &toVals, unsigned extra, - bool isIn) { + SmallVectorImpl &toVals, unsigned extra, bool isIn, + bool directOut) { unsigned idx = 0; for (auto type : types) { // All "dense" data passes through unmodified. @@ -73,18 +76,29 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types, if (!isIn) inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble - // Collect the external representations of the pos/crd arrays. + // Collect the external representations of the pos/crd/val arrays. foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, SparseTensorFieldKind kind, - Level, LevelType) { - if (kind == SparseTensorFieldKind::CrdMemRef || - kind == SparseTensorFieldKind::PosMemRef || + Level lv, LevelType) { + if (kind == SparseTensorFieldKind::PosMemRef || + kind == SparseTensorFieldKind::CrdMemRef || kind == SparseTensorFieldKind::ValMemRef) { if (isIn) { inputs.push_back(fromVals[idx++]); + } else if (directOut) { + Value mem; + if (kind == SparseTensorFieldKind::PosMemRef) + mem = builder.create(loc, inputs[0], + lv); + else if (kind == SparseTensorFieldKind::CrdMemRef) + mem = builder.create(loc, inputs[0], + lv); + else + mem = builder.create(loc, inputs[0]); + toVals.push_back(mem); } else { - ShapedType st = t.cast(); - auto rtp = RankedTensorType::get(st.getShape(), st.getElementType()); + ShapedType rtp = t.cast(); + rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); inputs.push_back(extraVals[extra++]); retTypes.push_back(rtp); cntTypes.push_back(builder.getIndexType()); @@ -97,7 +111,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types, // Assemble multiple inputs into a single sparse tensor. auto a = builder.create(loc, rtp, inputs); toVals.push_back(a.getResult()); - } else { + } else if (!directOut) { // Disassemble a single sparse input into multiple outputs. // Note that this includes the counters, which are dropped. unsigned len = retTypes.size(); @@ -144,11 +158,14 @@ namespace { // return ..., t1..tn, ... // } // -// TODO: refine output sparse tensors to work well with external framework +// (with a direct-out variant without the disassemble). // struct SparseFuncAssembler : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + SparseFuncAssembler(MLIRContext *context, bool dO) + : OpRewritePattern(context), directOut(dO) {} + LogicalResult matchAndRewrite(func::FuncOp funcOp, PatternRewriter &rewriter) const override { // Only rewrite public entry methods. @@ -159,8 +176,8 @@ struct SparseFuncAssembler : public OpRewritePattern { SmallVector inputTypes; SmallVector outputTypes; SmallVector extraTypes; - convTypes(funcOp.getArgumentTypes(), inputTypes); - convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes); + convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false); + convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut); // Only sparse inputs or outputs need a wrapper method. if (inputTypes.size() == funcOp.getArgumentTypes().size() && @@ -192,7 +209,7 @@ struct SparseFuncAssembler : public OpRewritePattern { // Convert inputs. SmallVector inputs; convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), - ValueRange(), inputs, 0, /*isIn=*/true); + ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut); // Call the original, now private method. A subsequent inlining pass can // determine whether cloning the method body in place is worthwhile. @@ -203,7 +220,7 @@ struct SparseFuncAssembler : public OpRewritePattern { // Convert outputs and return. SmallVector outputs; convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), - body->getArguments(), outputs, extra, /*isIn=*/false); + body->getArguments(), outputs, extra, /*isIn=*/false, directOut); rewriter.create(loc, outputs); // Finally, migrate a potential c-interface property. @@ -215,6 +232,9 @@ struct SparseFuncAssembler : public OpRewritePattern { } return success(); } + +private: + const bool directOut; }; } // namespace @@ -223,6 +243,7 @@ struct SparseFuncAssembler : public OpRewritePattern { // Public method for populating conversion rules. //===----------------------------------------------------------------------===// -void mlir::populateSparseAssembler(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); +void mlir::populateSparseAssembler(RewritePatternSet &patterns, + bool directOut) { + patterns.add(patterns.getContext(), directOut); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index c52fa3751e6b4..f0d162bdb84d9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -767,6 +767,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern { }; /// Sparse conversion rule for the sparse_tensor.disassemble operator. +/// Note that the current implementation simply exposes the buffers to +/// the external client. This assumes the client only reads the buffers +/// (usually copying it to the external data structures, such as numpy +/// arrays). The semantics of the disassemble operation technically +/// require that the copying is done here already using the out-levels +/// and out-values clause. class SparseTensorDisassembleConverter : public OpConversionPattern { public: @@ -774,9 +780,6 @@ class SparseTensorDisassembleConverter LogicalResult matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // We simply expose the buffers to the external client. This - // assumes the client only reads the buffers (usually copying it - // to the external data structures, such as numpy arrays). Location loc = op->getLoc(); auto stt = getSparseTensorType(op.getTensor()); SmallVector retVal; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index acea25f023980..b42d58634a36c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -50,11 +50,12 @@ namespace { struct SparseAssembler : public impl::SparseAssemblerBase { SparseAssembler() = default; SparseAssembler(const SparseAssembler &pass) = default; + SparseAssembler(bool dO) { directOut = dO; } void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateSparseAssembler(patterns); + populateSparseAssembler(patterns, directOut); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/test/Dialect/SparseTensor/external_direct.mlir b/mlir/test/Dialect/SparseTensor/external_direct.mlir new file mode 100644 index 0000000000000..78c4a295686b3 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/external_direct.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt %s --sparse-assembler="direct-out=True" -split-input-file | FileCheck %s + +// ----- + +// CHECK-LABEL: func.func @sparse_in( +// CHECK-SAME: %[[B:.*0]]: tensor, +// CHECK-SAME: %[[C:.*1]]: tensor, +// CHECK-SAME: %[[A:.*]]: tensor) -> tensor<64x64xf32> { +// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]] +// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]]) +// CHECK: return %[[F]] : tensor<64x64xf32> +// CHECK: } +// CHECK: func.func private @_internal_sparse_in +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> { + %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32> + return %0 : tensor<64x64xf32> +} + +// ----- + +// CHECK-LABEL: func.func @sparse_out( +// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>) +// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]]) +// CHECK: %[[P:.*]] = sparse_tensor.positions %[[F]] +// CHECK: %[[C:.*]] = sparse_tensor.coordinates %[[F]] +// CHECK: %[[V:.*]] = sparse_tensor.values %[[F]] +// CHECK: return %[[P]], %[[C]], %[[V]] +// CHECK: } +// CHECK: func.func private @_internal_sparse_out +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> { + %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse> + return %0 : tensor<64x64xf32, #sparse> +} + +// ----- + +// CHECK-LABEL: func.func @sparse_out2( +// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>) +// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]]) +// CHECK: %[[P:.*]] = sparse_tensor.positions %[[F]]#1 +// CHECK: %[[C:.*]] = sparse_tensor.coordinates %[[F]]#1 +// CHECK: %[[V:.*]] = sparse_tensor.values %[[F]]#1 +// CHECK: return %[[F]]#0, %[[P]], %[[C]], %[[V]] +// CHECK: } +// CHECK: func.func private @_internal_sparse_out2 +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) { + %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse> + return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse> +}