diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 4d91d98a9ba89..e46b6a4a6bb69 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -91,23 +91,12 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { - if (llvm::isa(*this)) - return 6; - if (llvm::isa(*this)) - return 8; - if (llvm::isa(*this)) - return 16; - if (llvm::isa(*this)) + // The actual width of TF32 is 19 bits. However, since it is a truncated + // version of Float32, we treat it as 32 bits in MLIR FloatType::getWidth + // for compatibility. + if (llvm::isa(*this)) return 32; - if (llvm::isa(*this)) - return 64; - if (llvm::isa(*this)) - return 80; - if (llvm::isa(*this)) - return 128; - llvm_unreachable("unexpected float type"); + return APFloat::semanticsSizeInBits(getFloatSemantics()); } /// Returns the floating semantics for the given type.