diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index c96e4217d21f0..f8900f3434cca 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -368,7 +368,10 @@ LLT getCoverTy(LLT OrigTy, LLT TargetTy); /// If these are vectors with different element types, this will try to produce /// a vector with a compatible total size, but the element type of \p OrigTy. If /// this can't be satisfied, this will produce a scalar smaller than the -/// original vector elements. +/// original vector elements. It is an error to call this function where +/// one argument is a fixed vector and the other is a scalable vector, since it +/// is illegal to build a G_{MERGE|UNMERGE}_VALUES between fixed and scalable +/// vectors. /// /// In the worst case, this returns LLT::scalar(1) LLVM_READNONE diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index dd99381093b6a..26fd12f9e51c4 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -1159,45 +1159,56 @@ LLT llvm::getCoverTy(LLT OrigTy, LLT TargetTy) { } LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) { - const unsigned OrigSize = OrigTy.getSizeInBits(); - const unsigned TargetSize = TargetTy.getSizeInBits(); - - if (OrigSize == TargetSize) + if (OrigTy.getSizeInBits() == TargetTy.getSizeInBits()) return OrigTy; - if (OrigTy.isVector()) { + if (OrigTy.isVector() && TargetTy.isVector()) { LLT OrigElt = OrigTy.getElementType(); - if (TargetTy.isVector()) { - LLT TargetElt = TargetTy.getElementType(); - if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) { - int GCD = std::gcd(OrigTy.getNumElements(), TargetTy.getNumElements()); - return LLT::scalarOrVector(ElementCount::getFixed(GCD), OrigElt); - } - } else { - // If the source is a vector of pointers, return a pointer element. - if (OrigElt.getSizeInBits() == TargetSize) - return OrigElt; - } - unsigned GCD = std::gcd(OrigSize, TargetSize); + // TODO: The docstring for this function says the intention is to use this + // function to build MERGE/UNMERGE instructions. It won't be the case that + // we generate a MERGE/UNMERGE between fixed and scalable vector types. We + // could implement getGCDType between the two in the future if there was a + // need, but it is not worth it now as this function should not be used in + // that way. + assert(((OrigTy.isScalableVector() && !TargetTy.isFixedVector()) || + (OrigTy.isFixedVector() && !TargetTy.isScalableVector())) && + "getGCDType not implemented between fixed and scalable vectors."); + + unsigned GCD = std::gcd(OrigTy.getSizeInBits().getKnownMinValue(), + TargetTy.getSizeInBits().getKnownMinValue()); if (GCD == OrigElt.getSizeInBits()) - return OrigElt; + return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()), + OrigElt); - // If we can't produce the original element type, we have to use a smaller - // scalar. + // Cannot produce original element type, but both have vscale in common. if (GCD < OrigElt.getSizeInBits()) - return LLT::scalar(GCD); - return LLT::fixed_vector(GCD / OrigElt.getSizeInBits(), OrigElt); - } + return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()), + GCD); - if (TargetTy.isVector()) { - // Try to preserve the original element type. - LLT TargetElt = TargetTy.getElementType(); - if (TargetElt.getSizeInBits() == OrigSize) - return OrigTy; + return LLT::vector( + ElementCount::get(GCD / OrigElt.getSizeInBits().getFixedValue(), + OrigTy.isScalable()), + OrigElt); } - unsigned GCD = std::gcd(OrigSize, TargetSize); + // If one type is vector and the element size matches the scalar size, then + // the gcd is the scalar type. + if (OrigTy.isVector() && + OrigTy.getElementType().getSizeInBits() == TargetTy.getSizeInBits()) + return OrigTy.getElementType(); + if (TargetTy.isVector() && + TargetTy.getElementType().getSizeInBits() == OrigTy.getSizeInBits()) + return OrigTy; + + // At this point, both types are either scalars of different type or one is a + // vector and one is a scalar. If both types are scalars, the GCD type is the + // GCD between the two scalar sizes. If one is vector and one is scalar, then + // the GCD type is the GCD between the scalar and the vector element size. + LLT OrigScalar = OrigTy.getScalarType(); + LLT TargetScalar = TargetTy.getScalarType(); + unsigned GCD = std::gcd(OrigScalar.getSizeInBits().getFixedValue(), + TargetScalar.getSizeInBits().getFixedValue()); return LLT::scalar(GCD); } diff --git a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp index 92bd0a36b82b4..1ff7fd956d015 100644 --- a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp @@ -183,6 +183,62 @@ TEST(GISelUtilsTest, getGCDType) { EXPECT_EQ(LLT::scalar(4), getGCDType(LLT::fixed_vector(3, 4), S8)); EXPECT_EQ(LLT::scalar(4), getGCDType(S8, LLT::fixed_vector(3, 4))); + + // Scalable -> Scalable + EXPECT_EQ(NXV1S1, getGCDType(NXV1S1, NXV1S32)); + EXPECT_EQ(NXV1S32, getGCDType(NXV1S64, NXV1S32)); + EXPECT_EQ(NXV1S32, getGCDType(NXV1S32, NXV1S64)); + EXPECT_EQ(NXV1P0, getGCDType(NXV1P0, NXV1S64)); + EXPECT_EQ(NXV1S64, getGCDType(NXV1S64, NXV1P0)); + + EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV4S32)); + EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV4S32)); + EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV4S64)); + EXPECT_EQ(NXV4P0, getGCDType(NXV4P0, NXV4S64)); + EXPECT_EQ(NXV4S64, getGCDType(NXV4S64, NXV4P0)); + + EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV2S32)); + EXPECT_EQ(NXV1S64, getGCDType(NXV4S64, NXV2S32)); + EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV2S64)); + EXPECT_EQ(NXV2P0, getGCDType(NXV4P0, NXV2S64)); + EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV2P0)); + + EXPECT_EQ(NXV2S1, getGCDType(NXV2S1, NXV4S32)); + EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4S32)); + EXPECT_EQ(NXV2S32, getGCDType(NXV2S32, NXV4S64)); + EXPECT_EQ(NXV2P0, getGCDType(NXV2P0, NXV4S64)); + EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4P0)); + + EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S32)); + EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S32)); + EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S64)); + EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4S64)); + EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4P0)); + + EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S1)); + EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S32)); + EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S64)); + EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4P0)); + + // Scalable, Scalar + + EXPECT_EQ(S1, getGCDType(NXV1S1, S1)); + EXPECT_EQ(S1, getGCDType(NXV1S1, S32)); + EXPECT_EQ(S1, getGCDType(NXV1S32, S1)); + EXPECT_EQ(S32, getGCDType(NXV1S32, S32)); + EXPECT_EQ(S32, getGCDType(NXV1S32, S64)); + EXPECT_EQ(S1, getGCDType(NXV2S32, S1)); + EXPECT_EQ(S32, getGCDType(NXV2S32, S32)); + EXPECT_EQ(S32, getGCDType(NXV2S32, S64)); + + EXPECT_EQ(S1, getGCDType(S1, NXV1S1)); + EXPECT_EQ(S1, getGCDType(S32, NXV1S1)); + EXPECT_EQ(S1, getGCDType(S1, NXV1S32)); + EXPECT_EQ(S32, getGCDType(S32, NXV1S32)); + EXPECT_EQ(S32, getGCDType(S64, NXV1S32)); + EXPECT_EQ(S1, getGCDType(S1, NXV2S32)); + EXPECT_EQ(S32, getGCDType(S32, NXV2S32)); + EXPECT_EQ(S32, getGCDType(S64, NXV2S32)); } TEST(GISelUtilsTest, getLCMType) {