diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index bf02911e19351..c96e4217d21f0 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -343,10 +343,13 @@ Register getFunctionLiveInPhysReg(MachineFunction &MF, const TargetRegisterClass &RC, const DebugLoc &DL, LLT RegTy = LLT()); -/// Return the least common multiple type of \p OrigTy and \p TargetTy, by changing the -/// number of vector elements or scalar bitwidth. The intent is a +/// Return the least common multiple type of \p OrigTy and \p TargetTy, by +/// changing the number of vector elements or scalar bitwidth. The intent is a /// G_MERGE_VALUES, G_BUILD_VECTOR, or G_CONCAT_VECTORS can be constructed from -/// \p OrigTy elements, and unmerged into \p TargetTy +/// \p OrigTy elements, and unmerged into \p TargetTy. It is an error to call +/// this function where one argument is a fixed vector and the other is a +/// scalable vector, since it is illegal to build a G_{MERGE|UNMERGE}_VALUES +/// between fixed and scalable vectors. LLVM_READNONE LLT getLCMType(LLT OrigTy, LLT TargetTy); diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index aed826a9cbc54..2f6b861819b66 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -1071,49 +1071,70 @@ void llvm::getSelectionDAGFallbackAnalysisUsage(AnalysisUsage &AU) { } LLT llvm::getLCMType(LLT OrigTy, LLT TargetTy) { - const unsigned OrigSize = OrigTy.getSizeInBits(); - const unsigned TargetSize = TargetTy.getSizeInBits(); - - if (OrigSize == TargetSize) + if (OrigTy.getSizeInBits() == TargetTy.getSizeInBits()) return OrigTy; - if (OrigTy.isVector()) { - const LLT OrigElt = OrigTy.getElementType(); - - if (TargetTy.isVector()) { - const LLT TargetElt = TargetTy.getElementType(); + if (OrigTy.isVector() && TargetTy.isVector()) { + LLT OrigElt = OrigTy.getElementType(); + LLT TargetElt = TargetTy.getElementType(); - if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) { - int GCDElts = - std::gcd(OrigTy.getNumElements(), TargetTy.getNumElements()); - // Prefer the original element type. - ElementCount Mul = OrigTy.getElementCount() * TargetTy.getNumElements(); - return LLT::vector(Mul.divideCoefficientBy(GCDElts), - OrigTy.getElementType()); - } - } else { - if (OrigElt.getSizeInBits() == TargetSize) - return OrigTy; + // TODO: The docstring for this function says the intention is to use this + // function to build MERGE/UNMERGE instructions. It won't be the case that + // we generate a MERGE/UNMERGE between fixed and scalable vector types. We + // could implement getLCMType between the two in the future if there was a + // need, but it is not worth it now as this function should not be used in + // that way. + assert(((OrigTy.isScalableVector() && !TargetTy.isFixedVector()) || + (OrigTy.isFixedVector() && !TargetTy.isScalableVector())) && + "getLCMType not implemented between fixed and scalable vectors."); + + if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) { + int GCDMinElts = std::gcd(OrigTy.getElementCount().getKnownMinValue(), + TargetTy.getElementCount().getKnownMinValue()); + // Prefer the original element type. + ElementCount Mul = OrigTy.getElementCount().multiplyCoefficientBy( + TargetTy.getElementCount().getKnownMinValue()); + return LLT::vector(Mul.divideCoefficientBy(GCDMinElts), + OrigTy.getElementType()); } - - unsigned LCMSize = std::lcm(OrigSize, TargetSize); - return LLT::fixed_vector(LCMSize / OrigElt.getSizeInBits(), OrigElt); + unsigned LCM = std::lcm(OrigTy.getSizeInBits().getKnownMinValue(), + TargetTy.getSizeInBits().getKnownMinValue()); + return LLT::vector( + ElementCount::get(LCM / OrigElt.getSizeInBits(), OrigTy.isScalable()), + OrigElt); } - if (TargetTy.isVector()) { - unsigned LCMSize = std::lcm(OrigSize, TargetSize); - return LLT::fixed_vector(LCMSize / OrigSize, OrigTy); + // One type is scalar, one type is vector + if (OrigTy.isVector() || TargetTy.isVector()) { + LLT VecTy = OrigTy.isVector() ? OrigTy : TargetTy; + LLT ScalarTy = OrigTy.isVector() ? TargetTy : OrigTy; + LLT EltTy = VecTy.getElementType(); + LLT OrigEltTy = OrigTy.isVector() ? OrigTy.getElementType() : OrigTy; + + // Prefer scalar type from OrigTy. + if (EltTy.getSizeInBits() == ScalarTy.getSizeInBits()) + return LLT::vector(VecTy.getElementCount(), OrigEltTy); + + // Different size scalars. Create vector with the same total size. + // LCM will take fixed/scalable from VecTy. + unsigned LCM = std::lcm(EltTy.getSizeInBits().getFixedValue() * + VecTy.getElementCount().getKnownMinValue(), + ScalarTy.getSizeInBits().getFixedValue()); + // Prefer type from OrigTy + return LLT::vector(ElementCount::get(LCM / OrigEltTy.getSizeInBits(), + VecTy.getElementCount().isScalable()), + OrigEltTy); } - unsigned LCMSize = std::lcm(OrigSize, TargetSize); - + // At this point, both types are scalars of different size + unsigned LCM = std::lcm(OrigTy.getSizeInBits().getFixedValue(), + TargetTy.getSizeInBits().getFixedValue()); // Preserve pointer types. - if (LCMSize == OrigSize) + if (LCM == OrigTy.getSizeInBits()) return OrigTy; - if (LCMSize == TargetSize) + if (LCM == TargetTy.getSizeInBits()) return TargetTy; - - return LLT::scalar(LCMSize); + return LLT::scalar(LCM); } LLT llvm::getCoverTy(LLT OrigTy, LLT TargetTy) { diff --git a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp index 8fda332d5c054..92bd0a36b82b4 100644 --- a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp @@ -46,6 +46,37 @@ static const LLT V6P0 = LLT::fixed_vector(6, P0); static const LLT V2P1 = LLT::fixed_vector(2, P1); static const LLT V4P1 = LLT::fixed_vector(4, P1); +static const LLT NXV1S1 = LLT::scalable_vector(1, S1); +static const LLT NXV2S1 = LLT::scalable_vector(2, S1); +static const LLT NXV3S1 = LLT::scalable_vector(3, S1); +static const LLT NXV4S1 = LLT::scalable_vector(4, S1); +static const LLT NXV12S1 = LLT::scalable_vector(12, S1); +static const LLT NXV32S1 = LLT::scalable_vector(32, S1); +static const LLT NXV64S1 = LLT::scalable_vector(64, S1); +static const LLT NXV128S1 = LLT::scalable_vector(128, S1); +static const LLT NXV384S1 = LLT::scalable_vector(384, S1); + +static const LLT NXV1S32 = LLT::scalable_vector(1, S32); +static const LLT NXV2S32 = LLT::scalable_vector(2, S32); +static const LLT NXV3S32 = LLT::scalable_vector(3, S32); +static const LLT NXV4S32 = LLT::scalable_vector(4, S32); +static const LLT NXV8S32 = LLT::scalable_vector(8, S32); +static const LLT NXV12S32 = LLT::scalable_vector(12, S32); +static const LLT NXV24S32 = LLT::scalable_vector(24, S32); + +static const LLT NXV1S64 = LLT::scalable_vector(1, S64); +static const LLT NXV2S64 = LLT::scalable_vector(2, S64); +static const LLT NXV3S64 = LLT::scalable_vector(3, S64); +static const LLT NXV4S64 = LLT::scalable_vector(4, S64); +static const LLT NXV6S64 = LLT::scalable_vector(6, S64); +static const LLT NXV12S64 = LLT::scalable_vector(12, S64); + +static const LLT NXV1P0 = LLT::scalable_vector(1, P0); +static const LLT NXV2P0 = LLT::scalable_vector(2, P0); +static const LLT NXV3P0 = LLT::scalable_vector(3, P0); +static const LLT NXV4P0 = LLT::scalable_vector(4, P0); +static const LLT NXV12P0 = LLT::scalable_vector(12, P0); + TEST(GISelUtilsTest, getGCDType) { EXPECT_EQ(S1, getGCDType(S1, S1)); EXPECT_EQ(S32, getGCDType(S32, S32)); @@ -244,6 +275,62 @@ TEST(GISelUtilsTest, getLCMType) { EXPECT_EQ(V2S64, getLCMType(V2S64, P1)); EXPECT_EQ(V4P1, getLCMType(P1, V2S64)); + + // Scalable, Scalable + EXPECT_EQ(NXV32S1, getLCMType(NXV1S1, NXV1S32)); + EXPECT_EQ(NXV1S64, getLCMType(NXV1S64, NXV1S32)); + EXPECT_EQ(NXV2S32, getLCMType(NXV1S32, NXV1S64)); + EXPECT_EQ(NXV1P0, getLCMType(NXV1P0, NXV1S64)); + EXPECT_EQ(NXV1S64, getLCMType(NXV1S64, NXV1P0)); + + EXPECT_EQ(NXV128S1, getLCMType(NXV4S1, NXV4S32)); + EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV4S32)); + EXPECT_EQ(NXV8S32, getLCMType(NXV4S32, NXV4S64)); + EXPECT_EQ(NXV4P0, getLCMType(NXV4P0, NXV4S64)); + EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV4P0)); + + EXPECT_EQ(NXV64S1, getLCMType(NXV4S1, NXV2S32)); + EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV2S32)); + EXPECT_EQ(NXV4S32, getLCMType(NXV4S32, NXV2S64)); + EXPECT_EQ(NXV4P0, getLCMType(NXV4P0, NXV2S64)); + EXPECT_EQ(NXV4S64, getLCMType(NXV4S64, NXV2P0)); + + EXPECT_EQ(NXV128S1, getLCMType(NXV2S1, NXV4S32)); + EXPECT_EQ(NXV2S64, getLCMType(NXV2S64, NXV4S32)); + EXPECT_EQ(NXV8S32, getLCMType(NXV2S32, NXV4S64)); + EXPECT_EQ(NXV4P0, getLCMType(NXV2P0, NXV4S64)); + EXPECT_EQ(NXV4S64, getLCMType(NXV2S64, NXV4P0)); + + EXPECT_EQ(NXV384S1, getLCMType(NXV3S1, NXV4S32)); + EXPECT_EQ(NXV6S64, getLCMType(NXV3S64, NXV4S32)); + EXPECT_EQ(NXV24S32, getLCMType(NXV3S32, NXV4S64)); + EXPECT_EQ(NXV12P0, getLCMType(NXV3P0, NXV4S64)); + EXPECT_EQ(NXV12S64, getLCMType(NXV3S64, NXV4P0)); + + EXPECT_EQ(NXV12S1, getLCMType(NXV3S1, NXV4S1)); + EXPECT_EQ(NXV12S32, getLCMType(NXV3S32, NXV4S32)); + EXPECT_EQ(NXV12S64, getLCMType(NXV3S64, NXV4S64)); + EXPECT_EQ(NXV12P0, getLCMType(NXV3P0, NXV4P0)); + + // Scalable, Scalar + + EXPECT_EQ(NXV1S1, getLCMType(NXV1S1, S1)); + EXPECT_EQ(NXV32S1, getLCMType(NXV1S1, S32)); + EXPECT_EQ(NXV1S32, getLCMType(NXV1S32, S1)); + EXPECT_EQ(NXV1S32, getLCMType(NXV1S32, S32)); + EXPECT_EQ(NXV2S32, getLCMType(NXV1S32, S64)); + EXPECT_EQ(NXV2S32, getLCMType(NXV2S32, S1)); + EXPECT_EQ(NXV2S32, getLCMType(NXV2S32, S32)); + EXPECT_EQ(NXV2S32, getLCMType(NXV2S32, S64)); + + EXPECT_EQ(NXV1S1, getLCMType(S1, NXV1S1)); + EXPECT_EQ(NXV1S32, getLCMType(S32, NXV1S1)); + EXPECT_EQ(NXV32S1, getLCMType(S1, NXV1S32)); + EXPECT_EQ(NXV1S32, getLCMType(S32, NXV1S32)); + EXPECT_EQ(NXV1S64, getLCMType(S64, NXV1S32)); + EXPECT_EQ(NXV64S1, getLCMType(S1, NXV2S32)); + EXPECT_EQ(NXV2S32, getLCMType(S32, NXV2S32)); + EXPECT_EQ(NXV1S64, getLCMType(S64, NXV2S32)); } TEST_F(AArch64GISelMITest, ConstFalseTest) {