diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index ccd91a928e1dd..248ef9f855b14 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -31,7 +31,7 @@ class XeGPUTypeDef traits = [], } def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", - [ShapedTypeInterface], "::mlir::TensorType"> { + [TensorTypeInterface]> { let summary = "TensorDesc describing regions of interested data."; let description = [{ TensorDesc is a type designed to describe regions of the interested data as well as some @@ -105,7 +105,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", ]; let extraClassDeclaration = [{ - using TensorType::clone; using mlir::ShapedType::Trait::getElementTypeBitWidth; using mlir::ShapedType::Trait::getRank; using mlir::ShapedType::Trait::getNumElements; @@ -115,8 +114,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", using mlir::ShapedType::Trait::getDimSize; using mlir::ShapedType::Trait::getDynamicDimIndex; + TensorDescType cloneWith(std::optional> shape, Type elementType) const; + bool hasRank() const { return true; } + TensorDescType clone(::mlir::Type elementType) { - return llvm::cast(cloneWith(getShape(), elementType)); + return cloneWith(getShape(), elementType); } BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const { @@ -144,7 +146,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", return MemorySpace::Global; } - int getArrayLength() { + int getArrayLength() const { auto attr = getEncoding(); auto block_attr = mlir::dyn_cast_if_present(attr); assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr."); @@ -154,7 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", return 1; } - bool getBoundaryCheck() { + bool getBoundaryCheck() const { auto attr = getEncoding(); auto block_attr = mlir::dyn_cast_if_present(attr); assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr."); diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index 8aa2c55570153..41496ed52d51c 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -143,21 +143,21 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> { /// Return the number of elements present in the given shape. static int64_t getNumElements(ArrayRef shape); + }]; + let extraSharedClassDeclaration = [{ /// Return a clone of this type with the given new shape and element type. /// The returned type is ranked, even if this type is unranked. auto clone(::llvm::ArrayRef shape, Type elementType) { - return cloneWith(shape, elementType); + return $_type.cloneWith(shape, elementType); } /// Return a clone of this type with the given new shape. The returned type /// is ranked, even if this type is unranked. auto clone(::llvm::ArrayRef shape) { - return cloneWith(shape, getElementType()); + return $_type.cloneWith(shape, $_type.getElementType()); } - }]; - let extraSharedClassDeclaration = [{ /// Return a clone of this type with the given new element type. The /// returned type is ranked if and only if this type is ranked. In that /// case, the returned type has the same shape as this type. @@ -227,4 +227,76 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> { }]; } +//===----------------------------------------------------------------------===// +// TensorTypeInterface +//===----------------------------------------------------------------------===// + +def TensorTypeInterface : TypeInterface<"TensorType", [ShapedTypeInterface]> { + let cppNamespace = "::mlir"; + let description = [{ + This interface provides a shared interface type for ranked, unranked and any + user-specified tensor types. + + This interface attaches the ShapedTypeInterface to act as a mixin to + provide many useful utility functions. + }]; + + let extraClassDeclaration = [{ + /// Return true if the specified element type is ok in a tensor. + static bool isValidElementType(::mlir::Type type); + }]; + + // Note: in trait to apply to derived types. + let extraTraitClassDeclaration = [{ + operator ShapedType() const { return llvm::cast($_type); } + operator TensorType() const { return llvm::cast($_type); } + }]; + + let extraClassOf = [{ + return $_type.hasTrait<::mlir::TensorType::Trait>(); + }]; +} + +//===----------------------------------------------------------------------===// +// BaseMemRefTypeInterface +//===----------------------------------------------------------------------===// + +def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [ShapedTypeInterface]> { + let cppNamespace = "::mlir"; + let description = [{ + This interface provides a shared interface type for ranked, unranked and any + user-specified memref types. + + This interface attaches the ShapedTypeInterface to act as a mixin to + provide many useful utility functions. + }]; + + let methods = [ + InterfaceMethod<[{ + Returns the memory space in which data referred to by this memref resides. + }], + "::mlir::Attribute", "getMemorySpace">, + InterfaceMethod<[{ + [deprecated] Returns the memory space in old raw integer representation. + New `Attribute getMemorySpace()` method should be used instead. + }], + "unsigned", "getMemorySpaceAsInt">, + ]; + + let extraClassDeclaration = [{ + /// Return true if the specified element type is ok in a memref. + static bool isValidElementType(::mlir::Type type); + }]; + + // Note: in trait to apply to derived types. + let extraTraitClassDeclaration = [{ + operator ShapedType() const { return llvm::cast($_type); } + operator BaseMemRefType() const { return llvm::cast($_type); } + }]; + + let extraClassOf = [{ + return $_type.hasTrait<::mlir::BaseMemRefType::Trait>(); + }]; +} + #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_ diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index df1e02732617d..4f3365492f720 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -43,108 +43,6 @@ template class ValueSemantics : public TypeTrait::TraitBase {}; -//===----------------------------------------------------------------------===// -// TensorType -//===----------------------------------------------------------------------===// - -/// Tensor types represent multi-dimensional arrays, and have two variants: -/// RankedTensorType and UnrankedTensorType. -/// Note: This class attaches the ShapedType trait to act as a mixin to -/// provide many useful utility functions. This inheritance has no effect -/// on derived tensor types. -class TensorType : public Type, public ShapedType::Trait { -public: - using Type::Type; - - /// Returns the element type of this tensor type. - Type getElementType() const; - - /// Returns if this type is ranked, i.e. it has a known number of dimensions. - bool hasRank() const; - - /// Returns the shape of this tensor type. - ArrayRef getShape() const; - - /// Clone this type with the given shape and element type. If the - /// provided shape is `std::nullopt`, the current shape of the type is used. - TensorType cloneWith(std::optional> shape, - Type elementType) const; - - // Make sure that base class overloads are visible. - using ShapedType::Trait::clone; - - /// Return a clone of this type with the given new shape and element type. - /// The returned type is ranked, even if this type is unranked. - RankedTensorType clone(ArrayRef shape, Type elementType) const; - - /// Return a clone of this type with the given new shape. The returned type - /// is ranked, even if this type is unranked. - RankedTensorType clone(ArrayRef shape) const; - - /// Return true if the specified element type is ok in a tensor. - static bool isValidElementType(Type type); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Type type); - - /// Allow implicit conversion to ShapedType. - operator ShapedType() const { return llvm::cast(*this); } -}; - -//===----------------------------------------------------------------------===// -// BaseMemRefType -//===----------------------------------------------------------------------===// - -/// This class provides a shared interface for ranked and unranked memref types. -/// Note: This class attaches the ShapedType trait to act as a mixin to -/// provide many useful utility functions. This inheritance has no effect -/// on derived memref types. -class BaseMemRefType : public Type, public ShapedType::Trait { -public: - using Type::Type; - - /// Returns the element type of this memref type. - Type getElementType() const; - - /// Returns if this type is ranked, i.e. it has a known number of dimensions. - bool hasRank() const; - - /// Returns the shape of this memref type. - ArrayRef getShape() const; - - /// Clone this type with the given shape and element type. If the - /// provided shape is `std::nullopt`, the current shape of the type is used. - BaseMemRefType cloneWith(std::optional> shape, - Type elementType) const; - - // Make sure that base class overloads are visible. - using ShapedType::Trait::clone; - - /// Return a clone of this type with the given new shape and element type. - /// The returned type is ranked, even if this type is unranked. - MemRefType clone(ArrayRef shape, Type elementType) const; - - /// Return a clone of this type with the given new shape. The returned type - /// is ranked, even if this type is unranked. - MemRefType clone(ArrayRef shape) const; - - /// Return true if the specified element type is ok in a memref. - static bool isValidElementType(Type type); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Type type); - - /// Returns the memory space in which data referred to by this memref resides. - Attribute getMemorySpace() const; - - /// [deprecated] Returns the memory space in old raw integer representation. - /// New `Attribute getMemorySpace()` method should be used instead. - unsigned getMemorySpaceAsInt() const; - - /// Allow implicit conversion to ShapedType. - operator ShapedType() const { return llvm::cast(*this); } -}; - } // namespace mlir //===----------------------------------------------------------------------===// @@ -390,10 +288,6 @@ class FixedVectorType : public VectorType { // Deferred Method Definitions //===----------------------------------------------------------------------===// -inline bool BaseMemRefType::classof(Type type) { - return llvm::isa(type); -} - inline bool BaseMemRefType::isValidElementType(Type type) { return type.isIntOrIndexOrFloat() || llvm::isa( @@ -401,10 +295,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) { llvm::isa(type); } -inline bool TensorType::classof(Type type) { - return llvm::isa(type); -} - //===----------------------------------------------------------------------===// // Type Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index af474b3e3ec47..575ae6a263b1b 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -542,8 +542,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> { //===----------------------------------------------------------------------===// def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ - ShapedTypeInterface - ], "BaseMemRefType"> { + BaseMemRefTypeInterface + ]> { let summary = "Shaped reference to a region of memory"; let description = [{ Syntax: @@ -794,7 +794,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ "unsigned":$memorySpaceInd)> ]; let extraClassDeclaration = [{ - using BaseMemRefType::clone; + using ShapedType::Trait::clone; using ShapedType::Trait::getElementTypeBitWidth; using ShapedType::Trait::getRank; using ShapedType::Trait::getNumElements; @@ -854,6 +854,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ /// Return "true" if the last dimension has a static unit stride. Also /// return "true" for types with no strides. bool isLastDimUnitStride(); + + /// Returns if this type is ranked (always true). + bool hasRank() const { return true; } + + /// Returns a clone of this type with the given shape and element + /// type. If a shape is not provided, the current shape of the type is used. + MemRefType cloneWith(std::optional> shape, Type elementType) const; }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; @@ -934,8 +941,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> { //===----------------------------------------------------------------------===// def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [ - ShapedTypeInterface, ValueSemantics - ], "TensorType"> { + TensorTypeInterface, ValueSemantics + ]> { let summary = "Multi-dimensional array with a fixed number of dimensions"; let description = [{ Syntax: @@ -1016,7 +1023,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [ }]> ]; let extraClassDeclaration = [{ - using TensorType::clone; + using ShapedType::Trait::clone; using ShapedType::Trait::getElementTypeBitWidth; using ShapedType::Trait::getRank; using ShapedType::Trait::getNumElements; @@ -1033,7 +1040,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [ /// Return a clone of this type with the given new element type and the same /// shape as this type. RankedTensorType clone(::mlir::Type elementType) { - return ::llvm::cast(cloneWith(getShape(), elementType)); + return cloneWith(getShape(), elementType); } /// Return a clone of this type without the encoding. @@ -1041,6 +1048,13 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [ return RankedTensorType::get(getShape(), getElementType()); } + /// Returns if this type is ranked (always true). + bool hasRank() const { return true; } + + /// Returns a clone of this type with the given shape and element + /// type. If a shape is not provided, the current shape of the type is used. + RankedTensorType cloneWith(std::optional> shape, Type elementType) const; + /// Return a clone of this type with the given new encoding and the same /// shape and element type as this type. RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) { @@ -1123,8 +1137,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> { //===----------------------------------------------------------------------===// def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ - ShapedTypeInterface - ], "BaseMemRefType"> { + BaseMemRefTypeInterface + ]> { let summary = "Shaped reference, with unknown rank, to a region of memory"; let description = [{ Syntax: @@ -1170,7 +1184,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ }]> ]; let extraClassDeclaration = [{ - using BaseMemRefType::clone; + using ShapedType::Trait::clone; using ShapedType::Trait::getElementTypeBitWidth; using ShapedType::Trait::getRank; using ShapedType::Trait::getNumElements; @@ -1186,11 +1200,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; - /// Return a clone of this type with the given new element type and the same - /// shape as this type. - MemRefType clone(::mlir::Type elementType) { - return ::llvm::cast(cloneWith(getShape(), elementType)); - } + /// Returns if this type is ranked (always false). + bool hasRank() const { return false; } + + /// Returns a clone of this type with the given shape and element + /// type. If a shape is not provided, the current shape of the type is used. + BaseMemRefType cloneWith(std::optional> shape, Type elementType) const; }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; @@ -1201,8 +1216,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ //===----------------------------------------------------------------------===// def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [ - ShapedTypeInterface, ValueSemantics - ], "TensorType"> { + TensorTypeInterface, ValueSemantics + ]> { let summary = "Multi-dimensional array with unknown dimensions"; let description = [{ Syntax: @@ -1229,7 +1244,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [ }]> ]; let extraClassDeclaration = [{ - using TensorType::clone; + using ShapedType::Trait::clone; using ShapedType::Trait::getElementTypeBitWidth; using ShapedType::Trait::getRank; using ShapedType::Trait::getNumElements; @@ -1240,6 +1255,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [ using ShapedType::Trait::getDynamicDimIndex; ArrayRef getShape() const { return std::nullopt; } + + /// Returns if this type is ranked (always false). + bool hasRank() const { return false; } + + /// Returns a clone of this type with the given shape and element + /// type. If a shape is not provided, the current shape of the type is used. + TensorType cloneWith(std::optional> shape, Type elementType) const; }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 5f23a33049f87..1d1bcee8600a8 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -41,7 +41,7 @@ TensorType inferReshapeInputType(TypedValue input, // 0D tensor. While such construct is not incorrect on its own, bufferization // cannot properly handle it at the moment, so we avoid it. SmallVector shape(input.getType().getRank(), 1); - return input.getType().clone(shape); + return mlir::cast(input.getType().clone(shape)); } // Infer the result type of 'tensor.expand_shape' in the collapse-expand @@ -51,7 +51,7 @@ TensorType inferReshapeExpandedType(TensorType inputType, // Special case for 0D output tensor. Note: Watch out when using Type::clone() // with just '{}', as it will invoke the incorrect overload. if (newShape.empty()) - return inputType.clone(ArrayRef{}); + return mlir::cast(inputType.clone(ArrayRef{})); // Check if the input is static, and if so, get its total size bool inputIsStatic = inputType.hasStaticShape(); @@ -98,7 +98,7 @@ TensorType inferReshapeExpandedType(TensorType inputType, assert(!inputIsStatic || resultIsStatic); // Create result type - return inputType.clone(resultShape); + return mlir::cast(inputType.clone(resultShape)); } // Infer the result type of 'tensor.collapse_shape' in the collapse-expand @@ -108,11 +108,11 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) { auto rhsShape = rhsType.getShape(); if (lhsShape.empty() || rhsShape.empty()) - return lhsType.clone(ArrayRef{}); + return mlir::cast(lhsType.clone(ArrayRef{})); if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape)) - return lhsType.clone({ShapedType::kDynamic}); + return mlir::cast(lhsType.clone({ShapedType::kDynamic})); SmallVector intermediateShape; unsigned currLhsDim = 0, currRhsDim = 0; @@ -149,7 +149,7 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) { assert(rhsShape[currRhsDim] == 1); } - return lhsType.clone(intermediateShape); + return mlir::cast(lhsType.clone(intermediateShape)); } SmallVector diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 99ffa62c41a4d..e15f81a1ef433 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -735,8 +735,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, if (llvm::isa(opResult.getType())) { // The OpResult is a tensor. Such values are replaced with memrefs during // bufferization. - assert((llvm::isa(replacement.getType()) || - llvm::isa(replacement.getType())) && + assert(llvm::isa(replacement.getType()) && "tensor op result should be replaced with a memref value"); // The existing uses of the OpResult still expect a tensor. Insert a // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 4ac6eca586961..0965f61279703 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -77,9 +77,9 @@ struct CastOpInterface // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not // change. auto rankedResultType = cast(castOp.getType()); - return MemRefType::get( + return BaseMemRefType(MemRefType::get( rankedResultType.getShape(), rankedResultType.getElementType(), - llvm::cast(*maybeSrcBufferType).getLayout(), memorySpace); + llvm::cast(*maybeSrcBufferType).getLayout(), memorySpace)); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -157,8 +157,8 @@ struct CollapseShapeOpInterface tensorResultType, srcBufferType.getMemorySpace()); } - return memref::CollapseShapeOp::computeCollapsedType( - srcBufferType, collapseShapeOp.getReassociationIndices()); + return BaseMemRefType(memref::CollapseShapeOp::computeCollapsedType( + srcBufferType, collapseShapeOp.getReassociationIndices())); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -325,7 +325,7 @@ struct ExpandShapeOpInterface expandShapeOp.getReassociationIndices()); if (failed(maybeResultType)) return failure(); - return *maybeResultType; + return BaseMemRefType(*maybeResultType); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -405,10 +405,10 @@ struct ExtractSliceOpInterface SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); - return memref::SubViewOp::inferRankReducedResultType( + return BaseMemRefType(memref::SubViewOp::inferRankReducedResultType( extractSliceOp.getType().getShape(), llvm::cast(*srcMemrefType), mixedOffsets, mixedSizes, - mixedStrides); + mixedStrides)); } }; @@ -746,9 +746,10 @@ struct PadOpInterface if (failed(maybeSrcBufferType)) return failure(); MemRefLayoutAttrInterface layout; - return MemRefType::get(padOp.getResultType().getShape(), - padOp.getResultType().getElementType(), layout, - maybeSrcBufferType->getMemorySpace()); + return BaseMemRefType( + MemRefType::get(padOp.getResultType().getShape(), + padOp.getResultType().getElementType(), layout, + maybeSrcBufferType->getMemorySpace())); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 78c242571935c..1fcf9df052ae3 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -396,6 +396,13 @@ FailureOr TensorDescType::getDistributedVectorType() { getElementType()); } +TensorDescType TensorDescType::cloneWith(std::optional> shape, + Type elementType) const { + return TensorDescType::get(shape.value_or(this->getShape()), elementType, + this->getArrayLength(), this->getBoundaryCheck(), + this->getMemorySpace(), this->getSgMap()); +} + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 3924d082f0628..02e7038a75fff 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/TensorEncoding.h" @@ -256,45 +257,6 @@ VectorType VectorType::cloneWith(std::optional> shape, // TensorType //===----------------------------------------------------------------------===// -Type TensorType::getElementType() const { - return llvm::TypeSwitch(*this) - .Case( - [](auto type) { return type.getElementType(); }); -} - -bool TensorType::hasRank() const { - return !llvm::isa(*this); -} - -ArrayRef TensorType::getShape() const { - return llvm::cast(*this).getShape(); -} - -TensorType TensorType::cloneWith(std::optional> shape, - Type elementType) const { - if (llvm::dyn_cast(*this)) { - if (shape) - return RankedTensorType::get(*shape, elementType); - return UnrankedTensorType::get(elementType); - } - - auto rankedTy = llvm::cast(*this); - if (!shape) - return RankedTensorType::get(rankedTy.getShape(), elementType, - rankedTy.getEncoding()); - return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType, - rankedTy.getEncoding()); -} - -RankedTensorType TensorType::clone(::llvm::ArrayRef shape, - Type elementType) const { - return ::llvm::cast(cloneWith(shape, elementType)); -} - -RankedTensorType TensorType::clone(::llvm::ArrayRef shape) const { - return ::llvm::cast(cloneWith(shape, getElementType())); -} - // Check if "elementType" can be an element type of a tensor. static LogicalResult checkTensorElementType(function_ref emitError, @@ -317,6 +279,12 @@ bool TensorType::isValidElementType(Type type) { //===----------------------------------------------------------------------===// // RankedTensorType //===----------------------------------------------------------------------===// +RankedTensorType +RankedTensorType::cloneWith(std::optional> shape, + Type elementType) const { + return RankedTensorType::get(shape.value_or(this->getShape()), elementType, + this->getEncoding()); +} LogicalResult RankedTensorType::verify(function_ref emitError, @@ -335,6 +303,13 @@ RankedTensorType::verify(function_ref emitError, // UnrankedTensorType //===----------------------------------------------------------------------===// +TensorType UnrankedTensorType::cloneWith(std::optional> shape, + Type elementType) const { + if (shape) + return RankedTensorType::get(*shape, elementType); + return UnrankedTensorType::get(elementType); +} + LogicalResult UnrankedTensorType::verify(function_ref emitError, Type elementType) { @@ -342,65 +317,18 @@ UnrankedTensorType::verify(function_ref emitError, } //===----------------------------------------------------------------------===// -// BaseMemRefType +// MemRefType //===----------------------------------------------------------------------===// -Type BaseMemRefType::getElementType() const { - return llvm::TypeSwitch(*this) - .Case( - [](auto type) { return type.getElementType(); }); -} - -bool BaseMemRefType::hasRank() const { - return !llvm::isa(*this); -} - -ArrayRef BaseMemRefType::getShape() const { - return llvm::cast(*this).getShape(); -} - -BaseMemRefType BaseMemRefType::cloneWith(std::optional> shape, - Type elementType) const { - if (llvm::dyn_cast(*this)) { - if (!shape) - return UnrankedMemRefType::get(elementType, getMemorySpace()); - MemRefType::Builder builder(*shape, elementType); - builder.setMemorySpace(getMemorySpace()); - return builder; - } - - MemRefType::Builder builder(llvm::cast(*this)); +MemRefType MemRefType::cloneWith(std::optional> shape, + Type elementType) const { + MemRefType::Builder builder(*this); if (shape) builder.setShape(*shape); builder.setElementType(elementType); - return builder; -} - -MemRefType BaseMemRefType::clone(::llvm::ArrayRef shape, - Type elementType) const { - return ::llvm::cast(cloneWith(shape, elementType)); -} - -MemRefType BaseMemRefType::clone(::llvm::ArrayRef shape) const { - return ::llvm::cast(cloneWith(shape, getElementType())); + return MemRefType(builder); } -Attribute BaseMemRefType::getMemorySpace() const { - if (auto rankedMemRefTy = llvm::dyn_cast(*this)) - return rankedMemRefTy.getMemorySpace(); - return llvm::cast(*this).getMemorySpace(); -} - -unsigned BaseMemRefType::getMemorySpaceAsInt() const { - if (auto rankedMemRefTy = llvm::dyn_cast(*this)) - return rankedMemRefTy.getMemorySpaceAsInt(); - return llvm::cast(*this).getMemorySpaceAsInt(); -} - -//===----------------------------------------------------------------------===// -// MemRefType -//===----------------------------------------------------------------------===// - std::optional> mlir::computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape, @@ -888,6 +816,17 @@ bool MemRefType::isLastDimUnitStride() { // UnrankedMemRefType //===----------------------------------------------------------------------===// +BaseMemRefType +UnrankedMemRefType::cloneWith(std::optional> shape, + Type elementType) const { + if (!shape) + return UnrankedMemRefType::get(elementType, getMemorySpace()); + + MemRefType::Builder builder(*shape, elementType); + builder.setMemorySpace(getMemorySpace()); + return MemRefType(builder); +} + unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { return detail::getMemorySpaceAsInt(getMemorySpace()); } diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index f1c31658c13ac..61fab9c889be7 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -403,4 +403,50 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", let mnemonic = "op_asm_type_interface"; } +def TestTensorType : Test_Type<"TestTensor", [TensorTypeInterface]> { + let mnemonic = "test_tensor"; + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "mlir::Type":$elementType + ); + let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`"; + + let extraClassDeclaration = [{ + // ShapedTypeInterface: + bool hasRank() const { + return true; + } + test::TestTensorType cloneWith(std::optional> shape, mlir::Type elementType) const { + return test::TestTensorType::get(getContext(), shape.value_or(getShape()), elementType); + } + }]; +} + +def TestMemrefType : Test_Type<"TestMemref", [BaseMemRefTypeInterface]> { + let mnemonic = "test_memref"; + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "mlir::Type":$elementType, + DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace + ); + let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`"; + + let extraClassDeclaration = [{ + // ShapedTypeInterface: + bool hasRank() const { + return true; + } + test::TestMemrefType cloneWith(std::optional> shape, mlir::Type elementType) const { + return test::TestMemrefType::get(getContext(), shape.value_or(getShape()), elementType, getMemSpace()); + } + + // BaseMemRefTypeInterface: + mlir::Attribute getMemorySpace() const { + return getMemSpace(); + } + // [deprecated] + unsigned getMemorySpaceAsInt() const { return 0; } + }]; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp index 42196b003e7da..f8b4cf026d7b9 100644 --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OwningOpRef.h" #include "gtest/gtest.h" @@ -84,3 +85,50 @@ TEST(InterfaceTest, TestImplicitConversion) { typeA = typeB; EXPECT_EQ(typeA, typeB); } + +TEST(InterfaceTest, TestCustomTensorIsTensorType) { + MLIRContext context; + context.loadDialect(); + + auto customTensorType = test::TestTensorType::get( + &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32)); + EXPECT_TRUE(mlir::isa(customTensorType)); + + auto customCloneType = customTensorType.clone({3, 4, 5}); + EXPECT_EQ(customTensorType.getElementType(), + customCloneType.getElementType()); + EXPECT_TRUE(mlir::isa(customCloneType)); + EXPECT_TRUE(mlir::isa(customCloneType)); + + // user-specified conversions + TensorType baseCopy = customTensorType; + std::ignore = baseCopy; + + ShapedType shapedBaseCopy = customTensorType; + std::ignore = shapedBaseCopy; +} + +TEST(InterfaceTest, TestCustomMemrefIsBaseMemref) { + MLIRContext context; + context.loadDialect(); + + auto customMemrefType = test::TestMemrefType::get( + &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32), + mlir::StringAttr::get(&context, "some_memspace")); + EXPECT_TRUE(mlir::isa(customMemrefType)); + + auto customCloneType = customMemrefType.clone({3, 4, 5}); + EXPECT_EQ(customMemrefType.getElementType(), + customCloneType.getElementType()); + EXPECT_TRUE(mlir::isa(customCloneType)); + EXPECT_TRUE(mlir::isa(customCloneType)); + EXPECT_EQ(customMemrefType.getMemorySpace(), + mlir::cast(customCloneType).getMemorySpace()); + + // user-specified conversions + BaseMemRefType baseCopy = customMemrefType; + std::ignore = baseCopy; + + ShapedType shapedBaseCopy = customMemrefType; + std::ignore = shapedBaseCopy; +}