diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h new file mode 100644 index 0000000000000..f6b296eccd748 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h @@ -0,0 +1,18 @@ +//===- BufferizationTypeInterfaces.h - Type Interfaces ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ +#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ + +//===----------------------------------------------------------------------===// +// Bufferization Type Interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc" + +#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td new file mode 100644 index 0000000000000..f19224a295648 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td @@ -0,0 +1,42 @@ +//===- BufferizationTypeInterfaces.td - Type Interfaces ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for type interfaces used in Bufferization. +// +//===----------------------------------------------------------------------===// + +#ifndef BUFFERIZATION_TYPE_INTERFACES +#define BUFFERIZATION_TYPE_INTERFACES + +include "mlir/IR/OpBase.td" + +def Bufferization_TensorLikeTypeInterface + : TypeInterface<"TensorLikeType"> { + let cppNamespace = "::mlir::bufferization"; + let description = [{ + Indicates that this type is a tensor type (similarly to a MLIR builtin + tensor) for bufferization purposes. + + The interface currently has no methods as it is used by types to opt into + being supported by the bufferization procedures. + }]; +} + +def Bufferization_BufferLikeTypeInterface + : TypeInterface<"BufferLikeType"> { + let cppNamespace = "::mlir::bufferization"; + let description = [{ + Indicates that this type is a buffer type (similarly to a MLIR builtin + memref) for bufferization purposes. + + The interface currently has no methods as it is used by types to opt into + being supported by the bufferization procedures. + }]; +} + +#endif // BUFFERIZATION_TYPE_INTERFACES diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt index 13a5bc370a4fc..3ead52148c208 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt @@ -10,3 +10,9 @@ mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls) mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRBufferizationEnumsIncGen) add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS BufferizationTypeInterfaces.td) +mlir_tablegen(BufferizationTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(BufferizationTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRBufferizationTypeInterfacesIncGen) +add_dependencies(mlir-headers MLIRBufferizationTypeInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index f53f569070f09..ee33476f441ee 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -471,6 +471,10 @@ def OneShotBufferizePass : Pass<"one-shot-bufferize", "ModuleOp"> { Statistic<"numTensorOutOfPlace", "num-tensor-out-of-place", "Number of out-of-place tensor OpOperands">, ]; + + let dependentDialects = [ + "bufferization::BufferizationDialect", "memref::MemRefDialect" + ]; } def PromoteBuffersToStackPass diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index e5a0c3c45b09e..6b9253a5d71da 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -9,8 +9,10 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Transforms/InliningUtils.h" @@ -51,6 +53,16 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface { return true; } }; + +template +struct BuiltinTensorExternalModel + : TensorLikeType::ExternalModel, + Tensor> {}; + +template +struct BuiltinMemRefExternalModel + : BufferLikeType::ExternalModel, + MemRef> {}; } // namespace //===----------------------------------------------------------------------===// @@ -63,6 +75,20 @@ void mlir::bufferization::BufferizationDialect::initialize() { #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" >(); addInterfaces(); + + // Note: Unlike with other external models, declaring bufferization's + // "promised interfaces" in builtins for TensorLike and BufferLike type + // interfaces is not possible (due to builtins being independent of + // bufferization). Thus, the compromise is to attach these interfaces directly + // during dialect initialization. + RankedTensorType::attachInterface< + BuiltinTensorExternalModel>(*getContext()); + UnrankedTensorType::attachInterface< + BuiltinTensorExternalModel>(*getContext()); + MemRefType::attachInterface>( + *getContext()); + UnrankedMemRefType::attachInterface< + BuiltinMemRefExternalModel>(*getContext()); } LogicalResult BufferizationDialect::verifyRegionArgAttribute( diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index e97b34b20ff72..0b60c44ece5fd 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -57,11 +57,6 @@ struct OneShotBufferizePass OneShotBufferizePass> { using Base::Base; - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - void runOnOperation() override { OneShotBufferizationOptions opt; if (!options) { diff --git a/mlir/test/Dialect/Bufferization/Transforms/tensorlike-bufferlike.mlir b/mlir/test/Dialect/Bufferization/Transforms/tensorlike-bufferlike.mlir new file mode 100644 index 0000000000000..f8691e110aad1 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/tensorlike-bufferlike.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s -test-tensorlike-bufferlike -split-input-file | FileCheck %s + +// CHECK: func.func @builtin_unranked +// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_buffer_like"}} +func.func @builtin_unranked(%t: tensor<*xf32>) -> (memref<*xf32>) +{ + %0 = bufferization.to_memref %t : tensor<*xf32> to memref<*xf32> + return %0 : memref<*xf32> +} + +// ----- + +// CHECK: func.func @builtin_ranked +// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_buffer_like"}} +func.func @builtin_ranked(%t: tensor<42xf32>) -> (memref<42xf32>) +{ + %0 = bufferization.to_memref %t : tensor<42xf32> to memref<42xf32> + return %0 : memref<42xf32> +} + +// ----- + +// CHECK: func.func @custom_tensor +// CHECK-SAME: {found = {operand_0 = "is_tensor_like"}} +func.func @custom_tensor(%t: !test.test_tensor<[42], f32>) -> () +{ + return +} + +// ----- + +// CHECK: func.func @custom_memref +// CHECK-SAME: {found = {operand_0 = "is_buffer_like"}} +func.func @custom_memref(%t: !test.test_memref<[42], f32>) -> () +{ + return +} diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt index c14a9f2cc9bb0..226e0bb97732d 100644 --- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRBufferizationTestPasses TestTensorCopyInsertion.cpp + TestTensorLikeAndBufferLike.cpp EXCLUDE_FROM_LIBMLIR ) @@ -9,4 +10,11 @@ mlir_target_link_libraries(MLIRBufferizationTestPasses PUBLIC MLIRBufferizationTransforms MLIRIR MLIRPass + MLIRTestDialect ) + +target_include_directories(MLIRBufferizationTestPasses + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test + ) diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndBufferLike.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndBufferLike.cpp new file mode 100644 index 0000000000000..60e60849f3e6c --- /dev/null +++ b/mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndBufferLike.cpp @@ -0,0 +1,99 @@ +//===- TestTensorLikeAndBufferLike.cpp - Bufferization Test -----*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" + +#include + +using namespace mlir; + +namespace { +std::string getImplementationStatus(Type type) { + if (isa(type)) { + return "is_tensor_like"; + } + if (isa(type)) { + return "is_buffer_like"; + } + return {}; +} + +DictionaryAttr findAllImplementeesOfTensorOrBufferLike(func::FuncOp funcOp) { + llvm::SmallVector attributes; + + const auto funcType = funcOp.getFunctionType(); + for (auto [index, inputType] : llvm::enumerate(funcType.getInputs())) { + const auto status = getImplementationStatus(inputType); + if (status.empty()) { + continue; + } + + attributes.push_back( + NamedAttribute(StringAttr::get(funcOp.getContext(), + "operand_" + std::to_string(index)), + StringAttr::get(funcOp.getContext(), status))); + } + + for (auto [index, resultType] : llvm::enumerate(funcType.getResults())) { + const auto status = getImplementationStatus(resultType); + if (status.empty()) { + continue; + } + + attributes.push_back(NamedAttribute( + StringAttr::get(funcOp.getContext(), "result_" + std::to_string(index)), + StringAttr::get(funcOp.getContext(), status))); + } + + return mlir::DictionaryAttr::get(funcOp.getContext(), attributes); +} + +/// This pass tests whether specified types implement TensorLike and (or) +/// BufferLike type interfaces defined in bufferization. +/// +/// The pass analyses operation signature. When the aforementioned interface +/// implementation found, an attribute is added to the operation, signifying the +/// associated operand / result. +struct TestTensorLikeAndBufferLikePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndBufferLikePass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-tensorlike-bufferlike"; } + StringRef getDescription() const final { + return "Module pass to test custom types that implement TensorLike / " + "BufferLike interfaces"; + } + + void runOnOperation() override { + auto op = getOperation(); + + op.walk([](func::FuncOp funcOp) { + const auto dict = findAllImplementeesOfTensorOrBufferLike(funcOp); + if (!dict.empty()) { + funcOp->setAttr("found", dict); + } + }); + } +}; +} // namespace + +namespace mlir::test { +void registerTestTensorLikeAndBufferLikePass() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index a48ac24ca056d..6e608e4772391 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -93,6 +93,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC MLIRTransformUtils MLIRTransforms MLIRValueBoundsOpInterface + MLIRBufferizationDialect ) add_mlir_translation_library(MLIRTestFromLLVMIRTranslation diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index f1c31658c13ac..e9785594d3332 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -19,6 +19,7 @@ include "TestAttrDefs.td" include "TestInterfaces.td" include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" // All of the types will extend this class. class Test_Type traits = []> @@ -403,4 +404,49 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", let mnemonic = "op_asm_type_interface"; } +def TestTensorType : Test_Type<"TestTensor", + [Bufferization_TensorLikeTypeInterface, ShapedTypeInterface]> { + let mnemonic = "test_tensor"; + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "mlir::Type":$elementType + ); + let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`"; + + let extraClassDeclaration = [{ + // ShapedTypeInterface: + bool hasRank() const { + return true; + } + test::TestTensorType cloneWith(std::optional> shape, + mlir::Type elementType) const { + return test::TestTensorType::get( + getContext(), shape.value_or(getShape()), elementType); + } + }]; +} + +def TestMemrefType : Test_Type<"TestMemref", + [Bufferization_BufferLikeTypeInterface, ShapedTypeInterface]> { + let mnemonic = "test_memref"; + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "mlir::Type":$elementType, + DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace + ); + let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`"; + + let extraClassDeclaration = [{ + // ShapedTypeInterface: + bool hasRank() const { + return true; + } + test::TestMemrefType cloneWith(std::optional> shape, + mlir::Type elementType) const { + return test::TestMemrefType::get( + getContext(), shape.value_or(getShape()), elementType, getMemSpace()); + } + }]; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h index cef3f056a7986..6499a96f495d0 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -18,6 +18,7 @@ #include #include "TestTraits.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index d06ff8070e7cf..3a5019fe8ee54 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -149,6 +149,7 @@ void registerTestSPIRVCPURunnerPipeline(); void registerTestSPIRVFuncSignatureConversion(); void registerTestSPIRVVectorUnrolling(); void registerTestTensorCopyInsertionPass(); +void registerTestTensorLikeAndBufferLikePass(); void registerTestTensorTransforms(); void registerTestTopologicalSortAnalysisPass(); void registerTestTransformDialectEraseSchedulePass(); @@ -291,6 +292,7 @@ void registerTestPasses() { mlir::test::registerTestSPIRVFuncSignatureConversion(); mlir::test::registerTestSPIRVVectorUnrolling(); mlir::test::registerTestTensorCopyInsertionPass(); + mlir::test::registerTestTensorLikeAndBufferLikePass(); mlir::test::registerTestTensorTransforms(); mlir::test::registerTestTopologicalSortAnalysisPass(); mlir::test::registerTestTransformDialectEraseSchedulePass();