diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index de74524c4b6fe..53b84edafbb66 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -65,6 +65,128 @@ enum SCEVTypes : unsigned short; extern bool VerifySCEV; +class SCEV; + +struct SCEVUse : public PointerIntPair { + SCEVUse() : PointerIntPair(nullptr, 0) {} + SCEVUse(const SCEV *S) : PointerIntPair(S, 0) {} + SCEVUse(const SCEV *S, int Flags) : PointerIntPair(S, Flags) {} + + operator const SCEV *() const { return getPointer(); } + const SCEV *operator->() const { return getPointer(); } + const SCEV *operator->() { return getPointer(); } + + void *getRawPointer() const { return getOpaqueValue(); } + + bool isCanonical() const; + + const SCEV *getCanonical() const; + + unsigned getFlags() const { return getInt(); } + + bool operator==(const SCEVUse &RHS) const { + assert(isCanonical() && RHS.isCanonical()); + return getRawPointer() == RHS.getRawPointer() || + getCanonical() == RHS.getCanonical(); + } + + bool operator==(const SCEV *RHS) const { return getRawPointer() == RHS; } + + /// Print out the internal representation of this scalar to the specified + /// stream. This should really only be used for debugging purposes. + void print(raw_ostream &OS) const; + + /// This method is used for debugging. + void dump() const; + + static const SCEV *computeCanonical(ScalarEvolution &SE, const SCEV *); +}; + +/// Provide PointerLikeTypeTraits for SCEVUse, so it can be used with +/// SmallPtrSet, among others. +template <> struct PointerLikeTypeTraits { + static inline void *getAsVoidPointer(SCEVUse U) { return U.getOpaqueValue(); } + static inline SCEVUse getFromVoidPointer(void *P) { + SCEVUse U; + U.setFromOpaqueValue(P); + return U; + } + + /// The Low bits are used by the PointerIntPair. + static constexpr int NumLowBitsAvailable = 0; +}; + +template <> struct DenseMapInfo { + // The following should hold, but it would require T to be complete: + // static_assert(alignof(T) <= (1 << Log2MaxAlign), + // "DenseMap does not support pointer keys requiring more than " + // "Log2MaxAlign bits of alignment"); + static constexpr uintptr_t Log2MaxAlign = 12; + + static inline SCEVUse getEmptyKey() { + uintptr_t Val = static_cast(-1); + Val <<= Log2MaxAlign; + return PointerLikeTypeTraits::getFromVoidPointer((void *)Val); + } + + static inline SCEVUse getTombstoneKey() { + uintptr_t Val = static_cast(-2); + Val <<= Log2MaxAlign; + return PointerLikeTypeTraits::getFromVoidPointer((void *)Val); + } + + static unsigned getHashValue(SCEVUse U) { + void *PtrVal = PointerLikeTypeTraits::getAsVoidPointer(U); + return (unsigned((uintptr_t)PtrVal) >> 4) ^ + (unsigned((uintptr_t)PtrVal) >> 9); + } + + static bool isEqual(const SCEVUse LHS, const SCEVUse RHS) { + return LHS.getRawPointer() == RHS.getRawPointer(); + } +}; + +inline bool SCEVUse::isCanonical() const { + if (getInt() != 0) + return false; + if (!getRawPointer() || + DenseMapInfo::getEmptyKey().getRawPointer() == getRawPointer() || + DenseMapInfo::getTombstoneKey().getRawPointer() == + getRawPointer()) + return true; + return getCanonical() == getPointer(); +} + +template [[nodiscard]] inline decltype(auto) dyn_cast(SCEVUse U) { + assert(detail::isPresent(U.getPointer()) && + "dyn_cast on a non-existent value"); + return CastInfo::doCastIfPossible(U.getPointer()); +} + +template +[[nodiscard]] inline decltype(auto) dyn_cast_if_present(SCEVUse U) { + assert(detail::isPresent(U.getPointer()) && + "dyn_cast on a non-existent value"); + return CastInfo::doCastIfPossible(U.getPointer()); +} + +template [[nodiscard]] inline decltype(auto) cast(SCEVUse U) { + assert(detail::isPresent(U.getPointer()) && + "dyn_cast on a non-existent value"); + return CastInfo::doCast(U.getPointer()); +} + +template [[nodiscard]] inline bool isa(SCEVUse U) { + return CastInfo::isPossible(U.getPointer()); +} + +template auto dyn_cast_or_null(SCEVUse U) { + const SCEV *Val = U.getPointer(); + if (!detail::isPresent(Val)) + return CastInfo::castFailed(); + return CastInfo::doCastIfPossible(detail::unwrapValue(Val)); +} + /// This class represents an analyzed expression in the program. These are /// opaque objects that the client is not allowed to do much with directly. /// @@ -86,6 +208,8 @@ class SCEV : public FoldingSetNode { /// miscellaneous information. unsigned short SubclassData = 0; + const SCEV *CanonicalSCEV; + public: /// NoWrapFlags are bitfield indices into SubclassData. /// @@ -143,7 +267,7 @@ class SCEV : public FoldingSetNode { Type *getType() const; /// Return operands of this SCEV expression. - ArrayRef operands() const; + ArrayRef operands() const; /// Return true if the expression is a constant zero. bool isZero() const; @@ -176,6 +300,10 @@ class SCEV : public FoldingSetNode { /// This method is used for debugging. void dump() const; + + void setCanonical(const SCEV *S) { CanonicalSCEV = S; } + + const SCEV *getCanonical() const { return CanonicalSCEV; } }; // Specialize FoldingSetTrait for SCEV to avoid needing to compute @@ -198,6 +326,15 @@ inline raw_ostream &operator<<(raw_ostream &OS, const SCEV &S) { return OS; } +inline raw_ostream &operator<<(raw_ostream &OS, const SCEVUse &S) { + S.print(OS); + return OS; +} + +inline const SCEV *SCEVUse::getCanonical() const { + return getPointer()->getCanonical(); +} + /// An object of this class is returned by queries that could not be answered. /// For example, if you ask for the number of iterations of a linked-list /// traversal loop, you will get one of these. None of the standard SCEV @@ -207,6 +344,7 @@ struct SCEVCouldNotCompute : public SCEV { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S); + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents an assumption made using SCEV expressions which can @@ -277,13 +415,13 @@ struct FoldingSetTrait : DefaultFoldingSetTrait { class SCEVComparePredicate final : public SCEVPredicate { /// We assume that LHS Pred RHS is true. const ICmpInst::Predicate Pred; - const SCEV *LHS; - const SCEV *RHS; + SCEVUse LHS; + SCEVUse RHS; public: SCEVComparePredicate(const FoldingSetNodeIDRef ID, - const ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + const ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Implementation of the SCEVPredicate interface bool implies(const SCEVPredicate *N) const override; @@ -293,10 +431,10 @@ class SCEVComparePredicate final : public SCEVPredicate { ICmpInst::Predicate getPredicate() const { return Pred; } /// Returns the left hand side of the predicate. - const SCEV *getLHS() const { return LHS; } + SCEVUse getLHS() const { return LHS; } /// Returns the right hand side of the predicate. - const SCEV *getRHS() const { return RHS; } + SCEVUse getRHS() const { return RHS; } /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEVPredicate *P) { @@ -411,8 +549,7 @@ class SCEVWrapPredicate final : public SCEVPredicate { /// ScalarEvolution::Preds folding set. This is why the \c add function is sound. class SCEVUnionPredicate final : public SCEVPredicate { private: - using PredicateMap = - DenseMap>; + using PredicateMap = DenseMap>; /// Vector with references to all predicates in this union. SmallVector Preds; @@ -519,18 +656,17 @@ class ScalarEvolution { /// loop { v2 = load @global2; } /// } /// No SCEV with operand V1, and v2 can exist in this program. - bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B); + bool instructionCouldExistWithOperands(SCEVUse A, SCEVUse B); /// Return true if the SCEV is a scAddRecExpr or it contains /// scAddRecExpr. The result will be cached in HasRecMap. - bool containsAddRecurrence(const SCEV *S); + bool containsAddRecurrence(SCEVUse S); /// Is operation \p BinOp between \p LHS and \p RHS provably does not have /// a signed/unsigned overflow (\p Signed)? If \p CtxI is specified, the /// no-overflow fact should be true in the context of this instruction. - bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, - const SCEV *LHS, const SCEV *RHS, - const Instruction *CtxI = nullptr); + bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, SCEVUse LHS, + SCEVUse RHS, const Instruction *CtxI = nullptr); /// Parse NSW/NUW flags from add/sub/mul IR binary operation \p Op into /// SCEV no-wrap flags, and deduce flag[s] that aren't known yet. @@ -541,78 +677,84 @@ class ScalarEvolution { getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO); /// Notify this ScalarEvolution that \p User directly uses SCEVs in \p Ops. - void registerUser(const SCEV *User, ArrayRef Ops); + void registerUser(SCEVUse User, ArrayRef Ops); /// Return true if the SCEV expression contains an undef value. - bool containsUndefs(const SCEV *S) const; + bool containsUndefs(SCEVUse S) const; /// Return true if the SCEV expression contains a Value that has been /// optimised out and is now a nullptr. - bool containsErasedValue(const SCEV *S) const; + bool containsErasedValue(SCEVUse S) const; /// Return a SCEV expression for the full generality of the specified /// expression. - const SCEV *getSCEV(Value *V); + SCEVUse getSCEV(Value *V); /// Return an existing SCEV for V if there is one, otherwise return nullptr. - const SCEV *getExistingSCEV(Value *V); - - const SCEV *getConstant(ConstantInt *V); - const SCEV *getConstant(const APInt &Val); - const SCEV *getConstant(Type *Ty, uint64_t V, bool isSigned = false); - const SCEV *getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth = 0); - const SCEV *getPtrToIntExpr(const SCEV *Op, Type *Ty); - const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); - const SCEV *getVScale(Type *Ty); - const SCEV *getElementCount(Type *Ty, ElementCount EC); - const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); - const SCEV *getZeroExtendExprImpl(const SCEV *Op, Type *Ty, - unsigned Depth = 0); - const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); - const SCEV *getSignExtendExprImpl(const SCEV *Op, Type *Ty, - unsigned Depth = 0); - const SCEV *getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty); - const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty); - const SCEV *getAddExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0); - const SCEV *getAddExpr(const SCEV *LHS, const SCEV *RHS, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0) { - SmallVector Ops = {LHS, RHS}; + SCEVUse getExistingSCEV(Value *V); + + SCEVUse getConstant(ConstantInt *V); + SCEVUse getConstant(const APInt &Val); + SCEVUse getConstant(Type *Ty, uint64_t V, bool isSigned = false); + SCEVUse getLosslessPtrToIntExpr(SCEVUse Op, unsigned Depth = 0); + SCEVUse getPtrToIntExpr(SCEVUse Op, Type *Ty); + SCEVUse getTruncateExpr(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getVScale(Type *Ty); + SCEVUse getElementCount(Type *Ty, ElementCount EC); + SCEVUse getZeroExtendExpr(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getZeroExtendExprImpl(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getSignExtendExpr(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getSignExtendExprImpl(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getCastExpr(SCEVTypes Kind, SCEVUse Op, Type *Ty); + SCEVUse getAnyExtendExpr(SCEVUse Op, Type *Ty); + SCEVUse getAddExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); + SCEVUse getAddExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); + SCEVUse getAddExpr(SCEVUse LHS, SCEVUse RHS, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0) { + SmallVector Ops = {LHS, RHS}; return getAddExpr(Ops, Flags, Depth); } - const SCEV *getAddExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0) { - SmallVector Ops = {Op0, Op1, Op2}; + SCEVUse getAddExpr(SCEVUse Op0, SCEVUse Op1, SCEVUse Op2, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0) { + SmallVector Ops = {Op0, Op1, Op2}; return getAddExpr(Ops, Flags, Depth); } - const SCEV *getMulExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0); - const SCEV *getMulExpr(const SCEV *LHS, const SCEV *RHS, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0) { - SmallVector Ops = {LHS, RHS}; + SCEVUse getMulExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); + SCEVUse getMulExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); + SCEVUse getMulExpr(SCEVUse LHS, SCEVUse RHS, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0) { + SmallVector Ops = {LHS, RHS}; return getMulExpr(Ops, Flags, Depth); } - const SCEV *getMulExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0) { - SmallVector Ops = {Op0, Op1, Op2}; + SCEVUse getMulExpr(SCEVUse Op0, SCEVUse Op1, SCEVUse Op2, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0) { + SmallVector Ops = {Op0, Op1, Op2}; return getMulExpr(Ops, Flags, Depth); } - const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getUDivExactExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getURemExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, - SCEV::NoWrapFlags Flags); - const SCEV *getAddRecExpr(SmallVectorImpl &Operands, - const Loop *L, SCEV::NoWrapFlags Flags); - const SCEV *getAddRecExpr(const SmallVectorImpl &Operands, - const Loop *L, SCEV::NoWrapFlags Flags) { - SmallVector NewOp(Operands.begin(), Operands.end()); + SCEVUse getUDivExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getUDivExactExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getURemExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, + SCEV::NoWrapFlags Flags); + SCEVUse getAddRecExpr(ArrayRef Operands, const Loop *L, + SCEV::NoWrapFlags Flags); + SCEVUse getAddRecExpr(SmallVectorImpl &Operands, const Loop *L, + SCEV::NoWrapFlags Flags); + SCEVUse getAddRecExpr(const SmallVectorImpl &Operands, const Loop *L, + SCEV::NoWrapFlags Flags) { + SmallVector NewOp(Operands.begin(), Operands.end()); return getAddRecExpr(NewOp, L, Flags); } @@ -620,7 +762,7 @@ class ScalarEvolution { /// Predicates. If successful return these ; /// The function is intended to be called from PSCEV (the caller will decide /// whether to actually add the predicates and carry out the rewrites). - std::optional>> + std::optional>> createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI); /// Returns an expression for a GEP @@ -628,61 +770,61 @@ class ScalarEvolution { /// \p GEP The GEP. The indices contained in the GEP itself are ignored, /// instead we use IndexExprs. /// \p IndexExprs The expressions for the indices. - const SCEV *getGEPExpr(GEPOperator *GEP, - const SmallVectorImpl &IndexExprs); - const SCEV *getAbsExpr(const SCEV *Op, bool IsNSW); - const SCEV *getMinMaxExpr(SCEVTypes Kind, - SmallVectorImpl &Operands); - const SCEV *getSequentialMinMaxExpr(SCEVTypes Kind, - SmallVectorImpl &Operands); - const SCEV *getSMaxExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getSMaxExpr(SmallVectorImpl &Operands); - const SCEV *getUMaxExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getUMaxExpr(SmallVectorImpl &Operands); - const SCEV *getSMinExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getSMinExpr(SmallVectorImpl &Operands); - const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS, - bool Sequential = false); - const SCEV *getUMinExpr(SmallVectorImpl &Operands, - bool Sequential = false); - const SCEV *getUnknown(Value *V); - const SCEV *getCouldNotCompute(); + SCEVUse getGEPExpr(GEPOperator *GEP, ArrayRef IndexExprs); + SCEVUse getGEPExpr(GEPOperator *GEP, + const SmallVectorImpl &IndexExprs); + SCEVUse getAbsExpr(SCEVUse Op, bool IsNSW); + SCEVUse getMinMaxExpr(SCEVTypes Kind, ArrayRef Operands); + SCEVUse getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl &Operands); + SCEVUse getSequentialMinMaxExpr(SCEVTypes Kind, + SmallVectorImpl &Operands); + SCEVUse getSMaxExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getSMaxExpr(SmallVectorImpl &Operands); + SCEVUse getUMaxExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getUMaxExpr(SmallVectorImpl &Operands); + SCEVUse getSMinExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getSMinExpr(SmallVectorImpl &Operands); + SCEVUse getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential = false); + SCEVUse getUMinExpr(SmallVectorImpl &Operands, + bool Sequential = false); + SCEVUse getUnknown(Value *V); + SCEVUse getCouldNotCompute(); /// Return a SCEV for the constant 0 of a specific type. - const SCEV *getZero(Type *Ty) { return getConstant(Ty, 0); } + SCEVUse getZero(Type *Ty) { return getConstant(Ty, 0); } /// Return a SCEV for the constant 1 of a specific type. - const SCEV *getOne(Type *Ty) { return getConstant(Ty, 1); } + SCEVUse getOne(Type *Ty) { return getConstant(Ty, 1); } /// Return a SCEV for the constant \p Power of two. - const SCEV *getPowerOfTwo(Type *Ty, unsigned Power) { + SCEVUse getPowerOfTwo(Type *Ty, unsigned Power) { assert(Power < getTypeSizeInBits(Ty) && "Power out of range"); return getConstant(APInt::getOneBitSet(getTypeSizeInBits(Ty), Power)); } /// Return a SCEV for the constant -1 of a specific type. - const SCEV *getMinusOne(Type *Ty) { + SCEVUse getMinusOne(Type *Ty) { return getConstant(Ty, -1, /*isSigned=*/true); } /// Return an expression for a TypeSize. - const SCEV *getSizeOfExpr(Type *IntTy, TypeSize Size); + SCEVUse getSizeOfExpr(Type *IntTy, TypeSize Size); /// Return an expression for the alloc size of AllocTy that is type IntTy - const SCEV *getSizeOfExpr(Type *IntTy, Type *AllocTy); + SCEVUse getSizeOfExpr(Type *IntTy, Type *AllocTy); /// Return an expression for the store size of StoreTy that is type IntTy - const SCEV *getStoreSizeOfExpr(Type *IntTy, Type *StoreTy); + SCEVUse getStoreSizeOfExpr(Type *IntTy, Type *StoreTy); /// Return an expression for offsetof on the given field with type IntTy - const SCEV *getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo); + SCEVUse getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo); /// Return the SCEV object corresponding to -V. - const SCEV *getNegativeSCEV(const SCEV *V, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); + SCEVUse getNegativeSCEV(SCEVUse V, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); /// Return the SCEV object corresponding to ~V. - const SCEV *getNotSCEV(const SCEV *V); + SCEVUse getNotSCEV(SCEVUse V); /// Return LHS-RHS. Minus is represented in SCEV as A+B*-1. /// @@ -691,9 +833,9 @@ class ScalarEvolution { /// To compute the difference between two unrelated pointers, you can /// explicitly convert the arguments using getPtrToIntExpr(), for pointer /// types that support it. - const SCEV *getMinusSCEV(const SCEV *LHS, const SCEV *RHS, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0); + SCEVUse getMinusSCEV(SCEVUse LHS, SCEVUse RHS, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); /// Compute ceil(N / D). N and D are treated as unsigned values. /// @@ -703,59 +845,59 @@ class ScalarEvolution { /// umin(N, 1) + floor((N - umin(N, 1)) / D) /// /// A denominator of zero or poison is handled the same way as getUDivExpr(). - const SCEV *getUDivCeilSCEV(const SCEV *N, const SCEV *D); + SCEVUse getUDivCeilSCEV(SCEVUse N, SCEVUse D); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is zero extended. - const SCEV *getTruncateOrZeroExtend(const SCEV *V, Type *Ty, - unsigned Depth = 0); + SCEVUse getTruncateOrZeroExtend(SCEVUse V, Type *Ty, unsigned Depth = 0); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is sign extended. - const SCEV *getTruncateOrSignExtend(const SCEV *V, Type *Ty, - unsigned Depth = 0); + SCEVUse getTruncateOrSignExtend(SCEVUse V, Type *Ty, unsigned Depth = 0); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is zero extended. The /// conversion must not be narrowing. - const SCEV *getNoopOrZeroExtend(const SCEV *V, Type *Ty); + SCEVUse getNoopOrZeroExtend(SCEVUse V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is sign extended. The /// conversion must not be narrowing. - const SCEV *getNoopOrSignExtend(const SCEV *V, Type *Ty); + SCEVUse getNoopOrSignExtend(SCEVUse V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is extended with /// unspecified bits. The conversion must not be narrowing. - const SCEV *getNoopOrAnyExtend(const SCEV *V, Type *Ty); + SCEVUse getNoopOrAnyExtend(SCEVUse V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. The conversion must not be widening. - const SCEV *getTruncateOrNoop(const SCEV *V, Type *Ty); + SCEVUse getTruncateOrNoop(SCEVUse V, Type *Ty); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umax operation with them. - const SCEV *getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS); + SCEVUse getUMaxFromMismatchedTypes(SCEVUse LHS, SCEVUse RHS); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umin operation with them. - const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, - bool Sequential = false); + SCEVUse getUMinFromMismatchedTypes(SCEVUse LHS, SCEVUse RHS, + bool Sequential = false); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umin operation with them. N-ary function. - const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl &Ops, - bool Sequential = false); + SCEVUse getUMinFromMismatchedTypes(ArrayRef Ops, + bool Sequential = false); + SCEVUse getUMinFromMismatchedTypes(SmallVectorImpl &Ops, + bool Sequential = false); /// Transitively follow the chain of pointer-type operands until reaching a /// SCEV that does not have a single pointer operand. This returns a /// SCEVUnknown pointer for well-formed pointer-type expressions, but corner /// cases do exist. - const SCEV *getPointerBase(const SCEV *V); + SCEVUse getPointerBase(SCEVUse V); /// Compute an expression equivalent to S - getPointerBase(S). - const SCEV *removePointerBase(const SCEV *S); + SCEVUse removePointerBase(SCEVUse S); /// Return a SCEV expression for the specified value at the specified scope /// in the program. The L value specifies a loop nest to evaluate the @@ -767,31 +909,31 @@ class ScalarEvolution { /// /// In the case that a relevant loop exit value cannot be computed, the /// original value V is returned. - const SCEV *getSCEVAtScope(const SCEV *S, const Loop *L); + SCEVUse getSCEVAtScope(SCEVUse S, const Loop *L); /// This is a convenience function which does getSCEVAtScope(getSCEV(V), L). - const SCEV *getSCEVAtScope(Value *V, const Loop *L); + SCEVUse getSCEVAtScope(Value *V, const Loop *L); /// Test whether entry to the loop is protected by a conditional between LHS /// and RHS. This is used to help avoid max expressions in loop trip /// counts, and to eliminate casts. bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + SCEVUse LHS, SCEVUse RHS); /// Test whether entry to the basic block is protected by a conditional /// between LHS and RHS. bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, - ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Test whether the backedge of the loop is protected by a conditional /// between LHS and RHS. This is used to eliminate casts. bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + SCEVUse LHS, SCEVUse RHS); /// A version of getTripCountFromExitCount below which always picks an /// evaluation type which can not result in overflow. - const SCEV *getTripCountFromExitCount(const SCEV *ExitCount); + SCEVUse getTripCountFromExitCount(SCEVUse ExitCount); /// Convert from an "exit count" (i.e. "backedge taken count") to a "trip /// count". A "trip count" is the number of times the header of the loop @@ -800,8 +942,8 @@ class ScalarEvolution { /// expression can overflow if ExitCount = UINT_MAX. If EvalTy is not wide /// enough to hold the result without overflow, result unsigned wraps with /// 2s-complement semantics. ex: EC = 255 (i8), TC = 0 (i8) - const SCEV *getTripCountFromExitCount(const SCEV *ExitCount, Type *EvalTy, - const Loop *L); + SCEVUse getTripCountFromExitCount(SCEVUse ExitCount, Type *EvalTy, + const Loop *L); /// Returns the exact trip count of the loop if we can compute it, and /// the result is a small constant. '0' is used to represent an unknown @@ -835,8 +977,7 @@ class ScalarEvolution { /// unknown or not guaranteed to be the multiple of a constant., Will also /// return 1 if the trip count is very large (>= 2^32). /// Note that the argument is an exit count for loop L, NOT a trip count. - unsigned getSmallConstantTripMultiple(const Loop *L, - const SCEV *ExitCount); + unsigned getSmallConstantTripMultiple(const Loop *L, SCEVUse ExitCount); /// Returns the largest constant divisor of the trip count of the /// loop. Will return 1 if no trip count could be computed, or if a @@ -871,12 +1012,12 @@ class ScalarEvolution { /// getBackedgeTakenCount. The loop is guaranteed to exit (via *some* exit) /// before the backedge is executed (ExitCount + 1) times. Note that there /// is no guarantee about *which* exit is taken on the exiting iteration. - const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock, - ExitCountKind Kind = Exact); + SCEVUse getExitCount(const Loop *L, const BasicBlock *ExitingBlock, + ExitCountKind Kind = Exact); /// Same as above except this uses the predicated backedge taken info and /// may require predicates. - const SCEV * + SCEVUse getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl *Predicates, ExitCountKind Kind = Exact); @@ -891,20 +1032,20 @@ class ScalarEvolution { /// Note that it is not valid to call this method on a loop without a /// loop-invariant backedge-taken count (see /// hasLoopInvariantBackedgeTakenCount). - const SCEV *getBackedgeTakenCount(const Loop *L, ExitCountKind Kind = Exact); + SCEVUse getBackedgeTakenCount(const Loop *L, ExitCountKind Kind = Exact); /// Similar to getBackedgeTakenCount, except it will add a set of /// SCEV predicates to Predicates that are required to be true in order for /// the answer to be correct. Predicates can be checked with run-time /// checks and can be used to perform loop versioning. - const SCEV *getPredicatedBackedgeTakenCount( + SCEVUse getPredicatedBackedgeTakenCount( const Loop *L, SmallVectorImpl &Predicates); /// When successful, this returns a SCEVConstant that is greater than or equal /// to (i.e. a "conservative over-approximation") of the value returend by /// getBackedgeTakenCount. If such a value cannot be computed, it returns the /// SCEVCouldNotCompute object. - const SCEV *getConstantMaxBackedgeTakenCount(const Loop *L) { + SCEVUse getConstantMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenCount(L, ConstantMaximum); } @@ -912,14 +1053,14 @@ class ScalarEvolution { /// SCEV predicates to Predicates that are required to be true in order for /// the answer to be correct. Predicates can be checked with run-time /// checks and can be used to perform loop versioning. - const SCEV *getPredicatedConstantMaxBackedgeTakenCount( + SCEVUse getPredicatedConstantMaxBackedgeTakenCount( const Loop *L, SmallVectorImpl &Predicates); /// When successful, this returns a SCEV that is greater than or equal /// to (i.e. a "conservative over-approximation") of the value returend by /// getBackedgeTakenCount. If such a value cannot be computed, it returns the /// SCEVCouldNotCompute object. - const SCEV *getSymbolicMaxBackedgeTakenCount(const Loop *L) { + SCEVUse getSymbolicMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenCount(L, SymbolicMaximum); } @@ -927,7 +1068,7 @@ class ScalarEvolution { /// SCEV predicates to Predicates that are required to be true in order for /// the answer to be correct. Predicates can be checked with run-time /// checks and can be used to perform loop versioning. - const SCEV *getPredicatedSymbolicMaxBackedgeTakenCount( + SCEVUse getPredicatedSymbolicMaxBackedgeTakenCount( const Loop *L, SmallVectorImpl &Predicates); /// Return true if the backedge taken count is either the value returned by @@ -984,60 +1125,60 @@ class ScalarEvolution { /// (at every loop iteration). It is, at the same time, the minimum number /// of times S is divisible by 2. For example, given {4,+,8} it returns 2. /// If S is guaranteed to be 0, it returns the bitwidth of S. - uint32_t getMinTrailingZeros(const SCEV *S); + uint32_t getMinTrailingZeros(SCEVUse S); /// Returns the max constant multiple of S. - APInt getConstantMultiple(const SCEV *S); + APInt getConstantMultiple(SCEVUse S); // Returns the max constant multiple of S. If S is exactly 0, return 1. - APInt getNonZeroConstantMultiple(const SCEV *S); + APInt getNonZeroConstantMultiple(SCEVUse S); /// Determine the unsigned range for a particular SCEV. /// NOTE: This returns a copy of the reference returned by getRangeRef. - ConstantRange getUnsignedRange(const SCEV *S) { + ConstantRange getUnsignedRange(SCEVUse S) { return getRangeRef(S, HINT_RANGE_UNSIGNED); } /// Determine the min of the unsigned range for a particular SCEV. - APInt getUnsignedRangeMin(const SCEV *S) { + APInt getUnsignedRangeMin(SCEVUse S) { return getRangeRef(S, HINT_RANGE_UNSIGNED).getUnsignedMin(); } /// Determine the max of the unsigned range for a particular SCEV. - APInt getUnsignedRangeMax(const SCEV *S) { + APInt getUnsignedRangeMax(SCEVUse S) { return getRangeRef(S, HINT_RANGE_UNSIGNED).getUnsignedMax(); } /// Determine the signed range for a particular SCEV. /// NOTE: This returns a copy of the reference returned by getRangeRef. - ConstantRange getSignedRange(const SCEV *S) { + ConstantRange getSignedRange(SCEVUse S) { return getRangeRef(S, HINT_RANGE_SIGNED); } /// Determine the min of the signed range for a particular SCEV. - APInt getSignedRangeMin(const SCEV *S) { + APInt getSignedRangeMin(SCEVUse S) { return getRangeRef(S, HINT_RANGE_SIGNED).getSignedMin(); } /// Determine the max of the signed range for a particular SCEV. - APInt getSignedRangeMax(const SCEV *S) { + APInt getSignedRangeMax(SCEVUse S) { return getRangeRef(S, HINT_RANGE_SIGNED).getSignedMax(); } /// Test if the given expression is known to be negative. - bool isKnownNegative(const SCEV *S); + bool isKnownNegative(SCEVUse S); /// Test if the given expression is known to be positive. - bool isKnownPositive(const SCEV *S); + bool isKnownPositive(SCEVUse S); /// Test if the given expression is known to be non-negative. - bool isKnownNonNegative(const SCEV *S); + bool isKnownNonNegative(SCEVUse S); /// Test if the given expression is known to be non-positive. - bool isKnownNonPositive(const SCEV *S); + bool isKnownNonPositive(SCEVUse S); /// Test if the given expression is known to be non-zero. - bool isKnownNonZero(const SCEV *S); + bool isKnownNonZero(SCEVUse S); /// Test if the given expression is known to be a power of 2. OrNegative /// allows matching negative power of 2s, and OrZero allows matching 0. @@ -1060,8 +1201,7 @@ class ScalarEvolution { /// 0 (initial value) for the first element and to {1, +, 1} (post /// increment value) for the second one. In both cases AddRec expression /// related to L2 remains the same. - std::pair SplitIntoInitAndPostInc(const Loop *L, - const SCEV *S); + std::pair SplitIntoInitAndPostInc(const Loop *L, SCEVUse S); /// We'd like to check the predicate on every iteration of the most dominated /// loop between loops used in LHS and RHS. @@ -1081,46 +1221,43 @@ class ScalarEvolution { /// so we can assert on that. /// e. Return true if isLoopEntryGuardedByCond(Pred, E(LHS), E(RHS)) && /// isLoopBackedgeGuardedByCond(Pred, B(LHS), B(RHS)) - bool isKnownViaInduction(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + bool isKnownViaInduction(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS); /// Test if the given expression is known to satisfy the condition described /// by Pred, LHS, and RHS. - bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + bool isKnownPredicate(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS); /// Check whether the condition described by Pred, LHS, and RHS is true or /// false. If we know it, return the evaluation of this condition. If neither /// is proved, return std::nullopt. - std::optional evaluatePredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + std::optional evaluatePredicate(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Test if the given expression is known to satisfy the condition described /// by Pred, LHS, and RHS in the given Context. - bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const Instruction *CtxI); + bool isKnownPredicateAt(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + const Instruction *CtxI); /// Check whether the condition described by Pred, LHS, and RHS is true or /// false in the given \p Context. If we know it, return the evaluation of /// this condition. If neither is proved, return std::nullopt. - std::optional evaluatePredicateAt(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const Instruction *CtxI); + std::optional evaluatePredicateAt(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, const Instruction *CtxI); /// Test if the condition described by Pred, LHS, RHS is known to be true on /// every iteration of the loop of the recurrency LHS. bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, - const SCEVAddRecExpr *LHS, const SCEV *RHS); + const SCEVAddRecExpr *LHS, SCEVUse RHS); /// Information about the number of loop iterations for which a loop exit's /// branch condition evaluates to the not-taken path. This is a temporary /// pair of exact and max expressions that are eventually summarized in /// ExitNotTakenInfo and BackedgeTakenInfo. struct ExitLimit { - const SCEV *ExactNotTaken; // The exit is not taken exactly this many times - const SCEV *ConstantMaxNotTaken; // The exit is not taken at most this many - // times - const SCEV *SymbolicMaxNotTaken; + SCEVUse ExactNotTaken; // The exit is not taken exactly this many times + SCEVUse ConstantMaxNotTaken; // The exit is not taken at most this many + // times + SCEVUse SymbolicMaxNotTaken; // Not taken either exactly ConstantMaxNotTaken or zero times bool MaxOrZero = false; @@ -1133,14 +1270,14 @@ class ScalarEvolution { /// Construct either an exact exit limit from a constant, or an unknown /// one from a SCEVCouldNotCompute. No other types of SCEVs are allowed /// as arguments and asserts enforce that internally. - /*implicit*/ ExitLimit(const SCEV *E); + /*implicit*/ ExitLimit(SCEVUse E); - ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, + ExitLimit(SCEVUse, SCEVUse ConstantMaxNotTaken, SCEVUse SymbolicMaxNotTaken, + bool MaxOrZero, ArrayRef> PredLists = {}); - ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, + ExitLimit(SCEVUse E, SCEVUse ConstantMaxNotTaken, + SCEVUse SymbolicMaxNotTaken, bool MaxOrZero, ArrayRef PredList); /// Test whether this ExitLimit contains any computed information, or @@ -1191,20 +1328,18 @@ class ScalarEvolution { struct LoopInvariantPredicate { ICmpInst::Predicate Pred; - const SCEV *LHS; - const SCEV *RHS; + SCEVUse LHS; + SCEVUse RHS; - LoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS) + LoopInvariantPredicate(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS) : Pred(Pred), LHS(LHS), RHS(RHS) {} }; /// If the result of the predicate LHS `Pred` RHS is loop invariant with /// respect to L, return a LoopInvariantPredicate with LHS and RHS being /// invariants, available at L's entry. Otherwise, return std::nullopt. std::optional - getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const Loop *L, - const Instruction *CtxI = nullptr); + getLoopInvariantPredicate(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + const Loop *L, const Instruction *CtxI = nullptr); /// If the result of the predicate LHS `Pred` RHS is loop invariant with /// respect to L at given Context during at least first MaxIter iterations, @@ -1213,59 +1348,61 @@ class ScalarEvolution { /// should be the loop's exit condition. std::optional getLoopInvariantExitCondDuringFirstIterations(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS, const Loop *L, + SCEVUse LHS, SCEVUse RHS, + const Loop *L, const Instruction *CtxI, - const SCEV *MaxIter); + SCEVUse MaxIter); std::optional - getLoopInvariantExitCondDuringFirstIterationsImpl( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, - const Instruction *CtxI, const SCEV *MaxIter); + getLoopInvariantExitCondDuringFirstIterationsImpl(ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS, + const Loop *L, + const Instruction *CtxI, + SCEVUse MaxIter); /// Simplify LHS and RHS in a comparison with predicate Pred. Return true /// iff any changes were made. If the operands are provably equal or /// unequal, LHS and RHS are set to the same value and Pred is set to either /// ICMP_EQ or ICMP_NE. - bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, - const SCEV *&RHS, unsigned Depth = 0); + bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, SCEVUse &LHS, + SCEVUse &RHS, unsigned Depth = 0); /// Return the "disposition" of the given SCEV with respect to the given /// loop. - LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L); + LoopDisposition getLoopDisposition(SCEVUse S, const Loop *L); /// Return true if the value of the given SCEV is unchanging in the /// specified loop. - bool isLoopInvariant(const SCEV *S, const Loop *L); + bool isLoopInvariant(SCEVUse S, const Loop *L); /// Determine if the SCEV can be evaluated at loop's entry. It is true if it /// doesn't depend on a SCEVUnknown of an instruction which is dominated by /// the header of loop L. - bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L); + bool isAvailableAtLoopEntry(SCEVUse S, const Loop *L); /// Return true if the given SCEV changes value in a known way in the /// specified loop. This property being true implies that the value is /// variant in the loop AND that we can emit an expression to compute the /// value of the expression at any particular loop iteration. - bool hasComputableLoopEvolution(const SCEV *S, const Loop *L); + bool hasComputableLoopEvolution(SCEVUse S, const Loop *L); /// Return the "disposition" of the given SCEV with respect to the given /// block. - BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB); + BlockDisposition getBlockDisposition(SCEVUse S, const BasicBlock *BB); /// Return true if elements that makes up the given SCEV dominate the /// specified basic block. - bool dominates(const SCEV *S, const BasicBlock *BB); + bool dominates(SCEVUse S, const BasicBlock *BB); /// Return true if elements that makes up the given SCEV properly dominate /// the specified basic block. - bool properlyDominates(const SCEV *S, const BasicBlock *BB); + bool properlyDominates(SCEVUse S, const BasicBlock *BB); /// Test whether the given SCEV has Op as a direct or indirect operand. - bool hasOperand(const SCEV *S, const SCEV *Op) const; + bool hasOperand(SCEVUse S, SCEVUse Op) const; /// Return the size of an element read or written by Inst. - const SCEV *getElementSize(Instruction *Inst); + SCEVUse getElementSize(Instruction *Inst); void print(raw_ostream &OS) const; void verify() const; @@ -1276,22 +1413,21 @@ class ScalarEvolution { /// operating on. const DataLayout &getDataLayout() const { return DL; } - const SCEVPredicate *getEqualPredicate(const SCEV *LHS, const SCEV *RHS); + const SCEVPredicate *getEqualPredicate(SCEVUse LHS, SCEVUse RHS); const SCEVPredicate *getComparePredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + SCEVUse LHS, SCEVUse RHS); const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags); /// Re-writes the SCEV according to the Predicates in \p A. - const SCEV *rewriteUsingPredicate(const SCEV *S, const Loop *L, - const SCEVPredicate &A); + SCEVUse rewriteUsingPredicate(SCEVUse S, const Loop *L, + const SCEVPredicate &A); /// Tries to convert the \p S expression to an AddRec expression, /// adding additional predicates to \p Preds as required. const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates( - const SCEV *S, const Loop *L, - SmallVectorImpl &Preds); + SCEVUse S, const Loop *L, SmallVectorImpl &Preds); /// Compute \p LHS - \p RHS and returns the result as an APInt if it is a /// constant, and std::nullopt if it isn't. @@ -1300,8 +1436,7 @@ class ScalarEvolution { /// frugal here since we just bail out of actually constructing and /// canonicalizing an expression in the cases where the result isn't going /// to be a constant. - std::optional computeConstantDifference(const SCEV *LHS, - const SCEV *RHS); + std::optional computeConstantDifference(SCEVUse LHS, SCEVUse RHS); /// Update no-wrap flags of an AddRec. This may drop the cached info about /// this AddRec (such as range info) in case if new flags may potentially @@ -1309,7 +1444,7 @@ class ScalarEvolution { void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags); class LoopGuards { - DenseMap RewriteMap; + DenseMap RewriteMap; bool PreserveNUW = false; bool PreserveNSW = false; ScalarEvolution &SE; @@ -1345,8 +1480,8 @@ class ScalarEvolution { }; /// Try to apply information from loop guards for \p L to \p Expr. - const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L); - const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards); + SCEVUse applyLoopGuards(const SCEVUse Expr, const Loop *L); + SCEVUse applyLoopGuards(const SCEVUse Expr, const LoopGuards &Guards); /// Return true if the loop has no abnormal exits. That is, if the loop /// is not infinite, it must exit through an explicit edge in the CFG. @@ -1364,22 +1499,22 @@ class ScalarEvolution { /// being poison as well. The returned set may be incomplete, i.e. there can /// be additional Values that also result in S being poison. void getPoisonGeneratingValues(SmallPtrSetImpl &Result, - const SCEV *S); + SCEVUse S); /// Check whether it is poison-safe to represent the expression S using the /// instruction I. If such a replacement is performed, the poison flags of /// instructions in DropPoisonGeneratingInsts must be dropped. bool canReuseInstruction( - const SCEV *S, Instruction *I, + SCEVUse S, Instruction *I, SmallVectorImpl &DropPoisonGeneratingInsts); class FoldID { - const SCEV *Op = nullptr; + SCEVUse Op = nullptr; const Type *Ty = nullptr; unsigned short C; public: - FoldID(SCEVTypes C, const SCEV *Op, const Type *Ty) : Op(Op), Ty(Ty), C(C) { + FoldID(SCEVTypes C, SCEVUse Op, const Type *Ty) : Op(Op), Ty(Ty), C(C) { assert(Op); assert(Ty); } @@ -1388,12 +1523,15 @@ class ScalarEvolution { unsigned computeHash() const { return detail::combineHashValue( - C, detail::combineHashValue(reinterpret_cast(Op), - reinterpret_cast(Ty))); + C, detail::combineHashValue( + reinterpret_cast(Op.getRawPointer()), + reinterpret_cast(Ty))); } bool operator==(const FoldID &RHS) const { - return std::tie(Op, Ty, C) == std::tie(RHS.Op, RHS.Ty, RHS.C); + void *RawPtr = Op.getRawPointer(); + void *RawPtrRHS = RHS.Op.getRawPointer(); + return std::tie(RawPtr, Ty, C) == std::tie(RawPtrRHS, RHS.Ty, RHS.C); } }; @@ -1441,14 +1579,14 @@ class ScalarEvolution { std::unique_ptr CouldNotCompute; /// The type for HasRecMap. - using HasRecMapType = DenseMap; + using HasRecMapType = DenseMap; /// This is a cache to record whether a SCEV contains any scAddRecExpr. HasRecMapType HasRecMap; /// The type for ExprValueMap. using ValueSetVector = SmallSetVector; - using ExprValueMapType = DenseMap; + using ExprValueMapType = DenseMap; /// ExprValueMap -- This map records the original values from which /// the SCEV expr is generated from. @@ -1456,15 +1594,15 @@ class ScalarEvolution { /// The type for ValueExprMap. using ValueExprMapType = - DenseMap>; + DenseMap>; /// This is a cache of the values we have analyzed so far. ValueExprMapType ValueExprMap; /// This is a cache for expressions that got folded to a different existing /// SCEV. - DenseMap FoldCache; - DenseMap> FoldCacheUser; + DenseMap FoldCache; + DenseMap> FoldCacheUser; /// Mark predicate values currently being processed by isImpliedCond. SmallPtrSet PendingLoopPredicates; @@ -1487,27 +1625,27 @@ class ScalarEvolution { bool ProvingSplitPredicate = false; /// Memoized values for the getConstantMultiple - DenseMap ConstantMultipleCache; + DenseMap ConstantMultipleCache; /// Return the Value set from which the SCEV expr is generated. - ArrayRef getSCEVValues(const SCEV *S); + ArrayRef getSCEVValues(SCEVUse S); /// Private helper method for the getConstantMultiple method. - APInt getConstantMultipleImpl(const SCEV *S); + APInt getConstantMultipleImpl(SCEVUse S); /// Information about the number of times a particular loop exit may be /// reached before exiting the loop. struct ExitNotTakenInfo { PoisoningVH ExitingBlock; - const SCEV *ExactNotTaken; - const SCEV *ConstantMaxNotTaken; - const SCEV *SymbolicMaxNotTaken; + SCEVUse ExactNotTaken; + SCEVUse ConstantMaxNotTaken; + SCEVUse SymbolicMaxNotTaken; SmallVector Predicates; explicit ExitNotTakenInfo(PoisoningVH ExitingBlock, - const SCEV *ExactNotTaken, - const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, + SCEVUse ExactNotTaken, + SCEVUse ConstantMaxNotTaken, + SCEVUse SymbolicMaxNotTaken, ArrayRef Predicates) : ExitingBlock(ExitingBlock), ExactNotTaken(ExactNotTaken), ConstantMaxNotTaken(ConstantMaxNotTaken), @@ -1531,7 +1669,7 @@ class ScalarEvolution { /// Expression indicating the least constant maximum backedge-taken count of /// the loop that is known, or a SCEVCouldNotCompute. This expression is /// only valid if the predicates associated with all loop exits are true. - const SCEV *ConstantMax = nullptr; + SCEVUse ConstantMax = nullptr; /// Indicating if \c ExitNotTaken has an element for every exiting block in /// the loop. @@ -1539,13 +1677,13 @@ class ScalarEvolution { /// Expression indicating the least maximum backedge-taken count of the loop /// that is known, or a SCEVCouldNotCompute. Lazily computed on first query. - const SCEV *SymbolicMax = nullptr; + SCEVUse SymbolicMax = nullptr; /// True iff the backedge is taken either exactly Max or zero times. bool MaxOrZero = false; bool isComplete() const { return IsComplete; } - const SCEV *getConstantMax() const { return ConstantMax; } + SCEVUse getConstantMax() const { return ConstantMax; } const ExitNotTakenInfo *getExitNotTaken( const BasicBlock *ExitingBlock, @@ -1560,7 +1698,7 @@ class ScalarEvolution { /// Initialize BackedgeTakenInfo from a list of exact exit counts. BackedgeTakenInfo(ArrayRef ExitCounts, bool IsComplete, - const SCEV *ConstantMax, bool MaxOrZero); + SCEVUse ConstantMax, bool MaxOrZero); /// Test whether this BackedgeTakenInfo contains any computed information, /// or whether it's all SCEVCouldNotCompute values. @@ -1590,7 +1728,7 @@ class ScalarEvolution { /// If we allowed SCEV predicates to be generated when populating this /// vector, this information can contain them and therefore a /// SCEVPredicate argument should be added to getExact. - const SCEV *getExact( + SCEVUse getExact( const Loop *L, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const; @@ -1599,7 +1737,7 @@ class ScalarEvolution { /// this block before this number of iterations, but may exit via another /// block. If \p Predicates is null the function returns CouldNotCompute if /// predicates are required, otherwise it fills in the required predicates. - const SCEV *getExact( + SCEVUse getExact( const BasicBlock *ExitingBlock, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const { if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates)) @@ -1609,12 +1747,12 @@ class ScalarEvolution { } /// Get the constant max backedge taken count for the loop. - const SCEV *getConstantMax( + SCEVUse getConstantMax( ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const; /// Get the constant max backedge taken count for the particular loop exit. - const SCEV *getConstantMax( + SCEVUse getConstantMax( const BasicBlock *ExitingBlock, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const { if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates)) @@ -1624,12 +1762,12 @@ class ScalarEvolution { } /// Get the symbolic max backedge taken count for the loop. - const SCEV *getSymbolicMax( + SCEVUse getSymbolicMax( const Loop *L, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr); /// Get the symbolic max backedge taken count for the particular loop exit. - const SCEV *getSymbolicMax( + SCEVUse getSymbolicMax( const BasicBlock *ExitingBlock, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const { if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates)) @@ -1652,7 +1790,7 @@ class ScalarEvolution { DenseMap PredicatedBackedgeTakenCounts; /// Loops whose backedge taken counts directly use this non-constant SCEV. - DenseMap, 4>> + DenseMap, 4>> BECountUsers; /// This map contains entries for all of the PHI instructions that we @@ -1664,16 +1802,16 @@ class ScalarEvolution { /// This map contains entries for all the expressions that we attempt to /// compute getSCEVAtScope information for, which can be expensive in /// extreme cases. - DenseMap, 2>> + DenseMap, 2>> ValuesAtScopes; /// Reverse map for invalidation purposes: Stores of which SCEV and which /// loop this is the value-at-scope of. - DenseMap, 2>> + DenseMap, 2>> ValuesAtScopesUsers; /// Memoized computeLoopDisposition results. - DenseMap, 2>> LoopDispositions; @@ -1701,33 +1839,33 @@ class ScalarEvolution { } /// Compute a LoopDisposition value. - LoopDisposition computeLoopDisposition(const SCEV *S, const Loop *L); + LoopDisposition computeLoopDisposition(SCEVUse S, const Loop *L); /// Memoized computeBlockDisposition results. DenseMap< - const SCEV *, + SCEVUse, SmallVector, 2>> BlockDispositions; /// Compute a BlockDisposition value. - BlockDisposition computeBlockDisposition(const SCEV *S, const BasicBlock *BB); + BlockDisposition computeBlockDisposition(SCEVUse S, const BasicBlock *BB); /// Stores all SCEV that use a given SCEV as its direct operand. - DenseMap > SCEVUsers; + DenseMap> SCEVUsers; /// Memoized results from getRange - DenseMap UnsignedRanges; + DenseMap UnsignedRanges; /// Memoized results from getRange - DenseMap SignedRanges; + DenseMap SignedRanges; /// Used to parameterize getRange enum RangeSignHint { HINT_RANGE_UNSIGNED, HINT_RANGE_SIGNED }; /// Set the memoized range for the given SCEV. - const ConstantRange &setRange(const SCEV *S, RangeSignHint Hint, + const ConstantRange &setRange(SCEVUse S, RangeSignHint Hint, ConstantRange CR) { - DenseMap &Cache = + DenseMap &Cache = Hint == HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; auto Pair = Cache.insert_or_assign(S, std::move(CR)); @@ -1737,29 +1875,29 @@ class ScalarEvolution { /// Determine the range for a particular SCEV. /// NOTE: This returns a reference to an entry in a cache. It must be /// copied if its needed for longer. - const ConstantRange &getRangeRef(const SCEV *S, RangeSignHint Hint, + const ConstantRange &getRangeRef(SCEVUse S, RangeSignHint Hint, unsigned Depth = 0); /// Determine the range for a particular SCEV, but evaluates ranges for /// operands iteratively first. - const ConstantRange &getRangeRefIter(const SCEV *S, RangeSignHint Hint); + const ConstantRange &getRangeRefIter(SCEVUse S, RangeSignHint Hint); /// Determines the range for the affine SCEVAddRecExpr {\p Start,+,\p Step}. /// Helper for \c getRange. - ConstantRange getRangeForAffineAR(const SCEV *Start, const SCEV *Step, + ConstantRange getRangeForAffineAR(SCEVUse Start, SCEVUse Step, const APInt &MaxBECount); /// Determines the range for the affine non-self-wrapping SCEVAddRecExpr {\p /// Start,+,\p Step}. ConstantRange getRangeForAffineNoSelfWrappingAR(const SCEVAddRecExpr *AddRec, - const SCEV *MaxBECount, + SCEVUse MaxBECount, unsigned BitWidth, RangeSignHint SignHint); /// Try to compute a range for the affine SCEVAddRecExpr {\p Start,+,\p /// Step} by "factoring out" a ternary expression from the add recurrence. /// Helper called by \c getRange. - ConstantRange getRangeViaFactoring(const SCEV *Start, const SCEV *Step, + ConstantRange getRangeViaFactoring(SCEVUse Start, SCEVUse Step, const APInt &MaxBECount); /// If the unknown expression U corresponds to a simple recurrence, return @@ -1770,59 +1908,58 @@ class ScalarEvolution { /// We know that there is no SCEV for the specified value. Analyze the /// expression recursively. - const SCEV *createSCEV(Value *V); + SCEVUse createSCEV(Value *V); /// We know that there is no SCEV for the specified value. Create a new SCEV /// for \p V iteratively. - const SCEV *createSCEVIter(Value *V); + SCEVUse createSCEVIter(Value *V); /// Collect operands of \p V for which SCEV expressions should be constructed /// first. Returns a SCEV directly if it can be constructed trivially for \p /// V. - const SCEV *getOperandsToCreate(Value *V, SmallVectorImpl &Ops); + SCEVUse getOperandsToCreate(Value *V, SmallVectorImpl &Ops); /// Returns SCEV for the first operand of a phi if all phi operands have /// identical opcodes and operands. const SCEV *createNodeForPHIWithIdenticalOperands(PHINode *PN); /// Provide the special handling we need to analyze PHI SCEVs. - const SCEV *createNodeForPHI(PHINode *PN); + SCEVUse createNodeForPHI(PHINode *PN); /// Helper function called from createNodeForPHI. - const SCEV *createAddRecFromPHI(PHINode *PN); + SCEVUse createAddRecFromPHI(PHINode *PN); /// A helper function for createAddRecFromPHI to handle simple cases. - const SCEV *createSimpleAffineAddRec(PHINode *PN, Value *BEValueV, - Value *StartValueV); + SCEVUse createSimpleAffineAddRec(PHINode *PN, Value *BEValueV, + Value *StartValueV); /// Helper function called from createNodeForPHI. - const SCEV *createNodeFromSelectLikePHI(PHINode *PN); + SCEVUse createNodeFromSelectLikePHI(PHINode *PN); /// Provide special handling for a select-like instruction (currently this /// is either a select instruction or a phi node). \p Ty is the type of the /// instruction being processed, that is assumed equivalent to /// "Cond ? TrueVal : FalseVal". - std::optional + std::optional createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, ICmpInst *Cond, Value *TrueVal, Value *FalseVal); /// See if we can model this select-like instruction via umin_seq expression. - const SCEV *createNodeForSelectOrPHIViaUMinSeq(Value *I, Value *Cond, - Value *TrueVal, - Value *FalseVal); + SCEVUse createNodeForSelectOrPHIViaUMinSeq(Value *I, Value *Cond, + Value *TrueVal, Value *FalseVal); /// Given a value \p V, which is a select-like instruction (currently this is /// either a select instruction or a phi node), which is assumed equivalent to /// Cond ? TrueVal : FalseVal /// see if we can model it as a SCEV expression. - const SCEV *createNodeForSelectOrPHI(Value *V, Value *Cond, Value *TrueVal, - Value *FalseVal); + SCEVUse createNodeForSelectOrPHI(Value *V, Value *Cond, Value *TrueVal, + Value *FalseVal); /// Provide the special handling we need to analyze GEP SCEVs. - const SCEV *createNodeForGEP(GEPOperator *GEP); + SCEVUse createNodeForGEP(GEPOperator *GEP); /// Implementation code for getSCEVAtScope; called at most once for each /// SCEV+Loop pair. - const SCEV *computeSCEVAtScope(const SCEV *S, const Loop *L); + SCEVUse computeSCEVAtScope(SCEVUse S, const Loop *L); /// Return the BackedgeTakenInfo for the given loop, lazily computing new /// values if the loop hasn't been analyzed yet. The returned result is @@ -1904,8 +2041,7 @@ class ScalarEvolution { /// return more precise results in some cases and is preferred when caller /// has a materialized ICmp. ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - bool IsSubExpr, + SCEVUse LHS, SCEVUse RHS, bool IsSubExpr, bool AllowPredicates = false); /// Compute the number of times the backedge of the specified loop will @@ -1931,20 +2067,20 @@ class ScalarEvolution { /// of the loop until we get the exit condition gets a value of ExitWhen /// (true or false). If we cannot evaluate the exit count of the loop, /// return CouldNotCompute. - const SCEV *computeExitCountExhaustively(const Loop *L, Value *Cond, - bool ExitWhen); + SCEVUse computeExitCountExhaustively(const Loop *L, Value *Cond, + bool ExitWhen); /// Return the number of times an exit condition comparing the specified /// value to zero will execute. If not computable, return CouldNotCompute. /// If AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. - ExitLimit howFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr, + ExitLimit howFarToZero(SCEVUse V, const Loop *L, bool IsSubExpr, bool AllowPredicates = false); /// Return the number of times an exit condition checking the specified /// value for nonzero will execute. If not computable, return /// CouldNotCompute. - ExitLimit howFarToNonZero(const SCEV *V, const Loop *L); + ExitLimit howFarToNonZero(SCEVUse V, const Loop *L); /// Return the number of times an exit condition containing the specified /// less-than comparison will execute. If not computable, return @@ -1958,11 +2094,11 @@ class ScalarEvolution { /// /// If \p AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. - ExitLimit howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, + ExitLimit howManyLessThans(SCEVUse LHS, SCEVUse RHS, const Loop *L, bool isSigned, bool ControlsOnlyExit, bool AllowPredicates = false); - ExitLimit howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, + ExitLimit howManyGreaterThans(SCEVUse LHS, SCEVUse RHS, const Loop *L, bool isSigned, bool IsSubExpr, bool AllowPredicates = false); @@ -1976,7 +2112,7 @@ class ScalarEvolution { /// whenever the given FoundCondValue value evaluates to true in given /// Context. If Context is nullptr, then the found predicate is true /// everywhere. LHS and FoundLHS may have different type width. - bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + bool isImpliedCond(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, const Value *FoundCondValue, bool Inverse, const Instruction *Context = nullptr); @@ -1984,65 +2120,60 @@ class ScalarEvolution { /// whenever the given FoundCondValue value evaluates to true in given /// Context. If Context is nullptr, then the found predicate is true /// everywhere. LHS and FoundLHS must have same type width. - bool isImpliedCondBalancedTypes(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, - ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, const SCEV *FoundRHS, + bool isImpliedCondBalancedTypes(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, ICmpInst::Predicate FoundPred, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by FoundPred, FoundLHS, FoundRHS is /// true in given Context. If Context is nullptr, then the found predicate is /// true everywhere. - bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, - const SCEV *FoundRHS, - const Instruction *Context = nullptr); + bool isImpliedCond(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + ICmpInst::Predicate FoundPred, SCEVUse FoundLHS, + SCEVUse FoundRHS, const Instruction *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true in given Context. If Context is nullptr, then the found predicate is /// true everywhere. - bool isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS, + bool isImpliedCondOperands(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. Here LHS is an operation that includes FoundLHS as one of its /// arguments. - bool isImpliedViaOperations(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, const SCEV *FoundRHS, + bool isImpliedViaOperations(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, SCEVUse FoundRHS, unsigned Depth = 0); /// Test whether the condition described by Pred, LHS, and RHS is true. /// Use only simple non-recursive types of checks, such as range analysis etc. - bool isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + bool isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. - bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS); + bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. Utility function used by isImpliedCondOperands. Tries to get /// cases like "X `sgt` 0 => X - 1 `sgt` -1". - bool isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, + bool isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, - const SCEV *FoundRHS); + SCEVUse FoundLHS, SCEVUse FoundRHS); /// Return true if the condition denoted by \p LHS \p Pred \p RHS is implied /// by a call to @llvm.experimental.guard in \p BB. bool isImpliedViaGuard(const BasicBlock *BB, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + SCEVUse LHS, SCEVUse RHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is @@ -2050,10 +2181,9 @@ class ScalarEvolution { /// /// This routine tries to rule out certain kinds of integer overflow, and /// then tries to reason about arithmetic properties of the predicates. - bool isImpliedCondOperandsViaNoOverflow(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS); + bool isImpliedCondOperandsViaNoOverflow(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is @@ -2062,9 +2192,8 @@ class ScalarEvolution { /// This routine tries to weaken the known condition basing on fact that /// FoundLHS is an AddRec. bool isImpliedCondOperandsViaAddRecStart(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS, + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI); /// Test whether the condition described by Pred, LHS, and RHS is true @@ -2074,19 +2203,17 @@ class ScalarEvolution { /// This routine tries to figure out predicate for Phis which are SCEVUnknown /// if it is true for every possible incoming value from their respective /// basic blocks. - bool isImpliedViaMerge(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, const SCEV *FoundRHS, - unsigned Depth); + bool isImpliedViaMerge(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, unsigned Depth); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. /// /// This routine tries to reason about shifts. - bool isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS); + bool isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS); /// If we know that the specified Phi is in the header of its containing /// loop, we know the loop executes a constant number of times, and the PHI @@ -2096,50 +2223,50 @@ class ScalarEvolution { /// Test if the given expression is known to satisfy the condition described /// by Pred and the known constant ranges of LHS and RHS. - bool isKnownPredicateViaConstantRanges(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + bool isKnownPredicateViaConstantRanges(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Try to prove the condition described by "LHS Pred RHS" by ruling out /// integer overflow. /// /// For instance, this will return true for "A s< (A + C)" if C is /// positive. - bool isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + bool isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Try to split Pred LHS RHS into logical conjunctions (and's) and try to /// prove them individually. - bool isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + bool isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Try to match the Expr as "(L + R)". - bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, + bool splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R, SCEV::NoWrapFlags &Flags); /// Forget predicated/non-predicated backedge taken counts for the given loop. void forgetBackedgeTakenCounts(const Loop *L, bool Predicated); /// Drop memoized information for all \p SCEVs. - void forgetMemoizedResults(ArrayRef SCEVs); + void forgetMemoizedResults(ArrayRef SCEVs); /// Helper for forgetMemoizedResults. - void forgetMemoizedResultsImpl(const SCEV *S); + void forgetMemoizedResultsImpl(SCEVUse S); /// Iterate over instructions in \p Worklist and their users. Erase entries /// from ValueExprMap and collect SCEV expressions in \p ToForget void visitAndClearUsers(SmallVectorImpl &Worklist, SmallPtrSetImpl &Visited, - SmallVectorImpl &ToForget); + SmallVectorImpl &ToForget); /// Erase Value from ValueExprMap and ExprValueMap. void eraseValueFromMap(Value *V); /// Insert V to S mapping into ValueExprMap and ExprValueMap. - void insertValueToMap(Value *V, const SCEV *S); + void insertValueToMap(Value *V, SCEVUse S); /// Return false iff given SCEV contains a SCEVUnknown with NULL value- /// pointer. - bool checkValidity(const SCEV *S) const; + bool checkValidity(SCEVUse S) const; /// Return true if `ExtendOpTy`({`Start`,+,`Step`}) can be proved to be /// equal to {`ExtendOpTy`(`Start`),+,`ExtendOpTy`(`Step`)}. This is @@ -2147,8 +2274,7 @@ class ScalarEvolution { /// {`Start`,+,`Step`} if `ExtendOpTy` is `SCEVSignExtendExpr` /// (resp. `SCEVZeroExtendExpr`). template - bool proveNoWrapByVaryingStart(const SCEV *Start, const SCEV *Step, - const Loop *L); + bool proveNoWrapByVaryingStart(SCEVUse Start, SCEVUse Step, const Loop *L); /// Try to prove NSW or NUW on \p AR relying on ConstantRange manipulation. SCEV::NoWrapFlags proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR); @@ -2174,17 +2300,17 @@ class ScalarEvolution { /// 'S'. Specifically, return the first instruction in said bounding scope. /// Return nullptr if the scope is trivial (function entry). /// (See scope definition rules associated with flag discussion above) - const Instruction *getNonTrivialDefiningScopeBound(const SCEV *S); + const Instruction *getNonTrivialDefiningScopeBound(SCEVUse S); /// Return a scope which provides an upper bound on the defining scope for /// a SCEV with the operands in Ops. The outparam Precise is set if the /// bound found is a precise bound (i.e. must be the defining scope.) - const Instruction *getDefiningScopeBound(ArrayRef Ops, + const Instruction *getDefiningScopeBound(ArrayRef Ops, bool &Precise); /// Wrapper around the above for cases which don't care if the bound /// is precise. - const Instruction *getDefiningScopeBound(ArrayRef Ops); + const Instruction *getDefiningScopeBound(ArrayRef Ops); /// Given two instructions in the same function, return true if we can /// prove B must execute given A executes. @@ -2231,7 +2357,7 @@ class ScalarEvolution { /// If the analysis is not successful, a mapping from the \p SymbolicPHI to /// itself (with no predicates) is recorded, and a nullptr with an empty /// predicates vector is returned as a pair. - std::optional>> + std::optional>> createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI); /// Compute the maximum backedge count based on the range of values @@ -2243,47 +2369,44 @@ class ScalarEvolution { /// * the induction variable is assumed not to overflow (i.e. either it /// actually doesn't, or we'd have to immediately execute UB) /// We *don't* assert these preconditions so please be careful. - const SCEV *computeMaxBECountForLT(const SCEV *Start, const SCEV *Stride, - const SCEV *End, unsigned BitWidth, - bool IsSigned); + SCEVUse computeMaxBECountForLT(SCEVUse Start, SCEVUse Stride, SCEVUse End, + unsigned BitWidth, bool IsSigned); /// Verify if an linear IV with positive stride can overflow when in a /// less-than comparison, knowing the invariant term of the comparison, /// the stride. - bool canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned); + bool canIVOverflowOnLT(SCEVUse RHS, SCEVUse Stride, bool IsSigned); /// Verify if an linear IV with negative stride can overflow when in a /// greater-than comparison, knowing the invariant term of the comparison, /// the stride. - bool canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned); + bool canIVOverflowOnGT(SCEVUse RHS, SCEVUse Stride, bool IsSigned); /// Get add expr already created or create a new one. - const SCEV *getOrCreateAddExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags); + SCEVUse getOrCreateAddExpr(ArrayRef Ops, SCEV::NoWrapFlags Flags); /// Get mul expr already created or create a new one. - const SCEV *getOrCreateMulExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags); + SCEVUse getOrCreateMulExpr(ArrayRef Ops, SCEV::NoWrapFlags Flags); // Get addrec expr already created or create a new one. - const SCEV *getOrCreateAddRecExpr(ArrayRef Ops, - const Loop *L, SCEV::NoWrapFlags Flags); + SCEVUse getOrCreateAddRecExpr(ArrayRef Ops, const Loop *L, + SCEV::NoWrapFlags Flags); /// Return x if \p Val is f(x) where f is a 1-1 function. - const SCEV *stripInjectiveFunctions(const SCEV *Val) const; + SCEVUse stripInjectiveFunctions(SCEVUse Val) const; /// Find all of the loops transitively used in \p S, and fill \p LoopsUsed. /// A loop is considered "used" by an expression if it contains /// an add rec on said loop. - void getUsedLoops(const SCEV *S, SmallPtrSetImpl &LoopsUsed); + void getUsedLoops(SCEVUse S, SmallPtrSetImpl &LoopsUsed); /// Try to match the pattern generated by getURemExpr(A, B). If successful, /// Assign A and B to LHS and RHS, respectively. - bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS); + bool matchURem(SCEVUse Expr, SCEVUse &LHS, SCEVUse &RHS); /// Look for a SCEV expression with type `SCEVType` and operands `Ops` in /// `UniqueSCEVs`. Return if found, else nullptr. - SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef Ops); + SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef Ops); /// Get reachable blocks in this function, making limited use of SCEV /// reasoning about conditions. @@ -2292,8 +2415,7 @@ class ScalarEvolution { /// Return the given SCEV expression with a new set of operands. /// This preserves the origial nowrap flags. - const SCEV *getWithOperands(const SCEV *S, - SmallVectorImpl &NewOps); + SCEVUse getWithOperands(SCEVUse S, SmallVectorImpl &NewOps); FoldingSet UniqueSCEVs; FoldingSet UniquePreds; @@ -2305,7 +2427,7 @@ class ScalarEvolution { /// Cache tentative mappings from UnknownSCEVs in a Loop, to a SCEV expression /// they can be rewritten into under certain predicates. DenseMap, - std::pair>> + std::pair>> PredicatedSCEVRewrites; /// Set of AddRecs for which proving NUW via an induction has already been @@ -2397,10 +2519,10 @@ class PredicatedScalarEvolution { /// predicate. The order of transformations applied on the expression of V /// returned by ScalarEvolution is guaranteed to be preserved, even when /// adding new predicates. - const SCEV *getSCEV(Value *V); + SCEVUse getSCEV(Value *V); /// Get the (predicated) backedge count for the analyzed loop. - const SCEV *getBackedgeTakenCount(); + SCEVUse getBackedgeTakenCount(); /// Get the (predicated) symbolic max backedge count for the analyzed loop. const SCEV *getSymbolicMaxBackedgeTakenCount(); @@ -2447,14 +2569,14 @@ class PredicatedScalarEvolution { /// Holds a SCEV and the version number of the SCEV predicate used to /// perform the rewrite of the expression. - using RewriteEntry = std::pair; + using RewriteEntry = std::pair; /// Maps a SCEV to the rewrite result of that SCEV at a certain version /// number. If this number doesn't match the current Generation, we will /// need to do a rewrite. To preserve the transformation order of previous /// rewrites, we will rewrite the previous result instead of the original /// SCEV. - DenseMap RewriteMap; + DenseMap RewriteMap; /// Records what NoWrap flags we've added to a Value *. ValueMap FlagsMap; @@ -2476,10 +2598,11 @@ class PredicatedScalarEvolution { unsigned Generation = 0; /// The backedge taken count. - const SCEV *BackedgeCount = nullptr; + SCEVUse BackedgeCount = nullptr; /// The symbolic backedge taken count. - const SCEV *SymbolicMaxBackedgeCount = nullptr; + + SCEVUse SymbolicMaxBackedgeCount = nullptr; /// The constant max trip count for the loop. std::optional SmallConstantMaxTripCount; diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h index 6eb1aca1cf76a..ff592829f22ef 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -90,11 +90,12 @@ class SCEVVScale : public SCEV { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; -inline unsigned short computeExpressionSize(ArrayRef Args) { +inline unsigned short computeExpressionSize(ArrayRef Args) { APInt Size(16, 1); - for (const auto *Arg : Args) + for (const auto Arg : Args) Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize())); return (unsigned short)Size.getZExtValue(); } @@ -102,19 +103,19 @@ inline unsigned short computeExpressionSize(ArrayRef Args) { /// This is the base class for unary cast operator classes. class SCEVCastExpr : public SCEV { protected: - const SCEV *Op; + SCEVUse Op; Type *Ty; - SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, + SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty); public: - const SCEV *getOperand() const { return Op; } - const SCEV *getOperand(unsigned i) const { + SCEVUse getOperand() const { return Op; } + SCEVUse getOperand(unsigned i) const { assert(i == 0 && "Operand index out of range!"); return Op; } - ArrayRef operands() const { return Op; } + ArrayRef operands() const { return Op; } size_t getNumOperands() const { return 1; } Type *getType() const { return Ty; } @@ -123,6 +124,7 @@ class SCEVCastExpr : public SCEV { return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a cast from a pointer to a pointer-sized integer @@ -130,18 +132,19 @@ class SCEVCastExpr : public SCEV { class SCEVPtrToIntExpr : public SCEVCastExpr { friend class ScalarEvolution; - SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy); + SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op, Type *ITy); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This is the base class for unary integral cast operator classes. class SCEVIntegralCastExpr : public SCEVCastExpr { protected: SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, - const SCEV *op, Type *ty); + SCEVUse op, Type *ty); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: @@ -149,6 +152,7 @@ class SCEVIntegralCastExpr : public SCEVCastExpr { return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a truncation of an integer value to a @@ -156,11 +160,12 @@ class SCEVIntegralCastExpr : public SCEVCastExpr { class SCEVTruncateExpr : public SCEVIntegralCastExpr { friend class ScalarEvolution; - SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); + SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op, Type *ty); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a zero extension of a small integer value @@ -168,13 +173,14 @@ class SCEVTruncateExpr : public SCEVIntegralCastExpr { class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { friend class ScalarEvolution; - SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); + SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op, Type *ty); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scZeroExtend; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a sign extension of a small integer value @@ -182,13 +188,14 @@ class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { class SCEVSignExtendExpr : public SCEVIntegralCastExpr { friend class ScalarEvolution; - SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); + SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op, Type *ty); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scSignExtend; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node is a base class providing common functionality for @@ -199,25 +206,23 @@ class SCEVNAryExpr : public SCEV { // arrays with its SCEVAllocator, so this class just needs a simple // pointer rather than a more elaborate vector-like data structure. // This also avoids the need for a non-trivial destructor. - const SCEV *const *Operands; + SCEVUse const *Operands; size_t NumOperands; - SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, - const SCEV *const *O, size_t N) + SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, SCEVUse const *O, + size_t N) : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O), NumOperands(N) {} public: size_t getNumOperands() const { return NumOperands; } - const SCEV *getOperand(unsigned i) const { + SCEVUse getOperand(unsigned i) const { assert(i < NumOperands && "Operand index out of range!"); return Operands[i]; } - ArrayRef operands() const { - return ArrayRef(Operands, NumOperands); - } + ArrayRef operands() const { return ArrayRef(Operands, NumOperands); } NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { return (NoWrapFlags)(SubclassData & Mask); @@ -241,13 +246,14 @@ class SCEVNAryExpr : public SCEV { S->getSCEVType() == scSequentialUMinExpr || S->getSCEVType() == scAddRecExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node is the base class for n'ary commutative operators. class SCEVCommutativeExpr : public SCEVNAryExpr { protected: SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, - const SCEV *const *O, size_t N) + SCEVUse const *O, size_t N) : SCEVNAryExpr(ID, T, O, N) {} public: @@ -257,6 +263,7 @@ class SCEVCommutativeExpr : public SCEVNAryExpr { S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } /// Set flags for a non-recurrence without clearing previously set flags. void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } @@ -268,11 +275,10 @@ class SCEVAddExpr : public SCEVCommutativeExpr { Type *Ty; - SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVAddExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVCommutativeExpr(ID, scAddExpr, O, N) { - auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) { - return Op->getType()->isPointerTy(); - }); + auto *FirstPointerTypedOp = find_if( + operands(), [](SCEVUse Op) { return Op->getType()->isPointerTy(); }); if (FirstPointerTypedOp != operands().end()) Ty = (*FirstPointerTypedOp)->getType(); else @@ -284,13 +290,14 @@ class SCEVAddExpr : public SCEVCommutativeExpr { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node represents multiplication of some number of SCEVs. class SCEVMulExpr : public SCEVCommutativeExpr { friend class ScalarEvolution; - SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVMulExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} public: @@ -298,30 +305,31 @@ class SCEVMulExpr : public SCEVCommutativeExpr { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a binary unsigned division operation. class SCEVUDivExpr : public SCEV { friend class ScalarEvolution; - std::array Operands; + std::array Operands; - SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) + SCEVUDivExpr(const FoldingSetNodeIDRef ID, SCEVUse lhs, SCEVUse rhs) : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) { Operands[0] = lhs; Operands[1] = rhs; } public: - const SCEV *getLHS() const { return Operands[0]; } - const SCEV *getRHS() const { return Operands[1]; } + SCEVUse getLHS() const { return Operands[0]; } + SCEVUse getRHS() const { return Operands[1]; } size_t getNumOperands() const { return 2; } - const SCEV *getOperand(unsigned i) const { + SCEVUse getOperand(unsigned i) const { assert((i == 0 || i == 1) && "Operand index out of range!"); return i == 0 ? getLHS() : getRHS(); } - ArrayRef operands() const { return Operands; } + ArrayRef operands() const { return Operands; } Type *getType() const { // In most cases the types of LHS and RHS will be the same, but in some @@ -334,6 +342,7 @@ class SCEVUDivExpr : public SCEV { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node represents a polynomial recurrence on the trip count @@ -349,25 +358,24 @@ class SCEVAddRecExpr : public SCEVNAryExpr { const Loop *L; - SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N, + SCEVAddRecExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N, const Loop *l) : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} public: Type *getType() const { return getStart()->getType(); } - const SCEV *getStart() const { return Operands[0]; } + SCEVUse getStart() const { return Operands[0]; } const Loop *getLoop() const { return L; } /// Constructs and returns the recurrence indicating how much this /// expression steps by. If this is a polynomial of degree N, it /// returns a chrec of degree N-1. We cannot determine whether /// the step recurrence has self-wraparound. - const SCEV *getStepRecurrence(ScalarEvolution &SE) const { + SCEVUse getStepRecurrence(ScalarEvolution &SE) const { if (isAffine()) return getOperand(1); - return SE.getAddRecExpr( - SmallVector(operands().drop_front()), getLoop(), - FlagAnyWrap); + return SE.getAddRecExpr(SmallVector(operands().drop_front()), + getLoop(), FlagAnyWrap); } /// Return true if this represents an expression A + B*x where A @@ -394,12 +402,12 @@ class SCEVAddRecExpr : public SCEVNAryExpr { /// Return the value of this chain of recurrences at the specified /// iteration number. - const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; + SCEVUse evaluateAtIteration(SCEVUse It, ScalarEvolution &SE) const; /// Return the value of this chain of recurrences at the specified iteration /// number. Takes an explicit list of operands to represent an AddRec. - static const SCEV *evaluateAtIteration(ArrayRef Operands, - const SCEV *It, ScalarEvolution &SE); + static SCEVUse evaluateAtIteration(ArrayRef Operands, SCEVUse It, + ScalarEvolution &SE); /// Return the number of iterations of this loop that produce /// values in the specified constant range. Another way of @@ -407,8 +415,8 @@ class SCEVAddRecExpr : public SCEVNAryExpr { /// where the value is not in the condition, thus computing the /// exit count. If the iteration count can't be computed, an /// instance of SCEVCouldNotCompute is returned. - const SCEV *getNumIterationsInRange(const ConstantRange &Range, - ScalarEvolution &SE) const; + SCEVUse getNumIterationsInRange(const ConstantRange &Range, + ScalarEvolution &SE) const; /// Return an expression representing the value of this expression /// one iteration of the loop ahead. @@ -418,6 +426,7 @@ class SCEVAddRecExpr : public SCEVNAryExpr { static bool classof(const SCEV *S) { return S->getSCEVType() == scAddRecExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node is the base class min/max selections. @@ -432,7 +441,7 @@ class SCEVMinMaxExpr : public SCEVCommutativeExpr { protected: /// Note: Constructing subclasses via this constructor is allowed SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, - const SCEV *const *O, size_t N) + SCEVUse const *O, size_t N) : SCEVCommutativeExpr(ID, T, O, N) { assert(isMinMaxType(T)); // Min and max never overflow @@ -443,6 +452,7 @@ class SCEVMinMaxExpr : public SCEVCommutativeExpr { Type *getType() const { return getOperand(0)->getType(); } static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } static enum SCEVTypes negate(enum SCEVTypes T) { switch (T) { @@ -464,48 +474,52 @@ class SCEVMinMaxExpr : public SCEVCommutativeExpr { class SCEVSMaxExpr : public SCEVMinMaxExpr { friend class ScalarEvolution; - SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVSMaxExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {} public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents an unsigned maximum selection. class SCEVUMaxExpr : public SCEVMinMaxExpr { friend class ScalarEvolution; - SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVUMaxExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {} public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a signed minimum selection. class SCEVSMinExpr : public SCEVMinMaxExpr { friend class ScalarEvolution; - SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVSMinExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {} public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents an unsigned minimum selection. class SCEVUMinExpr : public SCEVMinMaxExpr { friend class ScalarEvolution; - SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVUMinExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {} public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node is the base class for sequential/in-order min/max selections. @@ -527,7 +541,7 @@ class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { protected: /// Note: Constructing subclasses via this constructor is allowed SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, - const SCEV *const *O, size_t N) + SCEVUse const *O, size_t N) : SCEVNAryExpr(ID, T, O, N) { assert(isSequentialMinMaxType(T)); // Min and max never overflow @@ -554,13 +568,14 @@ class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { static bool classof(const SCEV *S) { return isSequentialMinMaxType(S->getSCEVType()); } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a sequential/in-order unsigned minimum selection. class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { friend class ScalarEvolution; - SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, + SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} @@ -569,6 +584,7 @@ class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { static bool classof(const SCEV *S) { return S->getSCEVType() == scSequentialUMinExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This means that we are dealing with an entirely unknown SCEV @@ -601,48 +617,56 @@ class SCEVUnknown final : public SCEV, private CallbackVH { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class defines a simple visitor class that may be used for /// various SCEV analysis purposes. template struct SCEVVisitor { - RetVal visit(const SCEV *S) { + RetVal visit(SCEVUse S) { switch (S->getSCEVType()) { case scConstant: - return ((SC *)this)->visitConstant((const SCEVConstant *)S); + return ((SC *)this)->visitConstant((const SCEVConstant *)S.getPointer()); case scVScale: - return ((SC *)this)->visitVScale((const SCEVVScale *)S); + return ((SC *)this)->visitVScale((const SCEVVScale *)S.getPointer()); case scPtrToInt: - return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); + return ((SC *)this) + ->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S.getPointer()); case scTruncate: - return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S); + return ((SC *)this) + ->visitTruncateExpr((const SCEVTruncateExpr *)S.getPointer()); case scZeroExtend: - return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S); + return ((SC *)this) + ->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S.getPointer()); case scSignExtend: - return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S); + return ((SC *)this) + ->visitSignExtendExpr((const SCEVSignExtendExpr *)S.getPointer()); case scAddExpr: - return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S); + return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S.getPointer()); case scMulExpr: - return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S); + return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S.getPointer()); case scUDivExpr: - return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S); + return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S.getPointer()); case scAddRecExpr: - return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S); + return ((SC *)this) + ->visitAddRecExpr((const SCEVAddRecExpr *)S.getPointer()); case scSMaxExpr: - return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S); + return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S.getPointer()); case scUMaxExpr: - return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S); + return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S.getPointer()); case scSMinExpr: - return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); + return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S.getPointer()); case scUMinExpr: - return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); + return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S.getPointer()); case scSequentialUMinExpr: return ((SC *)this) - ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); + ->visitSequentialUMinExpr( + (const SCEVSequentialUMinExpr *)S.getPointer()); case scUnknown: - return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); + return ((SC *)this)->visitUnknown((const SCEVUnknown *)S.getPointer()); case scCouldNotCompute: - return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S); + return ((SC *)this) + ->visitCouldNotCompute((const SCEVCouldNotCompute *)S.getPointer()); } llvm_unreachable("Unknown SCEV kind!"); } @@ -656,15 +680,15 @@ template struct SCEVVisitor { /// /// Visitor implements: /// // return true to follow this node. -/// bool follow(const SCEV *S); +/// bool follow(SCEVUse S); /// // return true to terminate the search. /// bool isDone(); template class SCEVTraversal { SV &Visitor; - SmallVector Worklist; - SmallPtrSet Visited; + SmallVector Worklist; + SmallPtrSet Visited; - void push(const SCEV *S) { + void push(SCEVUse S) { if (Visited.insert(S).second && Visitor.follow(S)) Worklist.push_back(S); } @@ -672,10 +696,10 @@ template class SCEVTraversal { public: SCEVTraversal(SV &V) : Visitor(V) {} - void visitAll(const SCEV *Root) { + void visitAll(SCEVUse Root) { push(Root); while (!Worklist.empty() && !Visitor.isDone()) { - const SCEV *S = Worklist.pop_back_val(); + SCEVUse S = Worklist.pop_back_val(); switch (S->getSCEVType()) { case scConstant: @@ -695,7 +719,7 @@ template class SCEVTraversal { case scUMinExpr: case scSequentialUMinExpr: case scAddRecExpr: - for (const auto *Op : S->operands()) { + for (const auto Op : S->operands()) { push(Op); if (Visitor.isDone()) break; @@ -710,21 +734,20 @@ template class SCEVTraversal { }; /// Use SCEVTraversal to visit all nodes in the given expression tree. -template void visitAll(const SCEV *Root, SV &Visitor) { +template void visitAll(SCEVUse Root, SV &Visitor) { SCEVTraversal T(Visitor); T.visitAll(Root); } /// Return true if any node in \p Root satisfies the predicate \p Pred. -template -bool SCEVExprContains(const SCEV *Root, PredTy Pred) { +template bool SCEVExprContains(SCEVUse Root, PredTy Pred) { struct FindClosure { bool Found = false; PredTy Pred; FindClosure(PredTy Pred) : Pred(Pred) {} - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { if (!Pred(S)) return true; @@ -744,7 +767,7 @@ bool SCEVExprContains(const SCEV *Root, PredTy Pred) { /// The result from each visit is cached, so it will return the same /// SCEV for the same input. template -class SCEVRewriteVisitor : public SCEVVisitor { +class SCEVRewriteVisitor : public SCEVVisitor { protected: ScalarEvolution &SE; // Memoize the result of each visit so that we only compute once for @@ -752,84 +775,84 @@ class SCEVRewriteVisitor : public SCEVVisitor { // a SCEV is referenced by multiple SCEVs. Without memoization, this // visit algorithm would have exponential time complexity in the worst // case, causing the compiler to hang on certain tests. - SmallDenseMap RewriteResults; + SmallDenseMap RewriteResults; public: SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} - const SCEV *visit(const SCEV *S) { + SCEVUse visit(SCEVUse S) { auto It = RewriteResults.find(S); if (It != RewriteResults.end()) return It->second; - auto *Visited = SCEVVisitor::visit(S); + auto Visited = SCEVVisitor::visit(S); auto Result = RewriteResults.try_emplace(S, Visited); assert(Result.second && "Should insert a new entry"); return Result.first->second; } - const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } + SCEVUse visitConstant(const SCEVConstant *Constant) { return Constant; } - const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; } + SCEVUse visitVScale(const SCEVVScale *VScale) { return VScale; } - const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { - const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); + SCEVUse visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { + SCEVUse Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() ? Expr : SE.getPtrToIntExpr(Operand, Expr->getType()); } - const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { - const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); + SCEVUse visitTruncateExpr(const SCEVTruncateExpr *Expr) { + SCEVUse Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() ? Expr : SE.getTruncateExpr(Operand, Expr->getType()); } - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); + SCEVUse visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + SCEVUse Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() ? Expr : SE.getZeroExtendExpr(Operand, Expr->getType()); } - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); + SCEVUse visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + SCEVUse Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() ? Expr : SE.getSignExtendExpr(Operand, Expr->getType()); } - const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { - SmallVector Operands; + SCEVUse visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getAddExpr(Operands); } - const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { - SmallVector Operands; + SCEVUse visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getMulExpr(Operands); } - const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { - auto *LHS = ((SC *)this)->visit(Expr->getLHS()); - auto *RHS = ((SC *)this)->visit(Expr->getRHS()); + SCEVUse visitUDivExpr(const SCEVUDivExpr *Expr) { + auto LHS = ((SC *)this)->visit(Expr->getLHS()); + auto RHS = ((SC *)this)->visit(Expr->getRHS()); bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { - SmallVector Operands; + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } @@ -838,72 +861,70 @@ class SCEVRewriteVisitor : public SCEVVisitor { Expr->getNoWrapFlags()); } - const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { - SmallVector Operands; + SCEVUse visitSMaxExpr(const SCEVSMaxExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getSMaxExpr(Operands); } - const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { - SmallVector Operands; + SCEVUse visitUMaxExpr(const SCEVUMaxExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getUMaxExpr(Operands); } - const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { - SmallVector Operands; + SCEVUse visitSMinExpr(const SCEVSMinExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getSMinExpr(Operands); } - const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { - SmallVector Operands; + SCEVUse visitUMinExpr(const SCEVUMinExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getUMinExpr(Operands); } - const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { - SmallVector Operands; + SCEVUse visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } + SCEVUse visitUnknown(const SCEVUnknown *Expr) { return Expr; } - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { - return Expr; - } + SCEVUse visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; } }; using ValueToValueMap = DenseMap; -using ValueToSCEVMapTy = DenseMap; +using ValueToSCEVMapTy = DenseMap; /// The SCEVParameterRewriter takes a scalar evolution expression and updates /// the SCEVUnknown components following the Map (Value -> SCEV). class SCEVParameterRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - ValueToSCEVMapTy &Map) { + static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE, + ValueToSCEVMapTy &Map) { SCEVParameterRewriter Rewriter(SE, Map); return Rewriter.visit(Scev); } @@ -911,7 +932,7 @@ class SCEVParameterRewriter : public SCEVRewriteVisitor { SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) : SCEVRewriteVisitor(SE), Map(M) {} - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { auto I = Map.find(Expr->getValue()); if (I == Map.end()) return Expr; @@ -922,7 +943,7 @@ class SCEVParameterRewriter : public SCEVRewriteVisitor { ValueToSCEVMapTy ⤅ }; -using LoopToScevMapT = DenseMap; +using LoopToScevMapT = DenseMap; /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies /// the Map (Loop -> SCEV) to all AddRecExprs. @@ -932,15 +953,15 @@ class SCEVLoopAddRecRewriter SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) : SCEVRewriteVisitor(SE), Map(M) {} - static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, - ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse Scev, LoopToScevMapT &Map, + ScalarEvolution &SE) { SCEVLoopAddRecRewriter Rewriter(SE, Map); return Rewriter.visit(Scev); } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { - SmallVector Operands; - for (const SCEV *Op : Expr->operands()) + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SmallVector Operands; + for (SCEVUse Op : Expr->operands()) Operands.push_back(visit(Op)); const Loop *L = Expr->getLoop(); diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 21d2ef3c867d7..909d6910402a8 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -13,6 +13,7 @@ #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H #define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" namespace llvm { @@ -23,6 +24,10 @@ bool match(const SCEV *S, const Pattern &P) { return P.match(S); } +template bool match(SCEVUse U, const Pattern &P) { + return match(U.getPointer(), P); +} + template struct cst_pred_ty : public Predicate { bool match(const SCEV *S) { assert((isa(S) || !S->getType()->isVectorTy()) && diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp index a4a98ea0bae14..ae2573f341d86 100644 --- a/llvm/lib/Analysis/DependenceAnalysis.cpp +++ b/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -1250,10 +1250,12 @@ bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst, if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound); LLVM_DEBUG(dbgs() << ", " << *UpperBound->getType() << "\n"); - const SCEV *AbsDelta = - SE->isKnownNonNegative(Delta) ? Delta : SE->getNegativeSCEV(Delta); - const SCEV *AbsCoeff = - SE->isKnownNonNegative(Coeff) ? Coeff : SE->getNegativeSCEV(Coeff); + const SCEV *AbsDelta = SE->isKnownNonNegative(Delta) + ? Delta + : SE->getNegativeSCEV(Delta).getPointer(); + const SCEV *AbsCoeff = SE->isKnownNonNegative(Coeff) + ? Coeff + : SE->getNegativeSCEV(Coeff).getPointer(); const SCEV *Product = SE->getMulExpr(UpperBound, AbsCoeff); if (isKnownPredicate(CmpInst::ICMP_SGT, AbsDelta, Product)) { // Distance greater than trip count - no dependence @@ -1791,8 +1793,9 @@ bool DependenceInfo::weakZeroSrcSIVtest(const SCEV *DstCoeff, const SCEV *AbsCoeff = SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(ConstCoeff) : ConstCoeff; - const SCEV *NewDelta = - SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta; + const SCEV *NewDelta = SE->isKnownNegative(ConstCoeff) + ? SE->getNegativeSCEV(Delta).getPointer() + : Delta; // check that Delta/SrcCoeff < iteration count // really check NewDelta < count*AbsCoeff @@ -1900,8 +1903,9 @@ bool DependenceInfo::weakZeroDstSIVtest(const SCEV *SrcCoeff, const SCEV *AbsCoeff = SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(ConstCoeff) : ConstCoeff; - const SCEV *NewDelta = - SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta; + const SCEV *NewDelta = SE->isKnownNegative(ConstCoeff) + ? SE->getNegativeSCEV(Delta).getPointer() + : Delta; // check that Delta/SrcCoeff < iteration count // really check NewDelta < count*AbsCoeff diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp index 76a78d5229652..b620a9fca7ca8 100644 --- a/llvm/lib/Analysis/IVDescriptors.cpp +++ b/llvm/lib/Analysis/IVDescriptors.cpp @@ -1520,7 +1520,7 @@ bool InductionDescriptor::isInductionPHI( return false; // Check that the PHI is consecutive. - const SCEV *PhiScev = Expr ? Expr : SE->getSCEV(Phi); + const SCEV *PhiScev = Expr ? Expr : SE->getSCEV(Phi).getPointer(); const SCEVAddRecExpr *AR = dyn_cast(PhiScev); if (!AR) { diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp index 2897b922f61e4..460f7730ccad8 100644 --- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp +++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -500,7 +500,8 @@ bool IndexedReference::isConsecutive(const Loop &L, const SCEV *&Stride, SE.getNoopOrSignExtend(ElemSize, WiderType)); const SCEV *CacheLineSize = SE.getConstant(Stride->getType(), CLS); - Stride = SE.isKnownNegative(Stride) ? SE.getNegativeSCEV(Stride) : Stride; + Stride = SE.isKnownNegative(Stride) ? SE.getNegativeSCEV(Stride).getPointer() + : Stride; return SE.isKnownPredicate(ICmpInst::ICMP_ULT, Stride, CacheLineSize); } diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e18133971f5bf..b497cf59090ac 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -259,6 +259,58 @@ static cl::opt UseContextForNoWrapFlagInference( // SCEV class definitions //===----------------------------------------------------------------------===// +class SCEVDropFlags : public SCEVRewriteVisitor { + using Base = SCEVRewriteVisitor; + +public: + SCEVDropFlags(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {} + + static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE) { + SCEVDropFlags Rewriter(SE); + return Rewriter.visit(Scev); + } + + SCEVUse visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; + bool Changed = false; + for (const auto Op : Expr->operands()) { + Operands.push_back(visit(Op)); + Changed |= Op != Operands.back(); + } + return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags()); + } + + SCEVUse visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; + bool Changed = false; + for (const auto Op : Expr->operands()) { + Operands.push_back(visit(Op)); + Changed |= Op != Operands.back(); + } + return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags()); + } +}; + +const SCEV *SCEVUse::computeCanonical(ScalarEvolution &SE, const SCEV *S) { + return SCEVDropFlags::rewrite(S, SE); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void SCEVUse::dump() const { + print(dbgs()); + dbgs() << '\n'; +} +#endif + +void SCEVUse::print(raw_ostream &OS) const { + getPointer()->print(OS); + SCEV::NoWrapFlags Flags = static_cast(getInt()); + if (Flags & SCEV::FlagNUW) + OS << "(u nuw)"; + if (Flags & SCEV::FlagNSW) + OS << "(u nsw)"; +} + //===----------------------------------------------------------------------===// // Implementation of the SCEV class. // @@ -280,28 +332,28 @@ void SCEV::print(raw_ostream &OS) const { return; case scPtrToInt: { const SCEVPtrToIntExpr *PtrToInt = cast(this); - const SCEV *Op = PtrToInt->getOperand(); - OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to " + SCEVUse Op = PtrToInt->getOperand(); + OS << "(ptrtoint " << *Op->getType() << " " << Op << " to " << *PtrToInt->getType() << ")"; return; } case scTruncate: { const SCEVTruncateExpr *Trunc = cast(this); - const SCEV *Op = Trunc->getOperand(); - OS << "(trunc " << *Op->getType() << " " << *Op << " to " + SCEVUse Op = Trunc->getOperand(); + OS << "(trunc " << *Op->getType() << " " << Op << " to " << *Trunc->getType() << ")"; return; } case scZeroExtend: { const SCEVZeroExtendExpr *ZExt = cast(this); - const SCEV *Op = ZExt->getOperand(); - OS << "(zext " << *Op->getType() << " " << *Op << " to " - << *ZExt->getType() << ")"; + SCEVUse Op = ZExt->getOperand(); + OS << "(zext " << *Op->getType() << " " << Op << " to " << *ZExt->getType() + << ")"; return; } case scSignExtend: { const SCEVSignExtendExpr *SExt = cast(this); - const SCEV *Op = SExt->getOperand(); + SCEVUse Op = SExt->getOperand(); OS << "(sext " << *Op->getType() << " " << *Op << " to " << *SExt->getType() << ")"; return; @@ -351,8 +403,8 @@ void SCEV::print(raw_ostream &OS) const { } OS << "("; ListSeparator LS(OpStr); - for (const SCEV *Op : NAry->operands()) - OS << LS << *Op; + for (SCEVUse Op : NAry->operands()) + OS << LS << Op; OS << ")"; switch (NAry->getSCEVType()) { case scAddExpr: @@ -417,7 +469,7 @@ Type *SCEV::getType() const { llvm_unreachable("Unknown SCEV kind!"); } -ArrayRef SCEV::operands() const { +ArrayRef SCEV::operands() const { switch (getSCEVType()) { case scConstant: case scVScale: @@ -470,51 +522,53 @@ bool SCEVCouldNotCompute::classof(const SCEV *S) { return S->getSCEVType() == scCouldNotCompute; } -const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { +SCEVUse ScalarEvolution::getConstant(ConstantInt *V) { FoldingSetNodeID ID; ID.AddInteger(scConstant); ID.AddPointer(V); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); UniqueSCEVs.InsertNode(S, IP); + S->setCanonical(S); return S; } -const SCEV *ScalarEvolution::getConstant(const APInt &Val) { +SCEVUse ScalarEvolution::getConstant(const APInt &Val) { return getConstant(ConstantInt::get(getContext(), Val)); } -const SCEV * -ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { +SCEVUse ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { IntegerType *ITy = cast(getEffectiveSCEVType(Ty)); return getConstant(ConstantInt::get(ITy, V, isSigned)); } -const SCEV *ScalarEvolution::getVScale(Type *Ty) { +SCEVUse ScalarEvolution::getVScale(Type *Ty) { FoldingSetNodeID ID; ID.AddInteger(scVScale); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty); UniqueSCEVs.InsertNode(S, IP); + S->setCanonical(S); return S; } -const SCEV *ScalarEvolution::getElementCount(Type *Ty, ElementCount EC) { - const SCEV *Res = getConstant(Ty, EC.getKnownMinValue()); +SCEVUse ScalarEvolution::getElementCount(Type *Ty, ElementCount EC) { + SCEVUse Res = getConstant(Ty, EC.getKnownMinValue()); if (EC.isScalable()) Res = getMulExpr(Res, getVScale(Ty)); return Res; } SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, - const SCEV *op, Type *ty) + SCEVUse op, Type *ty) : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} -SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, +SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op, Type *ITy) : SCEVCastExpr(ID, scPtrToInt, Op, ITy) { assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() && @@ -522,26 +576,26 @@ SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, } SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, - SCEVTypes SCEVTy, const SCEV *op, + SCEVTypes SCEVTy, SCEVUse op, Type *ty) : SCEVCastExpr(ID, SCEVTy, op, ty) {} -SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, +SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op, Type *ty) : SCEVIntegralCastExpr(ID, scTruncate, op, ty) { assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate non-integer value!"); } -SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty) +SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op, + Type *ty) : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) { assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot zero extend non-integer value!"); } -SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty) +SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op, + Type *ty) : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) { assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot sign extend non-integer value!"); @@ -549,7 +603,7 @@ SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, void SCEVUnknown::deleted() { // Clear this SCEVUnknown from various maps. - SE->forgetMemoizedResults(this); + SE->forgetMemoizedResults(SCEVUse(this)); // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); @@ -560,7 +614,7 @@ void SCEVUnknown::deleted() { void SCEVUnknown::allUsesReplacedWith(Value *New) { // Clear this SCEVUnknown from various maps. - SE->forgetMemoizedResults(this); + SE->forgetMemoizedResults(SCEVUse(this)); // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); @@ -653,9 +707,10 @@ static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, // If the max analysis depth was reached, return std::nullopt, assuming we do // not know if they are equivalent for sure. static std::optional -CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, - const LoopInfo *const LI, const SCEV *LHS, - const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) { +CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, + const LoopInfo *const LI, SCEVUse LHS, SCEVUse RHS, + DominatorTree &DT, ScalarEvolution &SE, + unsigned Depth = 0) { // Fast-path: SCEVs are uniqued so we can do a quick equality check. if (LHS == RHS) return 0; @@ -738,8 +793,8 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, case scSMinExpr: case scUMinExpr: case scSequentialUMinExpr: { - ArrayRef LOps = LHS->operands(); - ArrayRef ROps = RHS->operands(); + ArrayRef LOps = LHS->operands(); + ArrayRef ROps = RHS->operands(); // Lexicographically compare n-ary-like expressions. unsigned LNumOps = LOps.size(), RNumOps = ROps.size(); @@ -747,7 +802,7 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, return (int)LNumOps - (int)RNumOps; for (unsigned i = 0; i != LNumOps; ++i) { - auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT, + auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT, SE, Depth + 1); if (X != 0) return X; @@ -771,37 +826,36 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, /// results from this routine. In other words, we don't want the results of /// this to depend on where the addresses of various SCEV objects happened to /// land in memory. -static void GroupByComplexity(SmallVectorImpl &Ops, - LoopInfo *LI, DominatorTree &DT) { +static void GroupByComplexity(SmallVectorImpl &Ops, LoopInfo *LI, + DominatorTree &DT, ScalarEvolution &SE) { if (Ops.size() < 2) return; // Noop - EquivalenceClasses EqCacheSCEV; + EquivalenceClasses EqCacheSCEV; // Whether LHS has provably less complexity than RHS. - auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) { - auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT); + auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) { + auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT, SE); return Complexity && *Complexity < 0; }; if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. - const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; + SCEVUse &LHS = Ops[0], &RHS = Ops[1]; if (IsLessComplex(RHS, LHS)) std::swap(LHS, RHS); return; } // Do the rough sort by complexity. - llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) { - return IsLessComplex(LHS, RHS); - }); + llvm::stable_sort( + Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); }); // Now that we are sorted by complexity, group elements of the same // complexity. Note that this is, at worst, N^2, but the vector is likely to // be extremely short in practice. Note that we take this approach because we // do not want to depend on the addresses of the objects we are grouping. for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) { - const SCEV *S = Ops[i]; + SCEVUse S = Ops[i]; unsigned Complexity = S->getSCEVType(); // If there are any objects of the same complexity and same value as this @@ -819,8 +873,8 @@ static void GroupByComplexity(SmallVectorImpl &Ops, /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at /// least HugeExprThreshold nodes). -static bool hasHugeExpression(ArrayRef Ops) { - return any_of(Ops, [](const SCEV *S) { +static bool hasHugeExpression(ArrayRef Ops) { + return any_of(Ops, [](SCEVUse S) { return S->getExpressionSize() >= HugeExprThreshold; }); } @@ -836,7 +890,7 @@ static bool hasHugeExpression(ArrayRef Ops) { template static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, - SmallVectorImpl &Ops, FoldT Fold, + SmallVectorImpl &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) { const SCEVConstant *Folded = nullptr; for (unsigned Idx = 0; Idx < Ops.size();) { @@ -861,7 +915,7 @@ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, if (Folded && IsAbsorber(Folded->getAPInt())) return Folded; - GroupByComplexity(Ops, &LI, DT); + GroupByComplexity(Ops, &LI, DT, SE); if (Folded && !IsIdentity(Folded->getAPInt())) Ops.insert(Ops.begin(), Folded); @@ -873,9 +927,8 @@ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, //===----------------------------------------------------------------------===// /// Compute BC(It, K). The result has width W. Assume, K > 0. -static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, - ScalarEvolution &SE, - Type *ResultTy) { +static SCEVUse BinomialCoefficient(SCEVUse It, unsigned K, ScalarEvolution &SE, + Type *ResultTy) { // Handle the simplest case efficiently. if (K == 1) return SE.getTruncateOrZeroExtend(It, ResultTy); @@ -962,15 +1015,15 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, // Calculate the product, at width T+W IntegerType *CalculationTy = IntegerType::get(SE.getContext(), CalculationBits); - const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); + SCEVUse Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); for (unsigned i = 1; i != K; ++i) { - const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); + SCEVUse S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); Dividend = SE.getMulExpr(Dividend, SE.getTruncateOrZeroExtend(S, CalculationTy)); } // Divide by 2^T - const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + SCEVUse DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); // Truncate the result, and divide by K! / 2^T. @@ -986,21 +1039,20 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// /// where BC(It, k) stands for binomial coefficient. -const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, - ScalarEvolution &SE) const { +SCEVUse SCEVAddRecExpr::evaluateAtIteration(SCEVUse It, + ScalarEvolution &SE) const { return evaluateAtIteration(operands(), It, SE); } -const SCEV * -SCEVAddRecExpr::evaluateAtIteration(ArrayRef Operands, - const SCEV *It, ScalarEvolution &SE) { +SCEVUse SCEVAddRecExpr::evaluateAtIteration(ArrayRef Operands, + SCEVUse It, ScalarEvolution &SE) { assert(Operands.size() > 0); - const SCEV *Result = Operands[0]; + SCEVUse Result = Operands[0]; for (unsigned i = 1, e = Operands.size(); i != e; ++i) { // The computation is correct in the face of overflow provided that the // multiplication is performed _after_ the evaluation of the binomial // coefficient. - const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType()); + SCEVUse Coeff = BinomialCoefficient(It, i, SE, Result->getType()); if (isa(Coeff)) return Coeff; @@ -1013,8 +1065,7 @@ SCEVAddRecExpr::evaluateAtIteration(ArrayRef Operands, // SCEV Expression folder implementations //===----------------------------------------------------------------------===// -const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, - unsigned Depth) { +SCEVUse ScalarEvolution::getLosslessPtrToIntExpr(SCEVUse Op, unsigned Depth) { assert(Depth <= 1 && "getLosslessPtrToIntExpr() should self-recurse at most once."); @@ -1026,12 +1077,12 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, // What would be an ID for such a SCEV cast expression? FoldingSetNodeID ID; ID.AddInteger(scPtrToInt); - ID.AddPointer(Op); + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; // Is there already an expression for such a cast? - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // It isn't legal for optimizations to construct new ptrtoint expressions @@ -1065,6 +1116,10 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Op); + if (Op.isCanonical()) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } @@ -1088,12 +1143,12 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, public: SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {} - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE) { SCEVPtrToIntSinkingRewriter Rewriter(SE); return Rewriter.visit(Scev); } - const SCEV *visit(const SCEV *S) { + SCEVUse visit(SCEVUse S) { Type *STy = S->getType(); // If the expression is not pointer-typed, just keep it as-is. if (!STy->isPointerTy()) @@ -1102,27 +1157,27 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, return Base::visit(S); } - const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { - SmallVector Operands; + SCEVUse visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags()); } - const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { - SmallVector Operands; + SCEVUse visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags()); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { assert(Expr->getType()->isPointerTy() && "Should only reach pointer-typed SCEVUnknown's."); return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1); @@ -1130,25 +1185,24 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, }; // And actually perform the cast sinking. - const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this); + SCEVUse IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this); assert(IntOp->getType()->isIntegerTy() && "We must have succeeded in sinking the cast, " "and ending up with an integer-typed expression!"); return IntOp; } -const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) { +SCEVUse ScalarEvolution::getPtrToIntExpr(SCEVUse Op, Type *Ty) { assert(Ty->isIntegerTy() && "Target type must be an integer type!"); - const SCEV *IntOp = getLosslessPtrToIntExpr(Op); + SCEVUse IntOp = getLosslessPtrToIntExpr(Op); if (isa(IntOp)) return IntOp; return getTruncateOrZeroExtend(IntOp, Ty); } -const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getTruncateExpr(SCEVUse Op, Type *Ty, unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && "This is not a truncating conversion!"); assert(isSCEVable(Ty) && @@ -1158,10 +1212,11 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, FoldingSetNodeID ID; ID.AddInteger(scTruncate); - ID.AddPointer(Op); + ID.AddPointer(Op.getRawPointer()); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) @@ -1185,6 +1240,10 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Op); + if (Op.isCanonical()) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } @@ -1194,11 +1253,11 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, // that replace other casts. if (isa(Op) || isa(Op)) { auto *CommOp = cast(Op); - SmallVector Operands; + SmallVector Operands; unsigned numTruncs = 0; for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2; ++i) { - const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1); + SCEVUse S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1); if (!isa(CommOp->getOperand(i)) && isa(S)) numTruncs++; @@ -1214,14 +1273,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, // Although we checked in the beginning that ID is not in the cache, it is // possible that during recursion and different modification ID was inserted // into the cache. So if we find it, just return it. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; } // If the input value is a chrec scev, truncate the chrec's operands. if (const SCEVAddRecExpr *AddRec = dyn_cast(Op)) { - SmallVector Operands; - for (const SCEV *Op : AddRec->operands()) + SmallVector Operands; + for (SCEVUse Op : AddRec->operands()) Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1)); return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); } @@ -1238,15 +1297,19 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, Op, Ty); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Op); + if (Op.isCanonical()) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } // Get the limit of a recurrence such that incrementing by Step cannot cause // signed overflow as long as the value of the recurrence within the // loop does not exceed this limit before incrementing. -static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { +static SCEVUse getSignedOverflowLimitForStep(SCEVUse Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); if (SE->isKnownPositive(Step)) { *Pred = ICmpInst::ICMP_SLT; @@ -1264,9 +1327,9 @@ static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, // Get the limit of a recurrence such that incrementing by Step cannot cause // unsigned overflow as long as the value of the recurrence within the loop does // not exceed this limit before incrementing. -static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { +static SCEVUse getUnsignedOverflowLimitForStep(SCEVUse Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); *Pred = ICmpInst::ICMP_ULT; @@ -1277,8 +1340,8 @@ static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, namespace { struct ExtendOpTraitsBase { - typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *, - unsigned); + typedef SCEVUse (ScalarEvolution::*GetExtendExprTy)(SCEVUse, Type *, + unsigned); }; // Used to make code generic over signed and unsigned overflow. @@ -1289,7 +1352,7 @@ template struct ExtendOpTraits { // // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; // - // static const SCEV *getOverflowLimitForStep(const SCEV *Step, + // static SCEVUse getOverflowLimitForStep(SCEVUse Step, // ICmpInst::Predicate *Pred, // ScalarEvolution *SE); }; @@ -1300,9 +1363,9 @@ struct ExtendOpTraits : public ExtendOpTraitsBase { static const GetExtendExprTy GetExtendExpr; - static const SCEV *getOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { + static SCEVUse getOverflowLimitForStep(SCEVUse Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { return getSignedOverflowLimitForStep(Step, Pred, SE); } }; @@ -1316,9 +1379,9 @@ struct ExtendOpTraits : public ExtendOpTraitsBase { static const GetExtendExprTy GetExtendExpr; - static const SCEV *getOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { + static SCEVUse getOverflowLimitForStep(SCEVUse Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { return getUnsignedOverflowLimitForStep(Step, Pred, SE); } }; @@ -1336,14 +1399,14 @@ const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< // expression "Step + sext/zext(PreIncAR)" is congruent with // "sext/zext(PostIncAR)" template -static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE, unsigned Depth) { +static SCEVUse getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE, unsigned Depth) { auto WrapType = ExtendOpTraits::WrapType; auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; const Loop *L = AR->getLoop(); - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*SE); + SCEVUse Start = AR->getStart(); + SCEVUse Step = AR->getStepRecurrence(*SE); // Check for a simple looking step prior to loop entry. const SCEVAddExpr *SA = dyn_cast(Start); @@ -1354,7 +1417,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // subtraction is expensive. For this purpose, perform a quick and dirty // difference, by checking for Step in the operand list. Note, that // SA might have repeated ops, like %a + %a + ..., so only remove one. - SmallVector DiffOps(SA->operands()); + SmallVector DiffOps(SA->operands()); for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It) if (*It == Step) { DiffOps.erase(It); @@ -1370,7 +1433,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // 1. NSW/NUW flags on the step increment. auto PreStartFlags = ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); - const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); + SCEVUse PreStart = SE->getAddExpr(DiffOps, PreStartFlags); const SCEVAddRecExpr *PreAR = dyn_cast( SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); @@ -1378,7 +1441,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // "S+X does not sign/unsign-overflow". // - const SCEV *BECount = SE->getBackedgeTakenCount(L); + SCEVUse BECount = SE->getBackedgeTakenCount(L); if (PreAR && PreAR->getNoWrapFlags(WrapType) && !isa(BECount) && SE->isKnownPositive(BECount)) return PreStart; @@ -1386,7 +1449,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // 2. Direct overflow check on the step operation's expression. unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); - const SCEV *OperandExtendedStart = + SCEVUse OperandExtendedStart = SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth), (SE->*GetExtendExpr)(Step, WideTy, Depth)); if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) { @@ -1401,7 +1464,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // 3. Loop precondition. ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = + SCEVUse OverflowLimit = ExtendOpTraits::getOverflowLimitForStep(Step, &Pred, SE); if (OverflowLimit && @@ -1413,12 +1476,11 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // Get the normalized zero or sign extended expression for this AddRec's Start. template -static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE, - unsigned Depth) { +static SCEVUse getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE, unsigned Depth) { auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; - const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE, Depth); + SCEVUse PreStart = getPreStartForExtend(AR, Ty, SE, Depth); if (!PreStart) return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth); @@ -1460,8 +1522,7 @@ static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T // is `Delta` (defined below). template -bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, - const SCEV *Step, +bool ScalarEvolution::proveNoWrapByVaryingStart(SCEVUse Start, SCEVUse Step, const Loop *L) { auto WrapType = ExtendOpTraits::WrapType; @@ -1476,12 +1537,12 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, APInt StartAI = StartC->getAPInt(); for (unsigned Delta : {-2, -1, 1, 2}) { - const SCEV *PreStart = getConstant(StartAI - Delta); + SCEVUse PreStart = getConstant(StartAI - Delta); FoldingSetNodeID ID; ID.AddInteger(scAddRecExpr); - ID.AddPointer(PreStart); - ID.AddPointer(Step); + ID.AddPointer(PreStart.getRawPointer()); + ID.AddPointer(Step.getRawPointer()); ID.AddPointer(L); void *IP = nullptr; const auto *PreAR = @@ -1490,9 +1551,9 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, // Give up if we don't already have the add recurrence we need because // actually constructing an add recurrence is relatively expensive. if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) - const SCEV *DeltaS = getConstant(StartC->getType(), Delta); + SCEVUse DeltaS = getConstant(StartC->getType(), Delta); ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; - const SCEV *Limit = ExtendOpTraits::getOverflowLimitForStep( + SCEVUse Limit = ExtendOpTraits::getOverflowLimitForStep( DeltaS, &Pred, this); if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1) return true; @@ -1530,7 +1591,7 @@ static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count. static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const APInt &ConstantStart, - const SCEV *Step) { + SCEVUse Step) { const unsigned BitWidth = ConstantStart.getBitWidth(); const uint32_t TZ = SE.getMinTrailingZeros(Step); if (TZ) @@ -1540,10 +1601,9 @@ static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, } static void insertFoldCacheEntry( - const ScalarEvolution::FoldID &ID, const SCEV *S, - DenseMap &FoldCache, - DenseMap> - &FoldCacheUser) { + const ScalarEvolution::FoldID &ID, SCEVUse S, + DenseMap &FoldCache, + DenseMap> &FoldCacheUser) { auto I = FoldCache.insert({ID, S}); if (!I.second) { // Remove FoldCacheUser entry for ID when replacing an existing FoldCache @@ -1561,8 +1621,8 @@ static void insertFoldCacheEntry( FoldCacheUser[S].push_back(ID); } -const SCEV * -ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { +SCEVUse ScalarEvolution::getZeroExtendExpr(SCEVUse Op, Type *Ty, + unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1575,14 +1635,14 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { if (Iter != FoldCache.end()) return Iter->second; - const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth); + SCEVUse S = getZeroExtendExprImpl(Op, Ty, Depth); if (!isa(S)) insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser); return S; } -const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getZeroExtendExprImpl(SCEVUse Op, Type *Ty, + unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); @@ -1600,15 +1660,20 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // computed a SCEV for this Op and Ty. FoldingSetNodeID ID; ID.AddInteger(scZeroExtend); - ID.AddPointer(Op); + ID.AddPointer(Op.getRawPointer()); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; if (Depth > MaxCastDepth) { SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Op); + if (Op.isCanonical()) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } @@ -1616,7 +1681,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { // It's possible the bits taken off by the truncate were all zero bits. If // so, we should be able to simplify this further. - const SCEV *X = ST->getOperand(); + SCEVUse X = ST->getOperand(); ConstantRange CR = getUnsignedRange(X); unsigned TruncBits = getTypeSizeInBits(ST->getType()); unsigned NewBits = getTypeSizeInBits(Ty); @@ -1631,8 +1696,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Start = AR->getStart(); + SCEVUse Step = AR->getStepRecurrence(*this); unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); @@ -1653,34 +1718,33 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + SCEVUse MaxBECount = getConstantMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for overflow. // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. - const SCEV *CastedMaxBECount = + SCEVUse CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth); - const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend( + SCEVUse RecastedMaxBECount = getTruncateOrZeroExtend( CastedMaxBECount, MaxBECount->getType(), Depth); if (MaxBECount == RecastedMaxBECount) { Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. - const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step, - SCEV::FlagAnyWrap, Depth + 1); - const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul, - SCEV::FlagAnyWrap, - Depth + 1), - WideTy, Depth + 1); - const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1); - const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getZeroExtendExpr(Step, WideTy, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); + SCEVUse ZMul = + getMulExpr(CastedMaxBECount, Step, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse ZAdd = getZeroExtendExpr( + getAddExpr(Start, ZMul, SCEV::FlagAnyWrap, Depth + 1), WideTy, + Depth + 1); + SCEVUse WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1); + SCEVUse WideMaxBECount = + getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); + SCEVUse OperandExtendedAdd = + getAddExpr(WideStart, + getMulExpr(WideMaxBECount, + getZeroExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NUW, which is propagated to this AddRec. setNoWrapFlags(const_cast(AR), SCEV::FlagNUW); @@ -1737,8 +1801,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // For a negative step, we can extend the operands iff doing so only // traverses values in the range zext([0,UINT_MAX]). if (isKnownNegative(Step)) { - const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - - getSignedRangeMin(Step)); + SCEVUse N = getConstant(APInt::getMaxValue(BitWidth) - + getSignedRangeMin(Step)); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) { // Cache knowledge of AR NW, which is propagated to this @@ -1761,10 +1825,10 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, const APInt &C = SC->getAPInt(); const APInt &D = extractConstantWithoutWrapping(*this, C, Step); if (D != 0) { - const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); - const SCEV *SResidual = + SCEVUse SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); + SCEVUse SResidual = getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); - const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); + SCEVUse SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); return getAddExpr(SZExtD, SZExtR, (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), Depth + 1); @@ -1782,8 +1846,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // zext(A % B) --> zext(A) % zext(B) { - const SCEV *LHS; - const SCEV *RHS; + SCEVUse LHS; + SCEVUse RHS; if (matchURem(Op, LHS, RHS)) return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1), getZeroExtendExpr(RHS, Ty, Depth + 1)); @@ -1799,8 +1863,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, if (SA->hasNoUnsignedWrap()) { // If the addition does not unsign overflow then we can, by definition, // commute the zero extension with the addition operation. - SmallVector Ops; - for (const auto *Op : SA->operands()) + SmallVector Ops; + for (const auto Op : SA->operands()) Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1); } @@ -1816,10 +1880,10 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, if (const auto *SC = dyn_cast(SA->getOperand(0))) { const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); if (D != 0) { - const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); - const SCEV *SResidual = + SCEVUse SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); + SCEVUse SResidual = getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); - const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); + SCEVUse SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); return getAddExpr(SZExtD, SZExtR, (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), Depth + 1); @@ -1832,8 +1896,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, if (SM->hasNoUnsignedWrap()) { // If the multiply does not unsign overflow then we can, by definition, // commute the zero extension with the multiply operation. - SmallVector Ops; - for (const auto *Op : SM->operands()) + SmallVector Ops; + for (const auto Op : SM->operands()) Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1); } @@ -1869,8 +1933,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // zext(umax(x, y)) -> umax(zext(x), zext(y)) if (isa(Op) || isa(Op)) { auto *MinMax = cast(Op); - SmallVector Operands; - for (auto *Operand : MinMax->operands()) + SmallVector Operands; + for (auto Operand : MinMax->operands()) Operands.push_back(getZeroExtendExpr(Operand, Ty)); if (isa(MinMax)) return getUMinExpr(Operands); @@ -1880,24 +1944,29 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y)) if (auto *MinMax = dyn_cast(Op)) { assert(isa(MinMax) && "Not supported!"); - SmallVector Operands; - for (auto *Operand : MinMax->operands()) + SmallVector Operands; + for (auto Operand : MinMax->operands()) Operands.push_back(getZeroExtendExpr(Operand, Ty)); return getUMinExpr(Operands, /*Sequential*/ true); } // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Op); + if (Op.isCanonical()) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } -const SCEV * -ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { +SCEVUse ScalarEvolution::getSignExtendExpr(SCEVUse Op, Type *Ty, + unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1910,14 +1979,14 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { if (Iter != FoldCache.end()) return Iter->second; - const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth); + SCEVUse S = getSignExtendExprImpl(Op, Ty, Depth); if (!isa(S)) insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser); return S; } -const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getSignExtendExprImpl(SCEVUse Op, Type *Ty, + unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); @@ -1940,16 +2009,21 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // computed a SCEV for this Op and Ty. FoldingSetNodeID ID; ID.AddInteger(scSignExtend); - ID.AddPointer(Op); + ID.AddPointer(Op.getRawPointer()); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; // Limit recursion depth. if (Depth > MaxCastDepth) { SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Op); + if (Op.isCanonical()) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } @@ -1957,7 +2031,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { // It's possible the bits taken off by the truncate were all sign bits. If // so, we should be able to simplify this further. - const SCEV *X = ST->getOperand(); + SCEVUse X = ST->getOperand(); ConstantRange CR = getSignedRange(X); unsigned TruncBits = getTypeSizeInBits(ST->getType()); unsigned NewBits = getTypeSizeInBits(Ty); @@ -1971,8 +2045,8 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, if (SA->hasNoSignedWrap()) { // If the addition does not sign overflow then we can, by definition, // commute the sign extension with the addition operation. - SmallVector Ops; - for (const auto *Op : SA->operands()) + SmallVector Ops; + for (const auto Op : SA->operands()) Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1)); return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1); } @@ -1989,10 +2063,10 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, if (const auto *SC = dyn_cast(SA->getOperand(0))) { const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); if (D != 0) { - const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); - const SCEV *SResidual = + SCEVUse SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); + SCEVUse SResidual = getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); - const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); + SCEVUse SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); return getAddExpr(SSExtD, SSExtR, (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), Depth + 1); @@ -2005,8 +2079,8 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Start = AR->getStart(); + SCEVUse Step = AR->getStepRecurrence(*this); unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); @@ -2027,35 +2101,34 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + SCEVUse MaxBECount = getConstantMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. - const SCEV *CastedMaxBECount = + SCEVUse CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth); - const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend( + SCEVUse RecastedMaxBECount = getTruncateOrZeroExtend( CastedMaxBECount, MaxBECount->getType(), Depth); if (MaxBECount == RecastedMaxBECount) { Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. - const SCEV *SMul = getMulExpr(CastedMaxBECount, Step, - SCEV::FlagAnyWrap, Depth + 1); - const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul, - SCEV::FlagAnyWrap, - Depth + 1), - WideTy, Depth + 1); - const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1); - const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getSignExtendExpr(Step, WideTy, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); + SCEVUse SMul = + getMulExpr(CastedMaxBECount, Step, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse SAdd = getSignExtendExpr( + getAddExpr(Start, SMul, SCEV::FlagAnyWrap, Depth + 1), WideTy, + Depth + 1); + SCEVUse WideStart = getSignExtendExpr(Start, WideTy, Depth + 1); + SCEVUse WideMaxBECount = + getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); + SCEVUse OperandExtendedAdd = + getAddExpr(WideStart, + getMulExpr(WideMaxBECount, + getSignExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); if (SAdd == OperandExtendedAdd) { // Cache knowledge of AR NSW, which is propagated to this AddRec. setNoWrapFlags(const_cast(AR), SCEV::FlagNSW); @@ -2113,10 +2186,10 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, const APInt &C = SC->getAPInt(); const APInt &D = extractConstantWithoutWrapping(*this, C, Step); if (D != 0) { - const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); - const SCEV *SResidual = + SCEVUse SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); + SCEVUse SResidual = getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); - const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); + SCEVUse SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); return getAddExpr(SSExtD, SSExtR, (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), Depth + 1); @@ -2141,8 +2214,8 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // sext(smax(x, y)) -> smax(sext(x), sext(y)) if (isa(Op) || isa(Op)) { auto *MinMax = cast(Op); - SmallVector Operands; - for (auto *Operand : MinMax->operands()) + SmallVector Operands; + for (auto Operand : MinMax->operands()) Operands.push_back(getSignExtendExpr(Operand, Ty)); if (isa(MinMax)) return getSMinExpr(Operands); @@ -2151,16 +2224,20 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); registerUser(S, { Op }); + if (Op.isCanonical()) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } -const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op, - Type *Ty) { +SCEVUse ScalarEvolution::getCastExpr(SCEVTypes Kind, SCEVUse Op, Type *Ty) { switch (Kind) { case scTruncate: return getTruncateExpr(Op, Ty); @@ -2177,8 +2254,7 @@ const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op, /// getAnyExtendExpr - Return a SCEV for the given operand extended with /// unspecified bits out to the given type. -const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, - Type *Ty) { +SCEVUse ScalarEvolution::getAnyExtendExpr(SCEVUse Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -2192,26 +2268,26 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, // Peel off a truncate cast. if (const SCEVTruncateExpr *T = dyn_cast(Op)) { - const SCEV *NewOp = T->getOperand(); + SCEVUse NewOp = T->getOperand(); if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) return getAnyExtendExpr(NewOp, Ty); return getTruncateOrNoop(NewOp, Ty); } // Next try a zext cast. If the cast is folded, use it. - const SCEV *ZExt = getZeroExtendExpr(Op, Ty); + SCEVUse ZExt = getZeroExtendExpr(Op, Ty); if (!isa(ZExt)) return ZExt; // Next try a sext cast. If the cast is folded, use it. - const SCEV *SExt = getSignExtendExpr(Op, Ty); + SCEVUse SExt = getSignExtendExpr(Op, Ty); if (!isa(SExt)) return SExt; // Force the cast to be folded into the operands of an addrec. if (const SCEVAddRecExpr *AR = dyn_cast(Op)) { - SmallVector Ops; - for (const SCEV *Op : AR->operands()) + SmallVector Ops; + for (SCEVUse Op : AR->operands()) Ops.push_back(getAnyExtendExpr(Op, Ty)); return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); } @@ -2247,12 +2323,12 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, /// may be exposed. This helps getAddRecExpr short-circuit extra work in /// the common case where no interesting opportunities are present, and /// is also used as a check to avoid infinite recursion. -static bool -CollectAddOperandsWithScales(SmallDenseMap &M, - SmallVectorImpl &NewOps, - APInt &AccumulatedConstant, - ArrayRef Ops, const APInt &Scale, - ScalarEvolution &SE) { +static bool CollectAddOperandsWithScales(SmallDenseMap &M, + SmallVectorImpl &NewOps, + APInt &AccumulatedConstant, + ArrayRef Ops, + const APInt &Scale, + ScalarEvolution &SE) { bool Interesting = false; // Iterate over the add operands. They are sorted, with constants first. @@ -2281,8 +2357,8 @@ CollectAddOperandsWithScales(SmallDenseMap &M, } else { // A multiplication of a constant with some other value. Update // the map. - SmallVector MulOps(drop_begin(Mul->operands())); - const SCEV *Key = SE.getMulExpr(MulOps); + SmallVector MulOps(drop_begin(Mul->operands())); + SCEVUse Key = SE.getMulExpr(MulOps); auto Pair = M.insert({Key, NewScale}); if (Pair.second) { NewOps.push_back(Pair.first->first); @@ -2295,7 +2371,7 @@ CollectAddOperandsWithScales(SmallDenseMap &M, } } else { // An ordinary operand. Update the map. - std::pair::iterator, bool> Pair = + std::pair::iterator, bool> Pair = M.insert({Ops[i], Scale}); if (Pair.second) { NewOps.push_back(Pair.first->first); @@ -2312,10 +2388,10 @@ CollectAddOperandsWithScales(SmallDenseMap &M, } bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, - const SCEV *LHS, const SCEV *RHS, + SCEVUse LHS, SCEVUse RHS, const Instruction *CtxI) { - const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *, - SCEV::NoWrapFlags, unsigned); + SCEVUse (ScalarEvolution::*Operation)(SCEVUse, SCEVUse, SCEV::NoWrapFlags, + unsigned); switch (BinOp) { default: llvm_unreachable("Unsupported binary op"); @@ -2330,7 +2406,7 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, break; } - const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) = + SCEVUse (ScalarEvolution::*Extension)(SCEVUse, Type *, unsigned) = Signed ? &ScalarEvolution::getSignExtendExpr : &ScalarEvolution::getZeroExtendExpr; @@ -2339,11 +2415,11 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, auto *WideTy = IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2); - const SCEV *A = (this->*Extension)( + SCEVUse A = (this->*Extension)( (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0); - const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0); - const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0); - const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0); + SCEVUse LHSB = (this->*Extension)(LHS, WideTy, 0); + SCEVUse RHSB = (this->*Extension)(RHS, WideTy, 0); + SCEVUse B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0); if (A == B) return true; // Can we use context to prove the fact we need? @@ -2408,8 +2484,8 @@ ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp( OBO->getOpcode() != Instruction::Mul) return std::nullopt; - const SCEV *LHS = getSCEV(OBO->getOperand(0)); - const SCEV *RHS = getSCEV(OBO->getOperand(1)); + SCEVUse LHS = getSCEV(OBO->getOperand(0)); + SCEVUse RHS = getSCEV(OBO->getOperand(1)); const Instruction *CtxI = UseContextForNoWrapFlagInference ? dyn_cast(OBO) : nullptr; @@ -2435,10 +2511,10 @@ ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp( // We're trying to construct a SCEV of type `Type' with `Ops' as operands and // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of // can't-overflow flags for the operation if possible. -static SCEV::NoWrapFlags -StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, - const ArrayRef Ops, - SCEV::NoWrapFlags Flags) { +static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, + SCEVTypes Type, + const ArrayRef Ops, + SCEV::NoWrapFlags Flags) { using namespace std::placeholders; using OBO = OverflowingBinaryOperator; @@ -2453,7 +2529,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - auto IsKnownNonNegative = [&](const SCEV *S) { + auto IsKnownNonNegative = [&](SCEVUse S) { return SE->isKnownNonNegative(S); }; @@ -2518,14 +2594,20 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, return Flags; } -bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) { +bool ScalarEvolution::isAvailableAtLoopEntry(SCEVUse S, const Loop *L) { return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader()); } +SCEVUse ScalarEvolution::getAddExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags, unsigned Depth) { + SmallVector Ops2(Ops.begin(), Ops.end()); + return getAddExpr(Ops2, Flags, Depth); +} + /// Get a canonical add expression, or something simpler if possible. -const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags OrigFlags, - unsigned Depth) { +SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags OrigFlags, + unsigned Depth) { assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty add!"); @@ -2535,8 +2617,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, for (unsigned i = 1, e = Ops.size(); i != e; ++i) assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVAddExpr operand types don't match!"); - unsigned NumPtrs = count_if( - Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); }); + unsigned NumPtrs = + count_if(Ops, [](SCEVUse Op) { return Op->getType()->isPointerTy(); }); assert(NumPtrs <= 1 && "add has at most one pointer operand"); #endif @@ -2551,7 +2633,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, unsigned Idx = isa(Ops[0]) ? 1 : 0; // Delay expensive flag strengthening until necessary. - auto ComputeFlags = [this, OrigFlags](const ArrayRef Ops) { + auto ComputeFlags = [this, OrigFlags](const ArrayRef Ops) { return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags); }; @@ -2573,14 +2655,14 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, Type *Ty = Ops[0]->getType(); bool FoundMatch = false; for (unsigned i = 0, e = Ops.size(); i != e-1; ++i) - if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 + if (Ops[i] == Ops[i + 1]) { // X + Y + Y --> X + Y*2 // Scan ahead to count how many equal operands there are. unsigned Count = 2; while (i+Count != e && Ops[i+Count] == Ops[i]) ++Count; // Merge the values into a multiply. - const SCEV *Scale = getConstant(Ty, Count); - const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1); + SCEVUse Scale = getConstant(Ty, Count); + SCEVUse Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1); if (Ops.size() == Count) return Mul; Ops[i] = Mul; @@ -2603,14 +2685,14 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, if (auto *T = dyn_cast(Ops[Idx])) return T->getOperand()->getType(); if (const auto *Mul = dyn_cast(Ops[Idx])) { - const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1); + const auto LastOp = Mul->getOperand(Mul->getNumOperands() - 1); if (const auto *T = dyn_cast(LastOp)) return T->getOperand()->getType(); } return nullptr; }; if (auto *SrcType = FindTruncSrcType()) { - SmallVector LargeOps; + SmallVector LargeOps; bool Ok = true; // Check all the operands to see if they can be represented in the // source type of the truncate. @@ -2624,7 +2706,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, } else if (const SCEVConstant *C = dyn_cast(Op)) { LargeOps.push_back(getAnyExtendExpr(C, SrcType)); } else if (const SCEVMulExpr *M = dyn_cast(Op)) { - SmallVector LargeMulOps; + SmallVector LargeMulOps; for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { if (const SCEVTruncateExpr *T = dyn_cast(M->getOperand(j))) { @@ -2649,7 +2731,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, } if (Ok) { // Evaluate the expression in the larger type. - const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1); // If it folds to something simple, use it. Otherwise, don't. if (isa(Fold) || isa(Fold)) return getTruncateExpr(Fold, Ty); @@ -2660,8 +2742,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // Check if we have an expression of the form ((X + C1) - C2), where C1 and // C2 can be folded in a way that allows retaining wrapping flags of (X + // C1). - const SCEV *A = Ops[0]; - const SCEV *B = Ops[1]; + SCEVUse A = Ops[0]; + SCEVUse B = Ops[1]; auto *AddExpr = dyn_cast(B); auto *C = dyn_cast(A); if (AddExpr && C && isa(AddExpr->getOperand(0))) { @@ -2688,7 +2770,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, } if (PreservedFlags != SCEV::FlagAnyWrap) { - SmallVector NewOps(AddExpr->operands()); + SmallVector NewOps(AddExpr->operands()); NewOps[0] = getConstant(ConstAdd); return getAddExpr(NewOps, PreservedFlags); } @@ -2700,8 +2782,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, const SCEVMulExpr *Mul = dyn_cast(Ops[0]); if (Mul && Mul->getNumOperands() == 2 && Mul->getOperand(0)->isAllOnesValue()) { - const SCEV *X; - const SCEV *Y; + SCEVUse X; + SCEVUse Y; if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) { return getMulExpr(Y, getUDivExpr(X, Y)); } @@ -2746,8 +2828,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // operands multiplied by constant values. if (Idx < Ops.size() && isa(Ops[Idx])) { uint64_t BitWidth = getTypeSizeInBits(Ty); - SmallDenseMap M; - SmallVector NewOps; + SmallDenseMap M; + SmallVector NewOps; APInt AccumulatedConstant(BitWidth, 0); if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, Ops, APInt(BitWidth, 1), *this)) { @@ -2760,8 +2842,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // Some interesting folding opportunity is present, so its worthwhile to // re-generate the operands list. Group the operands by constant scale, // to avoid multiplying by the same constant scale multiple times. - std::map, APIntCompare> MulOpLists; - for (const SCEV *NewOp : NewOps) + std::map, APIntCompare> MulOpLists; + for (SCEVUse NewOp : NewOps) MulOpLists[M.find(NewOp)->second].push_back(NewOp); // Re-generate the operands list. Ops.clear(); @@ -2791,25 +2873,24 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { const SCEVMulExpr *Mul = cast(Ops[Idx]); for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { - const SCEV *MulOpSCEV = Mul->getOperand(MulOp); + SCEVUse MulOpSCEV = Mul->getOperand(MulOp); if (isa(MulOpSCEV)) continue; for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) if (MulOpSCEV == Ops[AddOp]) { // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) - const SCEV *InnerMul = Mul->getOperand(MulOp == 0); + SCEVUse InnerMul = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { // If the multiply has more than two operands, we must get the // Y*Z term. - SmallVector MulOps( - Mul->operands().take_front(MulOp)); + SmallVector MulOps(Mul->operands().take_front(MulOp)); append_range(MulOps, Mul->operands().drop_front(MulOp + 1)); InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); } - SmallVector TwoOps = {getOne(Ty), InnerMul}; - const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); - const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV, - SCEV::FlagAnyWrap, Depth + 1); + SmallVector TwoOps = {getOne(Ty), InnerMul}; + SCEVUse AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse OuterMul = + getMulExpr(AddOne, MulOpSCEV, SCEV::FlagAnyWrap, Depth + 1); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { Ops.erase(Ops.begin()+AddOp); @@ -2833,25 +2914,24 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, OMulOp != e; ++OMulOp) if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) - const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); + SCEVUse InnerMul1 = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { - SmallVector MulOps( - Mul->operands().take_front(MulOp)); + SmallVector MulOps(Mul->operands().take_front(MulOp)); append_range(MulOps, Mul->operands().drop_front(MulOp+1)); InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); } - const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); + SCEVUse InnerMul2 = OtherMul->getOperand(OMulOp == 0); if (OtherMul->getNumOperands() != 2) { - SmallVector MulOps( + SmallVector MulOps( OtherMul->operands().take_front(OMulOp)); append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1)); InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); } - SmallVector TwoOps = {InnerMul1, InnerMul2}; - const SCEV *InnerMulSum = + SmallVector TwoOps = {InnerMul1, InnerMul2}; + SCEVUse InnerMulSum = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); - const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum, - SCEV::FlagAnyWrap, Depth + 1); + SCEVUse OuterMul = getMulExpr(MulOpSCEV, InnerMulSum, + SCEV::FlagAnyWrap, Depth + 1); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); @@ -2872,7 +2952,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this add and add them to the vector if // they are loop invariant w.r.t. the recurrence. - SmallVector LIOps; + SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) @@ -2894,7 +2974,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); - SmallVector AddRecOps(AddRec->operands()); + SmallVector AddRecOps(AddRec->operands()); // It is not in general safe to propagate flags valid on an add within // the addrec scope to one outside it. We must prove that the inner @@ -2919,7 +2999,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // outer add and the inner addrec are guaranteed to have no overflow. // Always propagate NW. Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); - const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); + SCEVUse NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -2947,7 +3027,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, "AddRecExprs are not sorted in reverse dominance order?"); if (AddRecLoop == cast(Ops[OtherIdx])->getLoop()) { // Other + {A,+,B} + {C,+,D} --> Other + {A+C,+,B+D} - SmallVector AddRecOps(AddRec->operands()); + SmallVector AddRecOps(AddRec->operands()); for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) { const auto *OtherAddRec = cast(Ops[OtherIdx]); @@ -2958,8 +3038,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, append_range(AddRecOps, OtherAddRec->operands().drop_front(i)); break; } - SmallVector TwoOps = { - AddRecOps[i], OtherAddRec->getOperand(i)}; + SmallVector TwoOps = {AddRecOps[i], + OtherAddRec->getOperand(i)}; AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); } Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; @@ -2980,69 +3060,79 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, return getOrCreateAddExpr(Ops, ComputeFlags(Ops)); } -const SCEV * -ScalarEvolution::getOrCreateAddExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getOrCreateAddExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scAddExpr); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; SCEVAddExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Ops); + if (all_of(Ops, [](SCEVUse U) { return U.isCanonical(); })) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); } S->setNoWrapFlags(Flags); return S; } -const SCEV * -ScalarEvolution::getOrCreateAddRecExpr(ArrayRef Ops, - const Loop *L, SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getOrCreateAddRecExpr(ArrayRef Ops, + const Loop *L, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scAddRecExpr); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); ID.AddPointer(L); void *IP = nullptr; SCEVAddRecExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L); UniqueSCEVs.InsertNode(S, IP); LoopUsers[L].push_back(S); registerUser(S, Ops); + if (all_of(Ops, [](SCEVUse U) { return U.isCanonical(); })) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); } setNoWrapFlags(S, Flags); return S; } -const SCEV * -ScalarEvolution::getOrCreateMulExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getOrCreateMulExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scMulExpr); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; SCEVMulExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Ops); + if (all_of(Ops, [](SCEVUse U) { return U.isCanonical(); })) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); } S->setNoWrapFlags(Flags); return S; @@ -3082,11 +3172,11 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { /// Determine if any of the operands in this SCEV are a constant or if /// any of the add or multiply expressions in this SCEV contain a constant. -static bool containsConstantInAddMulChain(const SCEV *StartExpr) { +static bool containsConstantInAddMulChain(SCEVUse StartExpr) { struct FindConstantInAddMulChain { bool FoundConstant = false; - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { FoundConstant |= isa(S); return isa(S) || isa(S); } @@ -3103,9 +3193,16 @@ static bool containsConstantInAddMulChain(const SCEV *StartExpr) { } /// Get a canonical multiply expression, or something simpler if possible. -const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags OrigFlags, - unsigned Depth) { +SCEVUse ScalarEvolution::getMulExpr(ArrayRef Ops, + SCEV::NoWrapFlags OrigFlags, + unsigned Depth) { + SmallVector Ops2(Ops); + return getMulExpr(Ops2, OrigFlags, Depth); +} + +SCEVUse ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags OrigFlags, + unsigned Depth) { assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty mul!"); @@ -3127,7 +3224,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, return Folded; // Delay expensive flag strengthening until necessary. - auto ComputeFlags = [this, OrigFlags](const ArrayRef Ops) { + auto ComputeFlags = [this, OrigFlags](const ArrayRef Ops) { return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags); }; @@ -3154,10 +3251,10 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of // this transformation should be narrowed down. if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) { - const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0), - SCEV::FlagAnyWrap, Depth + 1); - const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1), - SCEV::FlagAnyWrap, Depth + 1); + SCEVUse LHS = getMulExpr(LHSC, Add->getOperand(0), SCEV::FlagAnyWrap, + Depth + 1); + SCEVUse RHS = getMulExpr(LHSC, Add->getOperand(1), SCEV::FlagAnyWrap, + Depth + 1); return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); } @@ -3165,11 +3262,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // If we have a mul by -1 of an add, try distributing the -1 among the // add operands. if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) { - SmallVector NewOps; + SmallVector NewOps; bool AnyFolded = false; - for (const SCEV *AddOp : Add->operands()) { - const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap, - Depth + 1); + for (SCEVUse AddOp : Add->operands()) { + SCEVUse Mul = + getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap, Depth + 1); if (!isa(Mul)) AnyFolded = true; NewOps.push_back(Mul); } @@ -3177,8 +3274,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1); } else if (const auto *AddRec = dyn_cast(Ops[1])) { // Negation preserves a recurrence's no self-wrap property. - SmallVector Operands; - for (const SCEV *AddRecOp : AddRec->operands()) + SmallVector Operands; + for (SCEVUse AddRecOp : AddRec->operands()) Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap, Depth + 1)); // Let M be the minimum representable signed value. AddRec with nsw @@ -3235,7 +3332,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this mul and add them to the vector // if they are loop invariant w.r.t. the recurrence. - SmallVector LIOps; + SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) { @@ -3247,9 +3344,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(AddRec->getNumOperands()); - const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1); // If both the mul and addrec are nuw, we can preserve nuw. // If both the mul and addrec are nsw, we can only preserve nsw if either @@ -3271,7 +3368,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, } } - const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags); + SCEVUse NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -3317,10 +3414,10 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, bool Overflow = false; Type *Ty = AddRec->getType(); bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; - SmallVector AddRecOps; + SmallVector AddRecOps; for (int x = 0, xe = AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { - SmallVector SumOps; + SmallVector SumOps; for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), @@ -3332,9 +3429,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, Coeff = umul_ov(Coeff1, Coeff2, Overflow); else Coeff = Coeff1*Coeff2; - const SCEV *CoeffTerm = getConstant(Ty, Coeff); - const SCEV *Term1 = AddRec->getOperand(y-z); - const SCEV *Term2 = OtherAddRec->getOperand(z); + SCEVUse CoeffTerm = getConstant(Ty, Coeff); + SCEVUse Term1 = AddRec->getOperand(y - z); + SCEVUse Term2 = OtherAddRec->getOperand(z); SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2, SCEV::FlagAnyWrap, Depth + 1)); } @@ -3344,8 +3441,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1)); } if (!Overflow) { - const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), - SCEV::FlagAnyWrap); + SCEVUse NewAddRec = + getAddRecExpr(AddRecOps, AddRec->getLoop(), SCEV::FlagAnyWrap); if (Ops.size() == 2) return NewAddRec; Ops[Idx] = NewAddRec; Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; @@ -3368,8 +3465,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, } /// Represents an unsigned remainder expression based on unsigned division. -const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, - const SCEV *RHS) { +SCEVUse ScalarEvolution::getURemExpr(SCEVUse LHS, SCEVUse RHS) { assert(getEffectiveSCEVType(LHS->getType()) == getEffectiveSCEVType(RHS->getType()) && "SCEVURemExpr operand types don't match!"); @@ -3390,15 +3486,14 @@ const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, } // Fallback to %a == %x urem %y == %x - ((%x udiv %y) * %y) - const SCEV *UDiv = getUDivExpr(LHS, RHS); - const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW); + SCEVUse UDiv = getUDivExpr(LHS, RHS); + SCEVUse Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW); return getMinusSCEV(LHS, Mult, SCEV::FlagNUW); } /// Get a canonical unsigned division expression, or something simpler if /// possible. -const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, - const SCEV *RHS) { +SCEVUse ScalarEvolution::getUDivExpr(SCEVUse LHS, SCEVUse RHS) { assert(!LHS->getType()->isPointerTy() && "SCEVUDivExpr operand can't be pointer!"); assert(LHS->getType() == RHS->getType() && @@ -3406,10 +3501,10 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, FoldingSetNodeID ID; ID.AddInteger(scUDivExpr); - ID.AddPointer(LHS); - ID.AddPointer(RHS); + ID.AddPointer(LHS.getRawPointer()); + ID.AddPointer(RHS.getRawPointer()); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // 0 udiv Y == 0 @@ -3446,8 +3541,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop(), SCEV::FlagAnyWrap)) { - SmallVector Operands; - for (const SCEV *Op : AR->operands()) + SmallVector Operands; + for (SCEVUse Op : AR->operands()) Operands.push_back(getUDivExpr(Op, RHS)); return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); } @@ -3463,9 +3558,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, const APInt &StartInt = StartC->getAPInt(); const APInt &StartRem = StartInt.urem(StepInt); if (StartRem != 0) { - const SCEV *NewLHS = - getAddRecExpr(getConstant(StartInt - StartRem), Step, - AR->getLoop(), SCEV::FlagNW); + SCEVUse NewLHS = getAddRecExpr(getConstant(StartInt - StartRem), + Step, AR->getLoop(), SCEV::FlagNW); if (LHS != NewLHS) { LHS = NewLHS; @@ -3473,10 +3567,10 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // already cached. ID.clear(); ID.AddInteger(scUDivExpr); - ID.AddPointer(LHS); - ID.AddPointer(RHS); + ID.AddPointer(LHS.getRawPointer()); + ID.AddPointer(RHS.getRawPointer()); IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; } } @@ -3484,16 +3578,16 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast(LHS)) { - SmallVector Operands; - for (const SCEV *Op : M->operands()) + SmallVector Operands; + for (SCEVUse Op : M->operands()) Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) // Find an operand that's safely divisible. for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { - const SCEV *Op = M->getOperand(i); - const SCEV *Div = getUDivExpr(Op, RHSC); + SCEVUse Op = M->getOperand(i); + SCEVUse Div = getUDivExpr(Op, RHSC); if (!isa(Div) && getMulExpr(Div, RHSC) == Op) { - Operands = SmallVector(M->operands()); + Operands = SmallVector(M->operands()); Operands[i] = Div; return getMulExpr(Operands); } @@ -3516,13 +3610,13 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. if (const SCEVAddExpr *A = dyn_cast(LHS)) { - SmallVector Operands; - for (const SCEV *Op : A->operands()) + SmallVector Operands; + for (SCEVUse Op : A->operands()) Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { Operands.clear(); for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { - const SCEV *Op = getUDivExpr(A->getOperand(i), RHS); + SCEVUse Op = getUDivExpr(A->getOperand(i), RHS); if (isa(Op) || getMulExpr(Op, RHS) != A->getOperand(i)) break; @@ -3558,11 +3652,16 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs // changes). Make sure we get a new one. IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), LHS, RHS); UniqueSCEVs.InsertNode(S, IP); registerUser(S, {LHS, RHS}); + if (LHS.isCanonical() && RHS.isCanonical()) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } @@ -3584,8 +3683,7 @@ APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { /// possible. There is no representation for an exact udiv in SCEV IR, but we /// can attempt to remove factors from the LHS and RHS. We can't do this when /// it's not exact because the udiv may be clearing bits. -const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, - const SCEV *RHS) { +SCEVUse ScalarEvolution::getUDivExactExpr(SCEVUse LHS, SCEVUse RHS) { // TODO: we could try to find factors in all sorts of things, but for now we // just deal with u/exact (multiply, constant). See SCEVDivision towards the // end of this file for inspiration. @@ -3599,7 +3697,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, // first element of the mulexpr. if (const auto *LHSCst = dyn_cast(Mul->getOperand(0))) { if (LHSCst == RHSCst) { - SmallVector Operands(drop_begin(Mul->operands())); + SmallVector Operands(drop_begin(Mul->operands())); return getMulExpr(Operands); } @@ -3612,7 +3710,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, cast(getConstant(LHSCst->getAPInt().udiv(Factor))); RHSCst = cast(getConstant(RHSCst->getAPInt().udiv(Factor))); - SmallVector Operands; + SmallVector Operands; Operands.push_back(LHSCst); append_range(Operands, Mul->operands().drop_front()); LHS = getMulExpr(Operands); @@ -3626,7 +3724,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) { if (Mul->getOperand(i) == RHS) { - SmallVector Operands; + SmallVector Operands; append_range(Operands, Mul->operands().take_front(i)); append_range(Operands, Mul->operands().drop_front(i + 1)); return getMulExpr(Operands); @@ -3638,10 +3736,9 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, /// Get an add recurrence expression for the specified loop. Simplify the /// expression as much as possible. -const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, - const Loop *L, - SCEV::NoWrapFlags Flags) { - SmallVector Operands; +SCEVUse ScalarEvolution::getAddRecExpr(SCEVUse Start, SCEVUse Step, + const Loop *L, SCEV::NoWrapFlags Flags) { + SmallVector Operands; Operands.push_back(Start); if (const SCEVAddRecExpr *StepChrec = dyn_cast(Step)) if (StepChrec->getLoop() == L) { @@ -3653,11 +3750,16 @@ const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, return getAddRecExpr(Operands, L, Flags); } +SCEVUse ScalarEvolution::getAddRecExpr(ArrayRef Operands, + const Loop *L, SCEV::NoWrapFlags Flags) { + SmallVector Ops2(Operands); + return getAddRecExpr(Ops2, L, Flags); +} + /// Get an add recurrence expression for the specified loop. Simplify the /// expression as much as possible. -const SCEV * -ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, - const Loop *L, SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, + const Loop *L, SCEV::NoWrapFlags Flags) { if (Operands.size() == 1) return Operands[0]; #ifndef NDEBUG Type *ETy = getEffectiveSCEVType(Operands[0]->getType()); @@ -3691,13 +3793,13 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) : (!NestedLoop->contains(L) && DT.dominates(L->getHeader(), NestedLoop->getHeader()))) { - SmallVector NestedOperands(NestedAR->operands()); + SmallVector NestedOperands(NestedAR->operands()); Operands[0] = NestedAR->getStart(); // AddRecs require their operands be loop-invariant with respect to their // loops. Don't perform this transformation if it would break this // requirement. - bool AllInvariant = all_of( - Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); + bool AllInvariant = + all_of(Operands, [&](SCEVUse Op) { return isLoopInvariant(Op, L); }); if (AllInvariant) { // Create a recurrence for the outer loop with the same step size. @@ -3708,7 +3810,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); - AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { + AllInvariant = all_of(NestedOperands, [&](SCEVUse Op) { return isLoopInvariant(Op, NestedLoop); }); @@ -3732,10 +3834,15 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, return getOrCreateAddRecExpr(Operands, L, Flags); } -const SCEV * +SCEVUse ScalarEvolution::getGEPExpr(GEPOperator *GEP, + ArrayRef IndexExprs) { + return getGEPExpr(GEP, SmallVector(IndexExprs)); +} + +SCEVUse ScalarEvolution::getGEPExpr(GEPOperator *GEP, - const SmallVectorImpl &IndexExprs) { - const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand()); + const SmallVectorImpl &IndexExprs) { + SCEVUse BaseExpr = getSCEV(GEP->getPointerOperand()); // getSCEV(Base)->getType() has the same address space as Base->getType() // because SCEV::getType() preserves the address space. Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType()); @@ -3759,14 +3866,14 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, Type *CurTy = GEP->getType(); bool FirstIter = true; - SmallVector Offsets; - for (const SCEV *IndexExpr : IndexExprs) { + SmallVector Offsets; + for (SCEVUse IndexExpr : IndexExprs) { // Compute the (potentially symbolic) offset in bytes for this index. if (StructType *STy = dyn_cast(CurTy)) { // For a struct, add the member offset. ConstantInt *Index = cast(IndexExpr)->getValue(); unsigned FieldNo = Index->getZExtValue(); - const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo); + SCEVUse FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo); Offsets.push_back(FieldOffset); // Update CurTy to the type of the field at Index. @@ -3782,12 +3889,12 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0); } // For an array, add the element offset, explicitly scaled. - const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy); + SCEVUse ElementSize = getSizeOfExpr(IntIdxTy, CurTy); // Getelementptr indices are signed. IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy); // Multiply the index by the element size to compute the element offset. - const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap); + SCEVUse LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap); Offsets.push_back(LocalOffset); } } @@ -3797,36 +3904,42 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, return BaseExpr; // Add the offsets together, assuming nsw if inbounds. - const SCEV *Offset = getAddExpr(Offsets, OffsetWrap); + SCEVUse Offset = getAddExpr(Offsets, OffsetWrap); // Add the base address and the offset. We cannot use the nsw flag, as the // base address is unsigned. However, if we know that the offset is // non-negative, we can use nuw. bool NUW = NW.hasNoUnsignedWrap() || (NW.hasNoUnsignedSignedWrap() && isKnownNonNegative(Offset)); SCEV::NoWrapFlags BaseWrap = NUW ? SCEV::FlagNUW : SCEV::FlagAnyWrap; - auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap); + auto GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap); assert(BaseExpr->getType() == GEPExpr->getType() && "GEP should not change type mid-flight."); return GEPExpr; } SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType, - ArrayRef Ops) { + ArrayRef Ops) { FoldingSetNodeID ID; ID.AddInteger(SCEVType); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; return UniqueSCEVs.FindNodeOrInsertPos(ID, IP); } -const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) { +SCEVUse ScalarEvolution::getAbsExpr(SCEVUse Op, bool IsNSW) { SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap; return getSMaxExpr(Op, getNegativeSCEV(Op, Flags)); } -const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, - SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, + ArrayRef Ops) { + SmallVector Ops2(Ops); + return getMinMaxExpr(Kind, Ops2); +} + +SCEVUse ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, + SmallVectorImpl &Ops) { assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!"); assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); if (Ops.size() == 1) return Ops[0]; @@ -3878,7 +3991,7 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, return Folded; // Check if we have created the same expression before. - if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) { + if (SCEVUse S = findExistingSCEVInCache(Kind, Ops)) { return S; } @@ -3936,19 +4049,23 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(Kind); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; - const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); + SCEVUse ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); if (ExistingSCEV) return ExistingSCEV; - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); SCEV *S = new (SCEVAllocator) SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Ops); + if (all_of(Ops, [](SCEVUse U) { return U.isCanonical(); })) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } @@ -3956,14 +4073,14 @@ namespace { class SCEVSequentialMinMaxDeduplicatingVisitor final : public SCEVVisitor> { - using RetVal = std::optional; + std::optional> { + using RetVal = std::optional; using Base = SCEVVisitor; ScalarEvolution &SE; const SCEVTypes RootKind; // Must be a sequential min/max expression. const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind. - SmallPtrSet SeenOps; + SmallPtrSet SeenOps; bool canRecurseInto(SCEVTypes Kind) const { // We can only recurse into the SCEV expression of the same effective type @@ -3971,7 +4088,7 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final return RootKind == Kind || NonSequentialRootKind == Kind; }; - RetVal visitAnyMinMaxExpr(const SCEV *S) { + RetVal visitAnyMinMaxExpr(SCEVUse S) { assert((isa(S) || isa(S)) && "Only for min/max expressions."); SCEVTypes Kind = S->getSCEVType(); @@ -3980,7 +4097,7 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final return S; auto *NAry = cast(S); - SmallVector NewOps; + SmallVector NewOps; bool Changed = visit(Kind, NAry->operands(), NewOps); if (!Changed) @@ -3993,7 +4110,7 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final : SE.getMinMaxExpr(Kind, NewOps); } - RetVal visit(const SCEV *S) { + RetVal visit(SCEVUse S) { // Has the whole operand been seen already? if (!SeenOps.insert(S).second) return std::nullopt; @@ -4008,13 +4125,13 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType( RootKind)) {} - bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef OrigOps, - SmallVectorImpl &NewOps) { + bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef OrigOps, + SmallVectorImpl &NewOps) { bool Changed = false; - SmallVector Ops; + SmallVector Ops; Ops.reserve(OrigOps.size()); - for (const SCEV *Op : OrigOps) { + for (SCEVUse Op : OrigOps) { RetVal NewOp = visit(Op); if (NewOp != Op) Changed = true; @@ -4117,7 +4234,7 @@ struct SCEVPoisonCollector { SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking) : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {} - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { if (!LookThroughMaybePoisonBlocking && !scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) return false; @@ -4133,7 +4250,7 @@ struct SCEVPoisonCollector { } // namespace /// Return true if V is poison given that AssumedPoison is already poison. -static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { +static bool impliesPoison(SCEVUse AssumedPoison, SCEVUse S) { // First collect all SCEVs that might result in AssumedPoison to be poison. // We need to look through potentially poison-blocking operations here, // because we want to find all SCEVs that *might* result in poison, not only @@ -4158,7 +4275,7 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { } void ScalarEvolution::getPoisonGeneratingValues( - SmallPtrSetImpl &Result, const SCEV *S) { + SmallPtrSetImpl &Result, SCEVUse S) { SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false); visitAll(S, PC); for (const SCEVUnknown *SU : PC.MaybePoison) @@ -4166,7 +4283,7 @@ void ScalarEvolution::getPoisonGeneratingValues( } bool ScalarEvolution::canReuseInstruction( - const SCEV *S, Instruction *I, + SCEVUse S, Instruction *I, SmallVectorImpl &DropPoisonGeneratingInsts) { // If the instruction cannot be poison, it's always safe to reuse. if (programUndefinedIfPoison(I)) @@ -4227,9 +4344,9 @@ bool ScalarEvolution::canReuseInstruction( return true; } -const SCEV * +SCEVUse ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, - SmallVectorImpl &Ops) { + SmallVectorImpl &Ops) { assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) && "Not a SCEVSequentialMinMaxExpr!"); assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); @@ -4250,7 +4367,7 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, // so we can *NOT* do any kind of sorting of the expressions! // Check if we have created the same expression before. - if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) + if (SCEVUse S = findExistingSCEVInCache(Kind, Ops)) return S; // FIXME: there are *some* simplifications that we can do here. @@ -4284,7 +4401,7 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, return getSequentialMinMaxExpr(Kind, Ops); } - const SCEV *SaturationPoint; + SCEVUse SaturationPoint; ICmpInst::Predicate Pred; switch (Kind) { case scSequentialUMinExpr: @@ -4304,7 +4421,7 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, if (::impliesPoison(Ops[i], Ops[i - 1]) || isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1], SaturationPoint)) { - SmallVector SeqOps = {Ops[i - 1], Ops[i]}; + SmallVector SeqOps = {Ops[i - 1], Ops[i]}; Ops[i - 1] = getMinMaxExpr( SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind), SeqOps); @@ -4323,82 +4440,83 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(Kind); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; - const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); + SCEVUse ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); if (ExistingSCEV) return ExistingSCEV; - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); SCEV *S = new (SCEVAllocator) SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); registerUser(S, Ops); + if (all_of(Ops, [](SCEVUse U) { return U.isCanonical(); })) + S->setCanonical(S); + else + S->setCanonical(SCEVUse::computeCanonical(*this, S)); return S; } -const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector Ops = {LHS, RHS}; +SCEVUse ScalarEvolution::getSMaxExpr(SCEVUse LHS, SCEVUse RHS) { + SmallVector Ops = {LHS, RHS}; return getSMaxExpr(Ops); } -const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { return getMinMaxExpr(scSMaxExpr, Ops); } -const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector Ops = {LHS, RHS}; +SCEVUse ScalarEvolution::getUMaxExpr(SCEVUse LHS, SCEVUse RHS) { + SmallVector Ops = {LHS, RHS}; return getUMaxExpr(Ops); } -const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { return getMinMaxExpr(scUMaxExpr, Ops); } -const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, - const SCEV *RHS) { - SmallVector Ops = { LHS, RHS }; +SCEVUse ScalarEvolution::getSMinExpr(SCEVUse LHS, SCEVUse RHS) { + SmallVector Ops = {LHS, RHS}; return getSMinExpr(Ops); } -const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getSMinExpr(SmallVectorImpl &Ops) { return getMinMaxExpr(scSMinExpr, Ops); } -const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS, - bool Sequential) { - SmallVector Ops = { LHS, RHS }; +SCEVUse ScalarEvolution::getUMinExpr(SCEVUse LHS, SCEVUse RHS, + bool Sequential) { + SmallVector Ops = {LHS, RHS}; return getUMinExpr(Ops, Sequential); } -const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops, - bool Sequential) { +SCEVUse ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops, + bool Sequential) { return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops) : getMinMaxExpr(scUMinExpr, Ops); } -const SCEV * -ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { - const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue()); +SCEVUse ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { + SCEVUse Res = getConstant(IntTy, Size.getKnownMinValue()); if (Size.isScalable()) Res = getMulExpr(Res, getVScale(IntTy)); return Res; } -const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { +SCEVUse ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } -const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) { +SCEVUse ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) { return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy)); } -const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, - StructType *STy, - unsigned FieldNo) { +SCEVUse ScalarEvolution::getOffsetOfExpr(Type *IntTy, StructType *STy, + unsigned FieldNo) { // We can bypass creating a target-independent constant expression and then // folding it back into a ConstantInt. This is just a compile-time // optimization. @@ -4408,7 +4526,7 @@ const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, return getConstant(IntTy, SL->getElementOffset(FieldNo)); } -const SCEV *ScalarEvolution::getUnknown(Value *V) { +SCEVUse ScalarEvolution::getUnknown(Value *V) { // Don't attempt to do anything other than create a SCEVUnknown object // here. createSCEV only calls getUnknown after checking for all other // interesting possibilities, and any other code that calls getUnknown @@ -4427,6 +4545,7 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) { FirstUnknown); FirstUnknown = cast(S); UniqueSCEVs.InsertNode(S, IP); + S->setCanonical(S); return S; } @@ -4470,8 +4589,7 @@ Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; } -bool ScalarEvolution::instructionCouldExistWithOperands(const SCEV *A, - const SCEV *B) { +bool ScalarEvolution::instructionCouldExistWithOperands(SCEVUse A, SCEVUse B) { /// For a valid use point to exist, the defining scope of one operand /// must dominate the other. bool PreciseA, PreciseB; @@ -4484,12 +4602,10 @@ bool ScalarEvolution::instructionCouldExistWithOperands(const SCEV *A, DT.dominates(ScopeB, ScopeA); } -const SCEV *ScalarEvolution::getCouldNotCompute() { - return CouldNotCompute.get(); -} +SCEVUse ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } -bool ScalarEvolution::checkValidity(const SCEV *S) const { - bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { +bool ScalarEvolution::checkValidity(SCEVUse S) const { + bool ContainsNulls = SCEVExprContains(S, [](SCEVUse S) { auto *SU = dyn_cast(S); return SU && SU->getValue() == nullptr; }); @@ -4497,20 +4613,20 @@ bool ScalarEvolution::checkValidity(const SCEV *S) const { return !ContainsNulls; } -bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { +bool ScalarEvolution::containsAddRecurrence(SCEVUse S) { HasRecMapType::iterator I = HasRecMap.find(S); if (I != HasRecMap.end()) return I->second; bool FoundAddRec = - SCEVExprContains(S, [](const SCEV *S) { return isa(S); }); + SCEVExprContains(S, [](SCEVUse S) { return isa(S); }); HasRecMap.insert({S, FoundAddRec}); return FoundAddRec; } /// Return the ValueOffsetPair set for \p S. \p S can be represented /// by the value and offset from any ValueOffsetPair in the set. -ArrayRef ScalarEvolution::getSCEVValues(const SCEV *S) { +ArrayRef ScalarEvolution::getSCEVValues(SCEVUse S) { ExprValueMapType::iterator SI = ExprValueMap.find_as(S); if (SI == ExprValueMap.end()) return {}; @@ -4531,7 +4647,7 @@ void ScalarEvolution::eraseValueFromMap(Value *V) { } } -void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) { +void ScalarEvolution::insertValueToMap(Value *V, SCEVUse S) { // A recursive query may have already computed the SCEV. It should be // equivalent, but may not necessarily be exactly the same, e.g. due to lazily // inferred nowrap flags. @@ -4544,20 +4660,20 @@ void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) { /// Return an existing SCEV if it exists, otherwise analyze the expression and /// create a new one. -const SCEV *ScalarEvolution::getSCEV(Value *V) { +SCEVUse ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); - if (const SCEV *S = getExistingSCEV(V)) + if (SCEVUse S = getExistingSCEV(V)) return S; return createSCEVIter(V); } -const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { +SCEVUse ScalarEvolution::getExistingSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); ValueExprMapType::iterator I = ValueExprMap.find_as(V); if (I != ValueExprMap.end()) { - const SCEV *S = I->second; + SCEVUse S = I->second; assert(checkValidity(S) && "existing SCEV has not been properly invalidated"); return S; @@ -4566,8 +4682,7 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { } /// Return a SCEV corresponding to -V = -1*V -const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, - SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getNegativeSCEV(SCEVUse V, SCEV::NoWrapFlags Flags) { if (const SCEVConstant *VC = dyn_cast(V)) return getConstant( cast(ConstantExpr::getNeg(VC->getValue()))); @@ -4578,7 +4693,7 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, } /// If Expr computes ~A, return A else return nullptr -static const SCEV *MatchNotExpr(const SCEV *Expr) { +static SCEVUse MatchNotExpr(SCEVUse Expr) { const SCEVAddExpr *Add = dyn_cast(Expr); if (!Add || Add->getNumOperands() != 2 || !Add->getOperand(0)->isAllOnesValue()) @@ -4593,7 +4708,7 @@ static const SCEV *MatchNotExpr(const SCEV *Expr) { } /// Return a SCEV corresponding to ~V = -1-V -const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { +SCEVUse ScalarEvolution::getNotSCEV(SCEVUse V) { assert(!V->getType()->isPointerTy() && "Can't negate pointer"); if (const SCEVConstant *VC = dyn_cast(V)) @@ -4603,17 +4718,17 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y) if (const SCEVMinMaxExpr *MME = dyn_cast(V)) { auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) { - SmallVector MatchedOperands; - for (const SCEV *Operand : MME->operands()) { - const SCEV *Matched = MatchNotExpr(Operand); + SmallVector MatchedOperands; + for (SCEVUse Operand : MME->operands()) { + SCEVUse Matched = MatchNotExpr(Operand); if (!Matched) - return (const SCEV *)nullptr; + return (SCEVUse) nullptr; MatchedOperands.push_back(Matched); } return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()), MatchedOperands); }; - if (const SCEV *Replaced = MatchMinMaxNegation(MME)) + if (SCEVUse Replaced = MatchMinMaxNegation(MME)) return Replaced; } @@ -4622,12 +4737,12 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { return getMinusSCEV(getMinusOne(Ty), V); } -const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) { +SCEVUse ScalarEvolution::removePointerBase(SCEVUse P) { assert(P->getType()->isPointerTy()); if (auto *AddRec = dyn_cast(P)) { // The base of an AddRec is the first operand. - SmallVector Ops{AddRec->operands()}; + SmallVector Ops{AddRec->operands()}; Ops[0] = removePointerBase(Ops[0]); // Don't try to transfer nowrap flags for now. We could in some cases // (for example, if pointer operand of the AddRec is a SCEVUnknown). @@ -4635,9 +4750,9 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) { } if (auto *Add = dyn_cast(P)) { // The base of an Add is the pointer operand. - SmallVector Ops{Add->operands()}; - const SCEV **PtrOp = nullptr; - for (const SCEV *&AddOp : Ops) { + SmallVector Ops{Add->operands()}; + SCEVUse *PtrOp = nullptr; + for (SCEVUse &AddOp : Ops) { if (AddOp->getType()->isPointerTy()) { assert(!PtrOp && "Cannot have multiple pointer ops"); PtrOp = &AddOp; @@ -4652,9 +4767,8 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) { return getZero(P->getType()); } -const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, - SCEV::NoWrapFlags Flags, - unsigned Depth) { +SCEVUse ScalarEvolution::getMinusSCEV(SCEVUse LHS, SCEVUse RHS, + SCEV::NoWrapFlags Flags, unsigned Depth) { // Fast path: X - X --> 0. if (LHS == RHS) return getZero(LHS->getType()); @@ -4702,8 +4816,8 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth); } -const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getTruncateOrZeroExtend(SCEVUse V, Type *Ty, + unsigned Depth) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"); @@ -4714,8 +4828,8 @@ const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty, return getZeroExtendExpr(V, Ty, Depth); } -const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getTruncateOrSignExtend(SCEVUse V, Type *Ty, + unsigned Depth) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"); @@ -4726,8 +4840,7 @@ const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty, return getSignExtendExpr(V, Ty, Depth); } -const SCEV * -ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { +SCEVUse ScalarEvolution::getNoopOrZeroExtend(SCEVUse V, Type *Ty) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot noop or zero extend with non-integer arguments!"); @@ -4738,8 +4851,7 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { return getZeroExtendExpr(V, Ty); } -const SCEV * -ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { +SCEVUse ScalarEvolution::getNoopOrSignExtend(SCEVUse V, Type *Ty) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot noop or sign extend with non-integer arguments!"); @@ -4750,8 +4862,7 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { return getSignExtendExpr(V, Ty); } -const SCEV * -ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { +SCEVUse ScalarEvolution::getNoopOrAnyExtend(SCEVUse V, Type *Ty) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot noop or any extend with non-integer arguments!"); @@ -4762,8 +4873,7 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { return getAnyExtendExpr(V, Ty); } -const SCEV * -ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { +SCEVUse ScalarEvolution::getTruncateOrNoop(SCEVUse V, Type *Ty) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate or noop with non-integer arguments!"); @@ -4774,10 +4884,9 @@ ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { return getTruncateExpr(V, Ty); } -const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, - const SCEV *RHS) { - const SCEV *PromotedLHS = LHS; - const SCEV *PromotedRHS = RHS; +SCEVUse ScalarEvolution::getUMaxFromMismatchedTypes(SCEVUse LHS, SCEVUse RHS) { + SCEVUse PromotedLHS = LHS; + SCEVUse PromotedRHS = RHS; if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); @@ -4787,15 +4896,20 @@ const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, return getUMaxExpr(PromotedLHS, PromotedRHS); } -const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, - const SCEV *RHS, - bool Sequential) { - SmallVector Ops = { LHS, RHS }; +SCEVUse ScalarEvolution::getUMinFromMismatchedTypes(SCEVUse LHS, SCEVUse RHS, + bool Sequential) { + SmallVector Ops = {LHS, RHS}; return getUMinFromMismatchedTypes(Ops, Sequential); } -const SCEV * -ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, +SCEVUse ScalarEvolution::getUMinFromMismatchedTypes(ArrayRef Ops, + bool Sequential) { + SmallVector Ops2(Ops); + return getUMinFromMismatchedTypes(Ops2, Sequential); +} + +SCEVUse +ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, bool Sequential) { assert(!Ops.empty() && "At least one operand must be!"); // Trivial case. @@ -4804,7 +4918,7 @@ ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, // Find the max type first. Type *MaxType = nullptr; - for (const auto *S : Ops) + for (const auto S : Ops) if (MaxType) MaxType = getWiderType(MaxType, S->getType()); else @@ -4812,15 +4926,15 @@ ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, assert(MaxType && "Failed to find maximum type!"); // Extend all ops to max type. - SmallVector PromotedOps; - for (const auto *S : Ops) + SmallVector PromotedOps; + for (const auto S : Ops) PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType)); // Generate umin. return getUMinExpr(PromotedOps, Sequential); } -const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { +SCEVUse ScalarEvolution::getPointerBase(SCEVUse V) { // A pointer operand may evaluate to a nonpointer expression, such as null. if (!V->getType()->isPointerTy()) return V; @@ -4829,8 +4943,8 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { if (auto *AddRec = dyn_cast(V)) { V = AddRec->getStart(); } else if (auto *Add = dyn_cast(V)) { - const SCEV *PtrOp = nullptr; - for (const SCEV *AddOp : Add->operands()) { + SCEVUse PtrOp = nullptr; + for (SCEVUse AddOp : Add->operands()) { if (AddOp->getType()->isPointerTy()) { assert(!PtrOp && "Cannot have multiple pointer ops"); PtrOp = AddOp; @@ -4864,10 +4978,10 @@ namespace { /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. class SCEVInitRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, - bool IgnoreOtherLoops = true) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE, + bool IgnoreOtherLoops = true) { SCEVInitRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(S); + SCEVUse Result = Rewriter.visit(S); if (Rewriter.hasSeenLoopVariantSCEVUnknown()) return SE.getCouldNotCompute(); return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops @@ -4875,13 +4989,13 @@ class SCEVInitRewriter : public SCEVRewriteVisitor { : Result; } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { if (!SE.isLoopInvariant(Expr, L)) SeenLoopVariantSCEVUnknown = true; return Expr; } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { // Only re-write AddRecExprs for this loop. if (Expr->getLoop() == L) return Expr->getStart(); @@ -4908,21 +5022,21 @@ class SCEVInitRewriter : public SCEVRewriteVisitor { /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. class SCEVPostIncRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE) { SCEVPostIncRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(S); + SCEVUse Result = Rewriter.visit(S); return Rewriter.hasSeenLoopVariantSCEVUnknown() ? SE.getCouldNotCompute() : Result; } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { if (!SE.isLoopInvariant(Expr, L)) SeenLoopVariantSCEVUnknown = true; return Expr; } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { // Only re-write AddRecExprs for this loop. if (Expr->getLoop() == L) return Expr->getPostIncExpr(SE); @@ -4949,8 +5063,7 @@ class SCEVPostIncRewriter : public SCEVRewriteVisitor { class SCEVBackedgeConditionFolder : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, - ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE) { bool IsPosBECond = false; Value *BECond = nullptr; if (BasicBlock *Latch = L->getLoopLatch()) { @@ -4968,8 +5081,8 @@ class SCEVBackedgeConditionFolder return Rewriter.visit(S); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { - const SCEV *Result = Expr; + SCEVUse visitUnknown(const SCEVUnknown *Expr) { + SCEVUse Result = Expr; bool InvariantF = SE.isLoopInvariant(Expr, L); if (!InvariantF) { @@ -4977,7 +5090,7 @@ class SCEVBackedgeConditionFolder switch (I->getOpcode()) { case Instruction::Select: { SelectInst *SI = cast(I); - std::optional Res = + std::optional Res = compareWithBackedgeCondition(SI->getCondition()); if (Res) { bool IsOne = cast(*Res)->getValue()->isOne(); @@ -4986,7 +5099,7 @@ class SCEVBackedgeConditionFolder break; } default: { - std::optional Res = compareWithBackedgeCondition(I); + std::optional Res = compareWithBackedgeCondition(I); if (Res) Result = *Res; break; @@ -5002,7 +5115,7 @@ class SCEVBackedgeConditionFolder : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond), IsPositiveBECond(IsPosBECond) {} - std::optional compareWithBackedgeCondition(Value *IC); + std::optional compareWithBackedgeCondition(Value *IC); const Loop *L; /// Loop back condition. @@ -5011,7 +5124,7 @@ class SCEVBackedgeConditionFolder bool IsPositiveBECond; }; -std::optional +std::optional SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { // If value matches the backedge condition for loop latch, @@ -5025,21 +5138,20 @@ SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { class SCEVShiftRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, - ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE) { SCEVShiftRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(S); + SCEVUse Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { // Only allow AddRecExprs for this loop. if (!SE.isLoopInvariant(Expr, L)) Valid = false; return Expr; } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { if (Expr->getLoop() == L && Expr->isAffine()) return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); Valid = false; @@ -5068,7 +5180,7 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; if (!AR->hasNoSelfWrap()) { - const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop()); + SCEVUse BECount = getConstantMaxBackedgeTakenCount(AR->getLoop()); if (const SCEVConstant *BECountMax = dyn_cast(BECount)) { ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this)); const APInt &BECountAP = BECountMax->getAPInt(); @@ -5116,7 +5228,7 @@ ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { if (!SignedWrapViaInductionTried.insert(AR).second) return Result; - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Step = AR->getStepRecurrence(*this); const Loop *L = AR->getLoop(); // Check whether the backedge-taken count is SCEVCouldNotCompute. @@ -5127,7 +5239,7 @@ ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + SCEVUse MaxBECount = getConstantMaxBackedgeTakenCount(L); // Normally, in the cases we can prove no-overflow via a // backedge guarding condition, we can also compute a backedge @@ -5146,8 +5258,7 @@ ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { // start value and the backedge is guarded by a comparison with the post-inc // value, the addrec is safe. ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = - getSignedOverflowLimitForStep(Step, &Pred, this); + SCEVUse OverflowLimit = getSignedOverflowLimitForStep(Step, &Pred, this); if (OverflowLimit && (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || isKnownOnEveryIteration(Pred, AR, OverflowLimit))) { @@ -5169,7 +5280,7 @@ ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { if (!UnsignedWrapViaInductionTried.insert(AR).second) return Result; - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Step = AR->getStepRecurrence(*this); unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); @@ -5181,7 +5292,7 @@ ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + SCEVUse MaxBECount = getConstantMaxBackedgeTakenCount(L); // Normally, in the cases we can prove no-overflow via a // backedge guarding condition, we can also compute a backedge @@ -5200,8 +5311,8 @@ ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { // start value and the backedge is guarded by a comparison with the post-inc // value, the addrec is safe. if (isKnownPositive(Step)) { - const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - - getUnsignedRangeMax(Step)); + SCEVUse N = + getConstant(APInt::getMinValue(BitWidth) - getUnsignedRangeMax(Step)); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) { Result = setFlags(Result, SCEV::FlagNUW); @@ -5350,7 +5461,7 @@ static std::optional MatchBinaryOp(Value *V, const DataLayout &DL, /// we return the type of the truncation operation, and indicate whether the /// truncated type should be treated as signed/unsigned by setting /// \p Signed to true/false, respectively. -static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, +static Type *isSimpleCastedPHI(SCEVUse Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE) { // The case where Op == SymbolicPHI (that is, with no type conversions on // the way) is handled by the regular add recurrence creating logic and @@ -5379,7 +5490,7 @@ static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, : dyn_cast(ZExt->getOperand()); if (!Trunc) return nullptr; - const SCEV *X = Trunc->getOperand(); + SCEVUse X = Trunc->getOperand(); if (X != SymbolicPHI) return nullptr; Signed = SExt != nullptr; @@ -5448,8 +5559,9 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { // which correspond to a phi->trunc->add->sext/zext->phi update chain. // // 3) Outline common code with createAddRecFromPHI to avoid duplication. -std::optional>> -ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) { +std::optional>> +ScalarEvolution::createAddRecFromPHIWithCastsImpl( + const SCEVUnknown *SymbolicPHI) { SmallVector Predicates; // *** Part1: Analyze if we have a phi-with-cast pattern for which we can @@ -5482,7 +5594,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI if (!BEValueV || !StartValueV) return std::nullopt; - const SCEV *BEValue = getSCEV(BEValueV); + SCEVUse BEValue = getSCEV(BEValueV); // If the value coming around the backedge is an add with the symbolic // value we just inserted, possibly with casts that we can ignore under @@ -5508,11 +5620,11 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI return std::nullopt; // Create an add with everything but the specified operand. - SmallVector Ops; + SmallVector Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) Ops.push_back(Add->getOperand(i)); - const SCEV *Accum = getAddExpr(Ops); + SCEVUse Accum = getAddExpr(Ops); // The runtime checks will not be valid if the step amount is // varying inside the loop. @@ -5570,8 +5682,8 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // // Create a truncated addrec for which we will add a no overflow check (P1). - const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = + SCEVUse StartVal = getSCEV(StartValueV); + SCEVUse PHISCEV = getAddRecExpr(getTruncateExpr(StartVal, TruncTy), getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); @@ -5598,11 +5710,10 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy) // for each of StartVal and Accum - auto getExtendedExpr = [&](const SCEV *Expr, - bool CreateSignExtend) -> const SCEV * { + auto getExtendedExpr = [&](SCEVUse Expr, bool CreateSignExtend) -> SCEVUse { assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); - const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy); - const SCEV *ExtendedExpr = + SCEVUse TruncatedExpr = getTruncateExpr(Expr, TruncTy); + SCEVUse ExtendedExpr = CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType()) : getZeroExtendExpr(TruncatedExpr, Expr->getType()); return ExtendedExpr; @@ -5613,13 +5724,12 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // = getExtendedExpr(Expr) // Determine whether the predicate P: Expr == ExtendedExpr // is known to be false at compile time - auto PredIsKnownFalse = [&](const SCEV *Expr, - const SCEV *ExtendedExpr) -> bool { + auto PredIsKnownFalse = [&](SCEVUse Expr, SCEVUse ExtendedExpr) -> bool { return Expr != ExtendedExpr && isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr); }; - const SCEV *StartExtended = getExtendedExpr(StartVal, Signed); + SCEVUse StartExtended = getExtendedExpr(StartVal, Signed); if (PredIsKnownFalse(StartVal, StartExtended)) { LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";); return std::nullopt; @@ -5627,14 +5737,13 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // The Step is always Signed (because the overflow checks are either // NSSW or NUSW) - const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); + SCEVUse AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); if (PredIsKnownFalse(Accum, AccumExtended)) { LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";); return std::nullopt; } - auto AppendPredicate = [&](const SCEV *Expr, - const SCEV *ExtendedExpr) -> void { + auto AppendPredicate = [&](SCEVUse Expr, SCEVUse ExtendedExpr) -> void { if (Expr != ExtendedExpr && !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) { const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr); @@ -5650,16 +5759,16 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // which the casts had been folded away. The caller can rewrite SymbolicPHI // into NewAR if it will also add the runtime overflow checks specified in // Predicates. - auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap); + auto NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap); - std::pair> PredRewrite = + std::pair> PredRewrite = std::make_pair(NewAR, Predicates); // Remember the result of the analysis for this SCEV at this locayyytion. PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite; return PredRewrite; } -std::optional>> +std::optional>> ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { auto *PN = cast(SymbolicPHI->getValue()); const Loop *L = isIntegerLoopHeaderPHI(PN, LI); @@ -5669,7 +5778,7 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { // Check to see if we already analyzed this PHI. auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L}); if (I != PredicatedSCEVRewrites.end()) { - std::pair> Rewrite = + std::pair> Rewrite = I->second; // Analysis was done before and failed to create an AddRec: if (Rewrite.first == SymbolicPHI) @@ -5681,8 +5790,8 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { return Rewrite; } - std::optional>> - Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI); + std::optional>> + Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI); // Record in the cache that the analysis failed if (!Rewrite) { @@ -5705,7 +5814,7 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( if (AR1 == AR2) return true; - auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { + auto areExprsEqual = [&](SCEVUse Expr1, SCEVUse Expr2) -> bool { if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) && !Preds->implies(SE.getEqualPredicate(Expr2, Expr1))) return false; @@ -5724,9 +5833,8 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( /// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)). /// If it fails, createAddRecFromPHI will use a more general, but slow, /// technique for finding the AddRec expression. -const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, - Value *BEValueV, - Value *StartValueV) { +SCEVUse ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, Value *BEValueV, + Value *StartValueV) { const Loop *L = LI.getLoopFor(PN->getParent()); assert(L && L->getHeader() == PN->getParent()); assert(BEValueV && StartValueV); @@ -5738,7 +5846,7 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, if (BO->Opcode != Instruction::Add) return nullptr; - const SCEV *Accum = nullptr; + SCEVUse Accum = nullptr; if (BO->LHS == PN && L->isLoopInvariant(BO->RHS)) Accum = getSCEV(BO->RHS); else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS)) @@ -5753,8 +5861,8 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, if (BO->IsNSW) Flags = setFlags(Flags, SCEV::FlagNSW); - const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + SCEVUse StartVal = getSCEV(StartValueV); + SCEVUse PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); insertValueToMap(PN, PHISCEV); if (auto *AR = dyn_cast(PHISCEV)) { @@ -5776,7 +5884,7 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, return PHISCEV; } -const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { +SCEVUse ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) return nullptr; @@ -5809,16 +5917,16 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // First, try to find AddRec expression without creating a fictituos symbolic // value for PN. - if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV)) + if (auto S = createSimpleAffineAddRec(PN, BEValueV, StartValueV)) return S; // Handle PHI node value symbolically. - const SCEV *SymbolicName = getUnknown(PN); + SCEVUse SymbolicName = getUnknown(PN); insertValueToMap(PN, SymbolicName); // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. - const SCEV *BEValue = getSCEV(BEValueV); + SCEVUse BEValue = getSCEV(BEValueV); // NOTE: If BEValue is loop invariant, we know that the PHI node just // has a special value for the first iteration of the loop. @@ -5838,12 +5946,12 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { if (FoundIndex != Add->getNumOperands()) { // Create an add with everything but the specified operand. - SmallVector Ops; + SmallVector Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i), L, *this)); - const SCEV *Accum = getAddExpr(Ops); + SCEVUse Accum = getAddExpr(Ops); // This is not a valid addrec if the step amount is varying each // loop iteration, but is not itself an addrec in this loop. @@ -5879,8 +5987,8 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // for instance. } - const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + SCEVUse StartVal = getSCEV(StartValueV); + SCEVUse PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the @@ -5976,7 +6084,7 @@ static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, return false; } -const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { +SCEVUse ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { auto IsReachable = [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); }; if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) { @@ -6044,7 +6152,7 @@ ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) { return SCEVExprsIdentical ? CommonSCEV : nullptr; } -const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { +SCEVUse ScalarEvolution::createNodeForPHI(PHINode *PN) { if (const SCEV *S = createAddRecFromPHI(PN)) return S; @@ -6065,10 +6173,10 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { return getUnknown(PN); } -bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, +bool SCEVMinMaxExprContains(SCEVUse Root, SCEVUse OperandToFind, SCEVTypes RootKind) { struct FindClosure { - const SCEV *OperandToFind; + SCEVUse OperandToFind; const SCEVTypes RootKind; // Must be a sequential min/max expression. const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind. @@ -6081,13 +6189,13 @@ bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, scZeroExtend == Kind; }; - FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind) + FindClosure(SCEVUse OperandToFind, SCEVTypes RootKind) : OperandToFind(OperandToFind), RootKind(RootKind), NonSequentialRootKind( SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType( RootKind)) {} - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { Found = S == OperandToFind; return !isDone() && canRecurseInto(S->getSCEVType()); @@ -6101,7 +6209,7 @@ bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, return FC.Found; } -std::optional +std::optional ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, ICmpInst *Cond, Value *TrueVal, @@ -6127,10 +6235,10 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, // a > b ? b+x : a+x -> min(a, b)+x if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty)) { bool Signed = ICI->isSigned(); - const SCEV *LA = getSCEV(TrueVal); - const SCEV *RA = getSCEV(FalseVal); - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + SCEVUse LA = getSCEV(TrueVal); + SCEVUse RA = getSCEV(FalseVal); + SCEVUse LS = getSCEV(LHS); + SCEVUse RS = getSCEV(RHS); if (LA->getType()->isPointerTy()) { // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA. // Need to make sure we can't produce weird expressions involving @@ -6140,7 +6248,7 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, if (LA == RS && RA == LS) return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS); } - auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * { + auto CoerceOperand = [&](SCEVUse Op) -> SCEVUse { if (Op->getType()->isPointerTy()) { Op = getLosslessPtrToIntExpr(Op); if (isa(Op)) @@ -6156,8 +6264,8 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, RS = CoerceOperand(RS); if (isa(LS) || isa(RS)) break; - const SCEV *LDiff = getMinusSCEV(LA, LS); - const SCEV *RDiff = getMinusSCEV(RA, RS); + SCEVUse LDiff = getMinusSCEV(LA, LS); + SCEVUse RDiff = getMinusSCEV(RA, RS); if (LDiff == RDiff) return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS), LDiff); @@ -6176,11 +6284,11 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty) && isa(RHS) && cast(RHS)->isZero()) { - const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty); - const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y - const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y - const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x - const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y + SCEVUse X = getNoopOrZeroExtend(getSCEV(LHS), Ty); + SCEVUse TrueValExpr = getSCEV(TrueVal); // C+y + SCEVUse FalseValExpr = getSCEV(FalseVal); // x+y + SCEVUse Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x + SCEVUse C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y if (isa(C) && cast(C)->getAPInt().ule(1)) return getAddExpr(getUMaxExpr(X, C), Y); } @@ -6190,11 +6298,11 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, // -> umin_seq(x, umin (..., umin_seq(...), ...)) if (isa(RHS) && cast(RHS)->isZero() && isa(TrueVal) && cast(TrueVal)->isZero()) { - const SCEV *X = getSCEV(LHS); + SCEVUse X = getSCEV(LHS); while (auto *ZExt = dyn_cast(X)) X = ZExt->getOperand(); if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) { - const SCEV *FalseValExpr = getSCEV(FalseVal); + SCEVUse FalseValExpr = getSCEV(FalseVal); if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr)) return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr, /*Sequential=*/true); @@ -6208,9 +6316,10 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, return std::nullopt; } -static std::optional -createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, - const SCEV *TrueExpr, const SCEV *FalseExpr) { +static std::optional createNodeForSelectViaUMinSeq(ScalarEvolution *SE, + SCEVUse CondExpr, + SCEVUse TrueExpr, + SCEVUse FalseExpr) { assert(CondExpr->getType()->isIntegerTy(1) && TrueExpr->getType() == FalseExpr->getType() && TrueExpr->getType()->isIntegerTy(1) && @@ -6228,7 +6337,7 @@ createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, if (!isa(TrueExpr) && !isa(FalseExpr)) return std::nullopt; - const SCEV *X, *C; + SCEVUse X, C; if (isa(TrueExpr)) { CondExpr = SE->getNotSCEV(CondExpr); X = FalseExpr; @@ -6241,20 +6350,23 @@ createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, /*Sequential=*/true)); } -static std::optional -createNodeForSelectViaUMinSeq(ScalarEvolution *SE, Value *Cond, Value *TrueVal, - Value *FalseVal) { +static std::optional createNodeForSelectViaUMinSeq(ScalarEvolution *SE, + Value *Cond, + Value *TrueVal, + Value *FalseVal) { if (!isa(TrueVal) && !isa(FalseVal)) return std::nullopt; - const auto *SECond = SE->getSCEV(Cond); - const auto *SETrue = SE->getSCEV(TrueVal); - const auto *SEFalse = SE->getSCEV(FalseVal); + const auto SECond = SE->getSCEV(Cond); + const auto SETrue = SE->getSCEV(TrueVal); + const auto SEFalse = SE->getSCEV(FalseVal); return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse); } -const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq( - Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) { +SCEVUse ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(Value *V, + Value *Cond, + Value *TrueVal, + Value *FalseVal) { assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?"); assert(TrueVal->getType() == FalseVal->getType() && V->getType() == TrueVal->getType() && @@ -6264,16 +6376,16 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq( if (!V->getType()->isIntegerTy(1)) return getUnknown(V); - if (std::optional S = + if (std::optional S = createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal)) return *S; return getUnknown(V); } -const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, - Value *TrueVal, - Value *FalseVal) { +SCEVUse ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, + Value *TrueVal, + Value *FalseVal) { // Handle "constant" branch or select. This can occur for instance when a // loop pass transforms an inner loop and moves on to process the outer loop. if (auto *CI = dyn_cast(Cond)) @@ -6281,7 +6393,7 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, if (auto *I = dyn_cast(V)) { if (auto *ICI = dyn_cast(Cond)) { - if (std::optional S = + if (std::optional S = createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI, TrueVal, FalseVal)) return *S; @@ -6293,17 +6405,17 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, /// Expand GEP instructions into add and multiply operations. This allows them /// to be analyzed by regular SCEV code. -const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { +SCEVUse ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { assert(GEP->getSourceElementType()->isSized() && "GEP source element type must be sized"); - SmallVector IndexExprs; + SmallVector IndexExprs; for (Value *Index : GEP->indices()) IndexExprs.push_back(getSCEV(Index)); return getGEPExpr(GEP, IndexExprs); } -APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { +APInt ScalarEvolution::getConstantMultipleImpl(SCEVUse S) { uint64_t BitWidth = getTypeSizeInBits(S->getType()); auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) { return TrailingZeros >= BitWidth @@ -6348,7 +6460,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { if (M->hasNoUnsignedWrap()) { // The result is the product of all operand results. APInt Res = getConstantMultiple(M->getOperand(0)); - for (const SCEV *Operand : M->operands().drop_front()) + for (SCEVUse Operand : M->operands().drop_front()) Res = Res * getConstantMultiple(Operand); return Res; } @@ -6356,7 +6468,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { // If there are no wrap guarentees, find the trailing zeros, which is the // sum of trailing zeros for all its operands. uint32_t TZ = 0; - for (const SCEV *Operand : M->operands()) + for (SCEVUse Operand : M->operands()) TZ += getMinTrailingZeros(Operand); return GetShiftedByZeros(TZ); } @@ -6367,7 +6479,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { return GetGCDMultiple(N); // Find the trailing bits, which is the minimum of its operands. uint32_t TZ = getMinTrailingZeros(N->getOperand(0)); - for (const SCEV *Operand : N->operands().drop_front()) + for (SCEVUse Operand : N->operands().drop_front()) TZ = std::min(TZ, getMinTrailingZeros(Operand)); return GetShiftedByZeros(TZ); } @@ -6391,7 +6503,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { llvm_unreachable("Unknown SCEV kind!"); } -APInt ScalarEvolution::getConstantMultiple(const SCEV *S) { +APInt ScalarEvolution::getConstantMultiple(SCEVUse S) { auto I = ConstantMultipleCache.find(S); if (I != ConstantMultipleCache.end()) return I->second; @@ -6402,12 +6514,12 @@ APInt ScalarEvolution::getConstantMultiple(const SCEV *S) { return InsertPair.first->second; } -APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) { +APInt ScalarEvolution::getNonZeroConstantMultiple(SCEVUse S) { APInt Multiple = getConstantMultiple(S); return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple; } -uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) { +uint32_t ScalarEvolution::getMinTrailingZeros(SCEVUse S) { return std::min(getConstantMultiple(S).countTrailingZeros(), (unsigned)getTypeSizeInBits(S->getType())); } @@ -6558,17 +6670,17 @@ getRangeForUnknownRecurrence(const SCEVUnknown *U) { } const ConstantRange & -ScalarEvolution::getRangeRefIter(const SCEV *S, +ScalarEvolution::getRangeRefIter(SCEVUse S, ScalarEvolution::RangeSignHint SignHint) { - DenseMap &Cache = + DenseMap &Cache = SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; - SmallVector WorkList; - SmallPtrSet Seen; + SmallVector WorkList; + SmallPtrSet Seen; // Add Expr to the worklist, if Expr is either an N-ary expression or a // SCEVUnknown PHI node. - auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) { + auto AddToWorklist = [&WorkList, &Seen, &Cache](SCEVUse Expr) { if (!Seen.insert(Expr).second) return; if (Cache.contains(Expr)) @@ -6603,11 +6715,11 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, // Build worklist by queuing operands of N-ary expressions and phi nodes. for (unsigned I = 0; I != WorkList.size(); ++I) { - const SCEV *P = WorkList[I]; + SCEVUse P = WorkList[I]; auto *UnknownS = dyn_cast(P); // If it is not a `SCEVUnknown`, just recurse into operands. if (!UnknownS) { - for (const SCEV *Op : P->operands()) + for (SCEVUse Op : P->operands()) AddToWorklist(Op); continue; } @@ -6624,7 +6736,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, // Use getRangeRef to compute ranges for items in the worklist in reverse // order. This will force ranges for earlier operands to be computed before // their users in most cases. - for (const SCEV *P : reverse(drop_begin(WorkList))) { + for (SCEVUse P : reverse(drop_begin(WorkList))) { getRangeRef(P, SignHint); if (auto *UnknownS = dyn_cast(P)) @@ -6639,9 +6751,10 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, /// Determine the range for a particular SCEV. If SignHint is /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges /// with a "cleaner" unsigned (resp. signed) representation. -const ConstantRange &ScalarEvolution::getRangeRef( - const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) { - DenseMap &Cache = +const ConstantRange & +ScalarEvolution::getRangeRef(SCEVUse S, ScalarEvolution::RangeSignHint SignHint, + unsigned Depth) { + DenseMap &Cache = SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; ConstantRange::PreferredRangeType RangeType = @@ -6649,7 +6762,7 @@ const ConstantRange &ScalarEvolution::getRangeRef( : ConstantRange::Signed; // See if we've computed this range already. - DenseMap::iterator I = Cache.find(S); + DenseMap::iterator I = Cache.find(S); if (I != Cache.end()) return I->second; @@ -6784,8 +6897,7 @@ const ConstantRange &ScalarEvolution::getRangeRef( // TODO: non-affine addrec if (AddRec->isAffine()) { - const SCEV *MaxBEScev = - getConstantMaxBackedgeTakenCount(AddRec->getLoop()); + SCEVUse MaxBEScev = getConstantMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa(MaxBEScev)) { APInt MaxBECount = cast(MaxBEScev)->getAPInt(); @@ -6812,7 +6924,7 @@ const ConstantRange &ScalarEvolution::getRangeRef( // Now try symbolic BE count and more powerful methods. if (UseExpensiveRangeSharpening) { - const SCEV *SymbolicMaxBECount = + SCEVUse SymbolicMaxBECount = getSymbolicMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa(SymbolicMaxBECount) && getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth && @@ -7041,8 +7153,7 @@ static ConstantRange getRangeForAffineARHelper(APInt Step, return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper)); } -ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, - const SCEV *Step, +ConstantRange ScalarEvolution::getRangeForAffineAR(SCEVUse Start, SCEVUse Step, const APInt &MaxBECount) { assert(getTypeSizeInBits(Start->getType()) == getTypeSizeInBits(Step->getType()) && @@ -7071,13 +7182,13 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, } ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( - const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth, + const SCEVAddRecExpr *AddRec, SCEVUse MaxBECount, unsigned BitWidth, ScalarEvolution::RangeSignHint SignHint) { assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n"); assert(AddRec->hasNoSelfWrap() && "This only works for non-self-wrapping AddRecs!"); const bool IsSigned = SignHint == HINT_RANGE_SIGNED; - const SCEV *Step = AddRec->getStepRecurrence(*this); + SCEVUse Step = AddRec->getStepRecurrence(*this); // Only deal with constant step to save compile time. if (!isa(Step)) return ConstantRange::getFull(BitWidth); @@ -7090,9 +7201,9 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( getTypeSizeInBits(AddRec->getType())) return ConstantRange::getFull(BitWidth); MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType()); - const SCEV *RangeWidth = getMinusOne(AddRec->getType()); - const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step)); - const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs); + SCEVUse RangeWidth = getMinusOne(AddRec->getType()); + SCEVUse StepAbs = getUMinExpr(Step, getNegativeSCEV(Step)); + SCEVUse MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs); if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount, MaxItersWithoutWrap)) return ConstantRange::getFull(BitWidth); @@ -7101,7 +7212,7 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; ICmpInst::Predicate GEPred = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; - const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); + SCEVUse End = AddRec->evaluateAtIteration(MaxBECount, *this); // We know that there is no self-wrap. Let's take Start and End values and // look at all intermediate values V1, V2, ..., Vn that IndVar takes during @@ -7115,7 +7226,7 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that // knowledge, let's try to prove that we are dealing with Case 1. It is so if // Start <= End and step is positive, or Start >= End and step is negative. - const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop()); + SCEVUse Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop()); ConstantRange StartRange = getRangeRef(Start, SignHint); ConstantRange EndRange = getRangeRef(End, SignHint); ConstantRange RangeBetween = StartRange.unionWith(EndRange); @@ -7138,8 +7249,7 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( return ConstantRange::getFull(BitWidth); } -ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, - const SCEV *Step, +ConstantRange ScalarEvolution::getRangeViaFactoring(SCEVUse Start, SCEVUse Step, const APInt &MaxBECount) { // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) @@ -7154,8 +7264,7 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, APInt TrueValue; APInt FalseValue; - explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, - const SCEV *S) { + explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, SCEVUse S) { std::optional CastOp; APInt Offset(BitWidth, 0); @@ -7244,10 +7353,10 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, // FIXME: without the explicit `this` receiver below, MSVC errors out with // C2352 and C2512 (otherwise it isn't needed). - const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); - const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); - const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); - const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); + SCEVUse TrueStart = this->getConstant(StartPattern.TrueValue); + SCEVUse TrueStep = this->getConstant(StepPattern.TrueValue); + SCEVUse FalseStart = this->getConstant(StartPattern.FalseValue); + SCEVUse FalseStep = this->getConstant(StepPattern.FalseValue); ConstantRange TrueRange = this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount); @@ -7273,8 +7382,7 @@ SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; } -const Instruction * -ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) { +const Instruction *ScalarEvolution::getNonTrivialDefiningScopeBound(SCEVUse S) { if (auto *AddRec = dyn_cast(S)) return &*AddRec->getLoop()->getHeader()->begin(); if (auto *U = dyn_cast(S)) @@ -7283,14 +7391,13 @@ ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) { return nullptr; } -const Instruction * -ScalarEvolution::getDefiningScopeBound(ArrayRef Ops, - bool &Precise) { +const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef Ops, + bool &Precise) { Precise = true; // Do a bounded search of the def relation of the requested SCEVs. - SmallSet Visited; - SmallVector Worklist; - auto pushOp = [&](const SCEV *S) { + SmallSet Visited; + SmallVector Worklist; + auto pushOp = [&](SCEVUse S) { if (!Visited.insert(S).second) return; // Threshold of 30 here is arbitrary. @@ -7301,17 +7408,17 @@ ScalarEvolution::getDefiningScopeBound(ArrayRef Ops, Worklist.push_back(S); }; - for (const auto *S : Ops) + for (const auto S : Ops) pushOp(S); const Instruction *Bound = nullptr; while (!Worklist.empty()) { - auto *S = Worklist.pop_back_val(); + auto S = Worklist.pop_back_val(); if (auto *DefI = getNonTrivialDefiningScopeBound(S)) { if (!Bound || DT.dominates(Bound, DefI)) Bound = DefI; } else { - for (const auto *Op : S->operands()) + for (const auto Op : S->operands()) pushOp(Op); } } @@ -7319,7 +7426,7 @@ ScalarEvolution::getDefiningScopeBound(ArrayRef Ops, } const Instruction * -ScalarEvolution::getDefiningScopeBound(ArrayRef Ops) { +ScalarEvolution::getDefiningScopeBound(ArrayRef Ops) { bool Discard; return getDefiningScopeBound(Ops, Discard); } @@ -7374,7 +7481,7 @@ bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { // executed every time we enter that scope. When the bounding scope is a // loop (the common case), this is equivalent to proving I executes on every // iteration of that loop. - SmallVector SCEVOps; + SmallVector SCEVOps; for (const Use &Op : I->operands()) { // I could be an extractvalue from a call to an overflow intrinsic. // TODO: We can do better here in some cases. @@ -7470,7 +7577,7 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); } -const SCEV *ScalarEvolution::createSCEVIter(Value *V) { +SCEVUse ScalarEvolution::createSCEVIter(Value *V) { // Worklist item with a Value and a bool indicating whether all operands have // been visited already. using PointerTy = PointerIntPair; @@ -7486,7 +7593,7 @@ const SCEV *ScalarEvolution::createSCEVIter(Value *V) { continue; SmallVector Ops; - const SCEV *CreatedSCEV = nullptr; + SCEVUse CreatedSCEV = nullptr; // If all operands have been visited already, create the SCEV. if (E.getInt()) { CreatedSCEV = createSCEV(CurV); @@ -7511,8 +7618,8 @@ const SCEV *ScalarEvolution::createSCEVIter(Value *V) { return getExistingSCEV(V); } -const SCEV * -ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getOperandsToCreate(Value *V, + SmallVectorImpl &Ops) { if (!isSCEVable(V->getType())) return getUnknown(V); @@ -7698,7 +7805,7 @@ ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl &Ops) { return nullptr; } -const SCEV *ScalarEvolution::createSCEV(Value *V) { +SCEVUse ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) return getUnknown(V); @@ -7716,8 +7823,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { else if (!isa(V)) return getUnknown(V); - const SCEV *LHS; - const SCEV *RHS; + SCEVUse LHS; + SCEVUse RHS; Operator *U = cast(V); if (auto BO = @@ -7730,10 +7837,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // because it leads to N-1 getAddExpr calls for N ultimate operands. // Instead, gather up all the operands and make a single getAddExpr call. // LLVM IR canonical form means we need only traverse the left operands. - SmallVector AddOps; + SmallVector AddOps; do { if (BO->Op) { - if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + if (auto OpSCEV = getExistingSCEV(BO->Op)) { AddOps.push_back(OpSCEV); break; } @@ -7745,10 +7852,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // since the flags are only known to apply to this particular // addition - they may not apply to other additions that can be // formed with operands from AddOps. - const SCEV *RHS = getSCEV(BO->RHS); + SCEVUse RHS = getSCEV(BO->RHS); SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); if (Flags != SCEV::FlagAnyWrap) { - const SCEV *LHS = getSCEV(BO->LHS); + SCEVUse LHS = getSCEV(BO->LHS); if (BO->Opcode == Instruction::Sub) AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); else @@ -7776,10 +7883,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { } case Instruction::Mul: { - SmallVector MulOps; + SmallVector MulOps; do { if (BO->Op) { - if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + if (auto OpSCEV = getExistingSCEV(BO->Op)) { MulOps.push_back(OpSCEV); break; } @@ -7845,19 +7952,19 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) { - const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); - const SCEV *LHS = getSCEV(BO->LHS); - const SCEV *ShiftedLHS = nullptr; + SCEVUse MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); + SCEVUse LHS = getSCEV(BO->LHS); + SCEVUse ShiftedLHS = nullptr; if (auto *LHSMul = dyn_cast(LHS)) { if (auto *OpC = dyn_cast(LHSMul->getOperand(0))) { // For an expression like (x * 8) & 8, simplify the multiply. unsigned MulZeros = OpC->getAPInt().countr_zero(); unsigned GCD = std::min(MulZeros, TZ); APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); - SmallVector MulOps; + SmallVector MulOps; MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD))); append_range(MulOps, LHSMul->operands().drop_front()); - auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); + auto NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt)); } } @@ -7905,7 +8012,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (const SCEVZeroExtendExpr *Z = dyn_cast(getSCEV(BO->LHS))) { Type *UTy = BO->LHS->getType(); - const SCEV *Z0 = Z->getOperand(); + SCEVUse Z0 = Z->getOperand(); Type *Z0Ty = Z0->getType(); unsigned Z0TySize = getTypeSizeInBits(Z0Ty); @@ -7981,9 +8088,9 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); Operator *L = dyn_cast(BO->LHS); - const SCEV *AddTruncateExpr = nullptr; + SCEVUse AddTruncateExpr = nullptr; ConstantInt *ShlAmtCI = nullptr; - const SCEV *AddConstant = nullptr; + SCEVUse AddConstant = nullptr; if (L && L->getOpcode() == Instruction::Add) { // X = Shl A, n @@ -7995,7 +8102,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { ConstantInt *AddOperandCI = dyn_cast(L->getOperand(1)); if (LShift && LShift->getOpcode() == Instruction::Shl) { if (AddOperandCI) { - const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0)); + SCEVUse ShlOp0SCEV = getSCEV(LShift->getOperand(0)); ShlAmtCI = dyn_cast(LShift->getOperand(1)); // since we truncate to TruncTy, the AddConstant should be of the // same type, so create a new Constant with type same as TruncTy. @@ -8013,7 +8120,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Y = AShr X, m // Both n and m are constant. - const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); + SCEVUse ShlOp0SCEV = getSCEV(L->getOperand(0)); ShlAmtCI = dyn_cast(L->getOperand(1)); AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy); } @@ -8034,8 +8141,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) { APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, ShlAmtCI->getZExtValue() - AShrAmt); - const SCEV *CompositeExpr = - getMulExpr(AddTruncateExpr, getConstant(Mul)); + SCEVUse CompositeExpr = getMulExpr(AddTruncateExpr, getConstant(Mul)); if (L->getOpcode() != Instruction::Shl) CompositeExpr = getAddExpr(CompositeExpr, AddConstant); @@ -8065,8 +8171,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // but by that point the NSW information has potentially been lost. if (BO->Opcode == Instruction::Sub && BO->IsNSW) { Type *Ty = U->getType(); - auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty); - auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty); + auto V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty); + auto V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty); return getMinusSCEV(V1, V2, SCEV::FlagNSW); } } @@ -8080,11 +8186,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case Instruction::PtrToInt: { // Pointer to integer cast is straight-forward, so do model it. - const SCEV *Op = getSCEV(U->getOperand(0)); + SCEVUse Op = getSCEV(U->getOperand(0)); Type *DstIntTy = U->getType(); // But only if effective SCEV (integer) type is wide enough to represent // all possible pointer values. - const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy); + SCEVUse IntOp = getPtrToIntExpr(Op, DstIntTy); if (isa(IntOp)) return getUnknown(V); return IntOp; @@ -8145,15 +8251,15 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { RHS = getSCEV(II->getArgOperand(1)); return getSMinExpr(LHS, RHS); case Intrinsic::usub_sat: { - const SCEV *X = getSCEV(II->getArgOperand(0)); - const SCEV *Y = getSCEV(II->getArgOperand(1)); - const SCEV *ClampedY = getUMinExpr(X, Y); + SCEVUse X = getSCEV(II->getArgOperand(0)); + SCEVUse Y = getSCEV(II->getArgOperand(1)); + SCEVUse ClampedY = getUMinExpr(X, Y); return getMinusSCEV(X, ClampedY, SCEV::FlagNUW); } case Intrinsic::uadd_sat: { - const SCEV *X = getSCEV(II->getArgOperand(0)); - const SCEV *Y = getSCEV(II->getArgOperand(1)); - const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y)); + SCEVUse X = getSCEV(II->getArgOperand(0)); + SCEVUse Y = getSCEV(II->getArgOperand(1)); + SCEVUse ClampedX = getUMinExpr(X, getNotSCEV(Y)); return getAddExpr(ClampedX, Y, SCEV::FlagNUW); } case Intrinsic::start_loop_iterations: @@ -8178,7 +8284,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // -const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) { +SCEVUse ScalarEvolution::getTripCountFromExitCount(SCEVUse ExitCount) { if (isa(ExitCount)) return getCouldNotCompute(); @@ -8189,9 +8295,9 @@ const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) { return getTripCountFromExitCount(ExitCount, EvalTy, nullptr); } -const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount, - Type *EvalTy, - const Loop *L) { +SCEVUse ScalarEvolution::getTripCountFromExitCount(SCEVUse ExitCount, + Type *EvalTy, + const Loop *L) { if (isa(ExitCount)) return getCouldNotCompute(); @@ -8252,7 +8358,7 @@ ScalarEvolution::getSmallConstantTripCount(const Loop *L, unsigned ScalarEvolution::getSmallConstantMaxTripCount( const Loop *L, SmallVectorImpl *Predicates) { - const auto *MaxExitCount = + SCEVUse MaxExitCount = Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates) : getConstantMaxBackedgeTakenCount(L); return getConstantTripCount(dyn_cast(MaxExitCount)); @@ -8273,12 +8379,12 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { } unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, - const SCEV *ExitCount) { + SCEVUse ExitCount) { if (ExitCount == getCouldNotCompute()) return 1; // Get the trip count - const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L)); + SCEVUse TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L)); APInt Multiple = getNonZeroConstantMultiple(TCExpr); // If a trip multiple is huge (>=2^32), the trip count is still divisible by @@ -8306,13 +8412,13 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && "Exiting block must actually branch out of the loop!"); - const SCEV *ExitCount = getExitCount(L, ExitingBlock); + SCEVUse ExitCount = getExitCount(L, ExitingBlock); return getSmallConstantTripMultiple(L, ExitCount); } -const SCEV *ScalarEvolution::getExitCount(const Loop *L, - const BasicBlock *ExitingBlock, - ExitCountKind Kind) { +SCEVUse ScalarEvolution::getExitCount(const Loop *L, + const BasicBlock *ExitingBlock, + ExitCountKind Kind) { switch (Kind) { case Exact: return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); @@ -8324,7 +8430,7 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L, llvm_unreachable("Invalid ExitCountKind!"); } -const SCEV *ScalarEvolution::getPredicatedExitCount( +SCEVUse ScalarEvolution::getPredicatedExitCount( const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl *Predicates, ExitCountKind Kind) { switch (Kind) { @@ -8341,13 +8447,13 @@ const SCEV *ScalarEvolution::getPredicatedExitCount( llvm_unreachable("Invalid ExitCountKind!"); } -const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount( +SCEVUse ScalarEvolution::getPredicatedBackedgeTakenCount( const Loop *L, SmallVectorImpl &Preds) { return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds); } -const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L, - ExitCountKind Kind) { +SCEVUse ScalarEvolution::getBackedgeTakenCount(const Loop *L, + ExitCountKind Kind) { switch (Kind) { case Exact: return getBackedgeTakenInfo(L).getExact(L, this); @@ -8359,12 +8465,12 @@ const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L, llvm_unreachable("Invalid ExitCountKind!"); } -const SCEV *ScalarEvolution::getPredicatedSymbolicMaxBackedgeTakenCount( +SCEVUse ScalarEvolution::getPredicatedSymbolicMaxBackedgeTakenCount( const Loop *L, SmallVectorImpl &Preds) { return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds); } -const SCEV *ScalarEvolution::getPredicatedConstantMaxBackedgeTakenCount( +SCEVUse ScalarEvolution::getPredicatedConstantMaxBackedgeTakenCount( const Loop *L, SmallVectorImpl &Preds) { return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds); } @@ -8426,7 +8532,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // only done to produce more precise results. if (Result.hasAnyInfo()) { // Invalidate any expression using an addrec in this loop. - SmallVector ToForget; + SmallVector ToForget; auto LoopUsersIt = LoopUsers.find(L); if (LoopUsersIt != LoopUsers.end()) append_range(ToForget, LoopUsersIt->second); @@ -8473,7 +8579,7 @@ void ScalarEvolution::forgetAllLoops() { void ScalarEvolution::visitAndClearUsers( SmallVectorImpl &Worklist, SmallPtrSetImpl &Visited, - SmallVectorImpl &ToForget) { + SmallVectorImpl &ToForget) { while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); if (!isSCEVable(I->getType()) && !isa(I)) @@ -8496,7 +8602,7 @@ void ScalarEvolution::forgetLoop(const Loop *L) { SmallVector LoopWorklist(1, L); SmallVector Worklist; SmallPtrSet Visited; - SmallVector ToForget; + SmallVector ToForget; // Iterate over all the loops and sub-loops to drop SCEV information. while (!LoopWorklist.empty()) { @@ -8509,7 +8615,7 @@ void ScalarEvolution::forgetLoop(const Loop *L) { // Drop information about predicated SCEV rewrites for this loop. for (auto I = PredicatedSCEVRewrites.begin(); I != PredicatedSCEVRewrites.end();) { - std::pair Entry = I->first; + std::pair Entry = I->first; if (Entry.second == CurrL) PredicatedSCEVRewrites.erase(I++); else @@ -8545,7 +8651,7 @@ void ScalarEvolution::forgetValue(Value *V) { // Drop information about expressions based on loop-header PHIs. SmallVector Worklist; SmallPtrSet Visited; - SmallVector ToForget; + SmallVector ToForget; Worklist.push_back(I); Visited.insert(I); visitAndClearUsers(Worklist, Visited, ToForget); @@ -8561,14 +8667,14 @@ void ScalarEvolution::forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V) { // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an // extra predecessor is added, this is no longer valid. Find all Unknowns and // AddRecs defined in the loop and invalidate any SCEV's making use of them. - if (const SCEV *S = getExistingSCEV(V)) { + if (SCEVUse S = getExistingSCEV(V)) { struct InvalidationRootCollector { Loop *L; - SmallVector Roots; + SmallVector Roots; InvalidationRootCollector(Loop *L) : L(L) {} - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { if (auto *SU = dyn_cast(S)) { if (auto *I = dyn_cast(SU->getValue())) if (L->contains(I)) @@ -8605,7 +8711,7 @@ void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) { if (!isSCEVable(V->getType())) return; - const SCEV *S = getExistingSCEV(V); + SCEVUse S = getExistingSCEV(V); if (!S) return; @@ -8613,17 +8719,17 @@ void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) { // S's users may change if S's disposition changes (i.e. a user may change to // loop-invariant, if S changes to loop invariant), so also invalidate // dispositions of S's users recursively. - SmallVector Worklist = {S}; - SmallPtrSet Seen = {S}; + SmallVector Worklist = {S}; + SmallPtrSet Seen = {S}; while (!Worklist.empty()) { - const SCEV *Curr = Worklist.pop_back_val(); + SCEVUse Curr = Worklist.pop_back_val(); bool LoopDispoRemoved = LoopDispositions.erase(Curr); bool BlockDispoRemoved = BlockDispositions.erase(Curr); if (!LoopDispoRemoved && !BlockDispoRemoved) continue; auto Users = SCEVUsers.find(Curr); if (Users != SCEVUsers.end()) - for (const auto *User : Users->second) + for (const auto User : Users->second) if (Seen.insert(User).second) Worklist.push_back(User); } @@ -8635,7 +8741,7 @@ void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) { /// is never skipped. This is a valid assumption as long as the loop exits via /// that test. For precise results, it is the caller's responsibility to specify /// the relevant loop exiting block using getExact(ExitingBlock, SE). -const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact( +SCEVUse ScalarEvolution::BackedgeTakenInfo::getExact( const Loop *L, ScalarEvolution *SE, SmallVectorImpl *Preds) const { // If any exits were not computable, the loop is not computable. @@ -8649,9 +8755,9 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact( // All exiting blocks we have gathered dominate loop's latch, so exact trip // count is simply a minimum out of all these calculated exit counts. - SmallVector Ops; + SmallVector Ops; for (const auto &ENT : ExitNotTaken) { - const SCEV *BECount = ENT.ExactNotTaken; + SCEVUse BECount = ENT.ExactNotTaken; assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!"); assert(SE->DT.dominates(ENT.ExitingBlock, Latch) && "We should only have known counts for exiting blocks that dominate " @@ -8690,7 +8796,7 @@ ScalarEvolution::BackedgeTakenInfo::getExitNotTaken( } /// getConstantMax - Get the constant max backedge taken count for the loop. -const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax( +SCEVUse ScalarEvolution::BackedgeTakenInfo::getConstantMax( ScalarEvolution *SE, SmallVectorImpl *Predicates) const { if (!getConstantMax()) @@ -8709,7 +8815,7 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax( return getConstantMax(); } -const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax( +SCEVUse ScalarEvolution::BackedgeTakenInfo::getSymbolicMax( const Loop *L, ScalarEvolution *SE, SmallVectorImpl *Predicates) { if (!SymbolicMax) { @@ -8750,13 +8856,11 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero( return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); } -ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) - : ExitLimit(E, E, E, false) {} +ScalarEvolution::ExitLimit::ExitLimit(SCEVUse E) : ExitLimit(E, E, E, false) {} ScalarEvolution::ExitLimit::ExitLimit( - const SCEV *E, const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, - ArrayRef> PredLists) + SCEVUse E, SCEVUse ConstantMaxNotTaken, SCEVUse SymbolicMaxNotTaken, + bool MaxOrZero, ArrayRef> PredLists) : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken), SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) { // If we prove the max count is zero, so is the symbolic bound. This happens @@ -8795,9 +8899,8 @@ ScalarEvolution::ExitLimit::ExitLimit( "Max backedge count should be int"); } -ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, - const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, +ScalarEvolution::ExitLimit::ExitLimit(SCEVUse E, SCEVUse ConstantMaxNotTaken, + SCEVUse SymbolicMaxNotTaken, bool MaxOrZero, ArrayRef PredList) : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero, @@ -8807,7 +8910,7 @@ ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( ArrayRef ExitCounts, - bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero) + bool IsComplete, SCEVUse ConstantMax, bool MaxOrZero) : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) { using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; @@ -8838,8 +8941,8 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, SmallVector ExitCounts; bool CouldComputeBECount = true; BasicBlock *Latch = L->getLoopLatch(); // may be NULL. - const SCEV *MustExitMaxBECount = nullptr; - const SCEV *MayExitMaxBECount = nullptr; + SCEVUse MustExitMaxBECount = nullptr; + SCEVUse MayExitMaxBECount = nullptr; bool MustExitMaxOrZero = false; bool IsOnlyExit = ExitingBlocks.size() == 1; @@ -8909,8 +9012,10 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, } } } - const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : - (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); + SCEVUse MaxBECount = + MustExitMaxBECount + ? MustExitMaxBECount + : (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); // The loop backedge will be taken the maximum or zero times if there's // a single exit that must be taken the maximum or zero times. bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1); @@ -9072,7 +9177,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( NWR.getEquivalentICmp(Pred, NewRHSC, Offset); if (!ExitIfTrue) Pred = ICmpInst::getInversePredicate(Pred); - auto *LHS = getSCEV(WO->getLHS()); + auto LHS = getSCEV(WO->getLHS()); if (Offset != 0) LHS = getAddExpr(LHS, getConstant(Offset)); auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC), @@ -9117,9 +9222,9 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( if (isa(Op0)) return Op0 == NeutralElement ? EL1 : EL0; - const SCEV *BECount = getCouldNotCompute(); - const SCEV *ConstantMaxBECount = getCouldNotCompute(); - const SCEV *SymbolicMaxBECount = getCouldNotCompute(); + SCEVUse BECount = getCouldNotCompute(); + SCEVUse ConstantMaxBECount = getCouldNotCompute(); + SCEVUse SymbolicMaxBECount = getCouldNotCompute(); if (EitherMayExit) { bool UseSequentialUMin = !isa(ExitCond); // Both conditions must be same for the loop to continue executing. @@ -9177,16 +9282,15 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( Pred = ExitCond->getInversePredicate(); const ICmpInst::Predicate OriginalPred = Pred; - const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); - const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); + SCEVUse LHS = getSCEV(ExitCond->getOperand(0)); + SCEVUse RHS = getSCEV(ExitCond->getOperand(1)); ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; - auto *ExhaustiveCount = - computeExitCountExhaustively(L, ExitCond, ExitIfTrue); + auto ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue); if (!isa(ExhaustiveCount)) return ExhaustiveCount; @@ -9195,7 +9299,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( ExitCond->getOperand(1), L, OriginalPred); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( - const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + const Loop *L, ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, bool ControlsOnlyExit, bool AllowPredicates) { // Try to evaluate any dependencies out of the loop. @@ -9224,7 +9328,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( ConstantRange CompRange = ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt()); - const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); + SCEVUse Ret = AddRec->getNumIterationsInRange(CompRange, *this); if (!isa(Ret)) return Ret; } @@ -9237,7 +9341,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( // because if it did, we'd have an infinite (undefined) loop. // TODO: We can peel off any functions which are invertible *in L*. Loop // invariant terms are effectively constants for our purposes here. - auto *InnerLHS = LHS; + auto InnerLHS = LHS; if (auto *ZExt = dyn_cast(LHS)) InnerLHS = ZExt->getOperand(); if (const SCEVAddRecExpr *AR = dyn_cast(InnerLHS); @@ -9246,7 +9350,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( /*OrNegative=*/true)) { auto Flags = AR->getNoWrapFlags(); Flags = setFlags(Flags, SCEV::FlagNW); - SmallVector Operands{AR->operands()}; + SmallVector Operands{AR->operands()}; Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); setNoWrapFlags(const_cast(AR), Flags); } @@ -9264,7 +9368,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( isKnownPositive(AR->getStepRecurrence(*this))) { auto Flags = AR->getNoWrapFlags(); Flags = setFlags(Flags, WrapType); - SmallVector Operands{AR->operands()}; + SmallVector Operands{AR->operands()}; Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); setNoWrapFlags(const_cast(AR), Flags); } @@ -9379,8 +9483,8 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, assert(L->contains(Switch->getDefaultDest()) && "Default case must not exit the loop!"); - const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); - const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); + SCEVUse LHS = getSCEVAtScope(Switch->getCondition(), L); + SCEVUse RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit); @@ -9393,8 +9497,8 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE) { - const SCEV *InVal = SE.getConstant(C); - const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE); + SCEVUse InVal = SE.getConstant(C); + SCEVUse Val = AddRec->evaluateAtIteration(InVal, SE); assert(isa(Val) && "Evaluation of SCEV at constant didn't fold correctly?"); return cast(Val)->getValue(); @@ -9535,7 +9639,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( if (Result->isZeroValue()) { unsigned BitWidth = getTypeSizeInBits(RHS->getType()); - const SCEV *UpperBound = + SCEVUse UpperBound = getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false); } @@ -9785,9 +9889,9 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, } } -const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, - Value *Cond, - bool ExitWhen) { +SCEVUse ScalarEvolution::computeExitCountExhaustively(const Loop *L, + Value *Cond, + bool ExitWhen) { PHINode *PN = getConstantEvolvingPHI(Cond, L); if (!PN) return getCouldNotCompute(); @@ -9852,9 +9956,8 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, return getCouldNotCompute(); } -const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { - SmallVector, 2> &Values = - ValuesAtScopes[V]; +SCEVUse ScalarEvolution::getSCEVAtScope(SCEVUse V, const Loop *L) { + SmallVector, 2> &Values = ValuesAtScopes[V]; // Check to see if we've folded this expression at this loop before. for (auto &LS : Values) if (LS.first == L) @@ -9863,7 +9966,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { Values.emplace_back(L, nullptr); // Otherwise compute it. - const SCEV *C = computeSCEVAtScope(V, L); + SCEVUse C = computeSCEVAtScope(V, L); for (auto &LS : reverse(ValuesAtScopes[V])) if (LS.first == L) { LS.second = C; @@ -9878,7 +9981,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { /// will return Constants for objects which aren't represented by a /// SCEVConstant, because SCEVConstant is restricted to ConstantInt. /// Returns NULL if the SCEV isn't representable as a Constant. -static Constant *BuildConstantFromSCEV(const SCEV *V) { +static Constant *BuildConstantFromSCEV(SCEVUse V) { switch (V->getSCEVType()) { case scCouldNotCompute: case scAddRecExpr: @@ -9904,7 +10007,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { case scAddExpr: { const SCEVAddExpr *SA = cast(V); Constant *C = nullptr; - for (const SCEV *Op : SA->operands()) { + for (SCEVUse Op : SA->operands()) { Constant *OpC = BuildConstantFromSCEV(Op); if (!OpC) return nullptr; @@ -9939,9 +10042,8 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { llvm_unreachable("Unknown SCEV kind!"); } -const SCEV * -ScalarEvolution::getWithOperands(const SCEV *S, - SmallVectorImpl &NewOps) { +SCEVUse ScalarEvolution::getWithOperands(SCEVUse S, + SmallVectorImpl &NewOps) { switch (S->getSCEVType()) { case scTruncate: case scZeroExtend: @@ -9975,7 +10077,7 @@ ScalarEvolution::getWithOperands(const SCEV *S, llvm_unreachable("Unknown SCEV kind!"); } -const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { +SCEVUse ScalarEvolution::computeSCEVAtScope(SCEVUse V, const Loop *L) { switch (V->getSCEVType()) { case scConstant: case scVScale: @@ -9988,21 +10090,21 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Avoid performing the look-up in the common case where the specified // expression has no loop-variant portions. for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { - const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); + SCEVUse OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); if (OpAtScope == AddRec->getOperand(i)) continue; // Okay, at least one of these operands is loop variant but might be // foldable. Build a new instance of the folded commutative expression. - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(AddRec->getNumOperands()); append_range(NewOps, AddRec->operands().take_front(i)); NewOps.push_back(OpAtScope); for (++i; i != e; ++i) NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); - const SCEV *FoldedRec = getAddRecExpr( - NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); + SCEVUse FoldedRec = getAddRecExpr(NewOps, AddRec->getLoop(), + AddRec->getNoWrapFlags(SCEV::FlagNW)); AddRec = dyn_cast(FoldedRec); // The addrec may be folded to a nonrecurrence, for example, if the // induction variable is multiplied by zero after constant folding. Go @@ -10017,7 +10119,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (!AddRec->getLoop()->contains(L)) { // To evaluate this recurrence, we need to know how many times the AddRec // loop iterates. Compute this now. - const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); + SCEVUse BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); if (BackedgeTakenCount == getCouldNotCompute()) return AddRec; @@ -10039,15 +10141,15 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { case scUMinExpr: case scSMinExpr: case scSequentialUMinExpr: { - ArrayRef Ops = V->operands(); + ArrayRef Ops = V->operands(); // Avoid performing the look-up in the common case where the specified // expression has no loop-variant portions. for (unsigned i = 0, e = Ops.size(); i != e; ++i) { - const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L); + SCEVUse OpAtScope = getSCEVAtScope(Ops[i], L); if (OpAtScope != Ops[i]) { // Okay, at least one of these operands is loop variant but might be // foldable. Build a new instance of the folded commutative expression. - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(Ops.size()); append_range(NewOps, Ops.take_front(i)); NewOps.push_back(OpAtScope); @@ -10080,7 +10182,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // to see if the loop that contains it has a known backedge-taken // count. If so, we may be able to force computation of the exit // value. - const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop); + SCEVUse BackedgeTakenCount = getBackedgeTakenCount(CurrLoop); // This trivial case can show up in some degenerate cases where // the incoming IR has not yet been fully simplified. if (BackedgeTakenCount->isZero()) { @@ -10145,8 +10247,8 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (!isSCEVable(Op->getType())) return V; - const SCEV *OrigV = getSCEV(Op); - const SCEV *OpV = getSCEVAtScope(OrigV, L); + SCEVUse OrigV = getSCEV(Op); + SCEVUse OpV = getSCEVAtScope(OrigV, L); MadeImprovement |= OrigV != OpV; Constant *C = BuildConstantFromSCEV(OpV); @@ -10174,11 +10276,11 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { llvm_unreachable("Unknown SCEV type!"); } -const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { +SCEVUse ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { return getSCEVAtScope(getSCEV(V), L); } -const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { +SCEVUse ScalarEvolution::stripInjectiveFunctions(SCEVUse S) const { if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) return stripInjectiveFunctions(ZExt->getOperand()); if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) @@ -10195,7 +10297,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { /// /// If the equation does not have a solution, SCEVCouldNotCompute is returned. static const SCEV * -SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, +SolveLinEquationWithOverflow(const APInt &A, SCEVUse B, SmallVectorImpl *Predicates, ScalarEvolution &SE) { @@ -10244,7 +10346,7 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, // I * (B / D) mod (N / D) // To simplify the computation, we factor out the divide by D: // (I * B mod N) / D - const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); + SCEVUse D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); } @@ -10521,7 +10623,7 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth); } -ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, +ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(SCEVUse V, const Loop *L, bool ControlsOnlyExit, bool AllowPredicates) { @@ -10535,7 +10637,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. - if (C->getValue()->isZero()) return C; + if (C->getValue()->isZero()) + return SCEVUse(C); return getCouldNotCompute(); // Otherwise it will loop infinitely. } @@ -10599,7 +10702,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, if (!CountDown && !isKnownNonNegative(StepWLG)) return getCouldNotCompute(); - const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); + const SCEV *Distance = + CountDown ? Start : getNegativeSCEV(Start).getPointer(); // Handle unitary steps, which cannot wraparound. // 1*N = -Start; -1*N = Start (mod 2^BW), so: // N = Distance (as unsigned) @@ -10615,9 +10719,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // Explicitly handling this here is necessary because getUnsignedRange // isn't context-sensitive; it doesn't know that we only care about the // range inside the loop. - const SCEV *Zero = getZero(Distance->getType()); - const SCEV *One = getOne(Distance->getType()); - const SCEV *DistancePlusOne = getAddExpr(Distance, One); + SCEVUse Zero = getZero(Distance->getType()); + SCEVUse One = getOne(Distance->getType()); + SCEVUse DistancePlusOne = getAddExpr(Distance, One); if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) { // If Distance + 1 doesn't overflow, we can compute the maximum distance // as "unsigned_max(Distance + 1) - 1". @@ -10642,16 +10746,15 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, if (!loopIsFiniteByAssumption(L) && !isKnownNonZero(StepWLG)) return getCouldNotCompute(); - const SCEV *Exact = - getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - const SCEV *ConstantMax = getCouldNotCompute(); + SCEVUse Exact = getUDivExpr( + Distance, CountDown ? getNegativeSCEV(Step).getPointer() : Step); + SCEVUse ConstantMax = getCouldNotCompute(); if (Exact != getCouldNotCompute()) { APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards)); ConstantMax = getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact))); } - const SCEV *SymbolicMax = - isa(Exact) ? ConstantMax : Exact; + SCEVUse SymbolicMax = isa(Exact) ? ConstantMax : Exact; return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates); } @@ -10659,21 +10762,21 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, const SCEVConstant *StepC = dyn_cast(Step); if (!StepC || StepC->getValue()->isZero()) return getCouldNotCompute(); - const SCEV *E = SolveLinEquationWithOverflow( + SCEVUse E = SolveLinEquationWithOverflow( StepC->getAPInt(), getNegativeSCEV(Start), AllowPredicates ? &Predicates : nullptr, *this); - const SCEV *M = E; + SCEVUse M = E; if (E != getCouldNotCompute()) { APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards)); M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E))); } - auto *S = isa(E) ? M : E; + auto S = isa(E) ? M : E; return ExitLimit(E, M, S, false, Predicates); } -ScalarEvolution::ExitLimit -ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { +ScalarEvolution::ExitLimit ScalarEvolution::howFarToNonZero(SCEVUse V, + const Loop *L) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the // future as needed. @@ -10713,9 +10816,10 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) /// expressions are equal, however for the purposes of looking for a condition /// guarding a loop, it can be useful to be a little more general, since a /// front-end may have replicated the controlling expression. -static bool HasSameValue(const SCEV *A, const SCEV *B) { - // Quick check to see if they are the same SCEV. - if (A == B) return true; +static bool HasSameValue(SCEVUse A, SCEVUse B) { + // Quick check to see if they are the same SCEV, ignoring use-specific flags. + if (A.getPointer() == B.getPointer()) + return true; auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) { // Not all instructions that are "identical" compute the same value. For @@ -10737,7 +10841,7 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { return false; } -static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) { +static bool MatchBinarySub(SCEVUse S, SCEVUse &LHS, SCEVUse &RHS) { const SCEVAddExpr *Add = dyn_cast(S); if (!Add || Add->getNumOperands() != 2) return false; @@ -10757,7 +10861,7 @@ static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) { } bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, - const SCEV *&LHS, const SCEV *&RHS, + SCEVUse &LHS, SCEVUse &RHS, unsigned Depth) { bool Changed = false; // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or @@ -10941,23 +11045,23 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, return Changed; } -bool ScalarEvolution::isKnownNegative(const SCEV *S) { +bool ScalarEvolution::isKnownNegative(SCEVUse S) { return getSignedRangeMax(S).isNegative(); } -bool ScalarEvolution::isKnownPositive(const SCEV *S) { +bool ScalarEvolution::isKnownPositive(SCEVUse S) { return getSignedRangeMin(S).isStrictlyPositive(); } -bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { +bool ScalarEvolution::isKnownNonNegative(SCEVUse S) { return !getSignedRangeMin(S).isNegative(); } -bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { +bool ScalarEvolution::isKnownNonPositive(SCEVUse S) { return !getSignedRangeMax(S).isStrictlyPositive(); } -bool ScalarEvolution::isKnownNonZero(const SCEV *S) { +bool ScalarEvolution::isKnownNonZero(SCEVUse S) { // Query push down for cases where the unsigned range is // less than sufficient. if (const auto *SExt = dyn_cast(S)) @@ -10985,20 +11089,20 @@ bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero, return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S)); } -std::pair -ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) { +std::pair +ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, SCEVUse S) { // Compute SCEV on entry of loop L. - const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this); + SCEVUse Start = SCEVInitRewriter::rewrite(S, L, *this); if (Start == getCouldNotCompute()) return { Start, Start }; // Compute post increment SCEV for loop L. - const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); + SCEVUse PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); assert(PostInc != getCouldNotCompute() && "Unexpected could not compute"); return { Start, PostInc }; } -bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { // First collect all loops. SmallPtrSet LoopsUsed; getUsedLoops(LHS, LoopsUsed); @@ -11047,8 +11151,8 @@ bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first); } -bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { // Canonicalize the inputs first. (void)SimplifyICmpOperands(Pred, LHS, RHS); @@ -11063,8 +11167,8 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, } std::optional ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, + SCEVUse RHS) { if (isKnownPredicate(Pred, LHS, RHS)) return true; if (isKnownPredicate(ICmpInst::getInversePredicate(Pred), LHS, RHS)) @@ -11072,17 +11176,16 @@ std::optional ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred, return std::nullopt; } -bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const Instruction *CtxI) { +bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, const Instruction *CtxI) { // TODO: Analyze guards and assumes from Context's block. return isKnownPredicate(Pred, LHS, RHS) || isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS); } std::optional -ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const Instruction *CtxI) { +ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, const Instruction *CtxI) { std::optional KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS); if (KnownWithoutContext) return KnownWithoutContext; @@ -11098,7 +11201,7 @@ ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, - const SCEV *RHS) { + SCEVUse RHS) { const Loop *L = LHS->getLoop(); return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) && isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS); @@ -11156,7 +11259,7 @@ ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS, if (!LHS->hasNoSignedWrap()) return std::nullopt; - const SCEV *Step = LHS->getStepRecurrence(*this); + SCEVUse Step = LHS->getStepRecurrence(*this); if (isKnownNonNegative(Step)) return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; @@ -11169,7 +11272,7 @@ ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS, std::optional ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, + SCEVUse LHS, SCEVUse RHS, const Loop *L, const Instruction *CtxI) { // If there is a loop-invariant, force it into the RHS, otherwise bail out. @@ -11255,8 +11358,8 @@ ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred, std::optional ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, - const Instruction *CtxI, const SCEV *MaxIter) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, const Loop *L, + const Instruction *CtxI, SCEVUse MaxIter) { if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl( Pred, LHS, RHS, L, CtxI, MaxIter)) return LIP; @@ -11266,7 +11369,7 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations( // work, try the following trick: if the a predicate is invariant for X, it // is also invariant for umin(X, ...). So try to find something that works // among subexpressions of MaxIter expressed as umin. - for (auto *Op : UMin->operands()) + for (auto Op : UMin->operands()) if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl( Pred, LHS, RHS, L, CtxI, Op)) return LIP; @@ -11275,8 +11378,8 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations( std::optional ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, - const Instruction *CtxI, const SCEV *MaxIter) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, const Loop *L, + const Instruction *CtxI, SCEVUse MaxIter) { // Try to prove the following set of facts: // - The predicate is monotonic in the iteration space. // - If the check does not fail on the 1st iteration: @@ -11303,9 +11406,9 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( return std::nullopt; // TODO: Support steps other than +/- 1. - const SCEV *Step = AR->getStepRecurrence(*this); - auto *One = getOne(Step->getType()); - auto *MinusOne = getNegativeSCEV(One); + SCEVUse Step = AR->getStepRecurrence(*this); + auto One = getOne(Step->getType()); + auto MinusOne = getNegativeSCEV(One); if (Step != One && Step != MinusOne) return std::nullopt; @@ -11316,7 +11419,7 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( return std::nullopt; // Value of IV on suggested last iteration. - const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this); + SCEVUse Last = AR->evaluateAtIteration(MaxIter, *this); // Does it still meet the requirement? if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS)) return std::nullopt; @@ -11329,7 +11432,7 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; if (Step == MinusOne) NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred); - const SCEV *Start = AR->getStart(); + SCEVUse Start = AR->getStart(); if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI)) return std::nullopt; @@ -11338,7 +11441,7 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( } bool ScalarEvolution::isKnownPredicateViaConstantRanges( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS) { if (HasSameValue(LHS, RHS)) return ICmpInst::isTrueWhenEqual(Pred); @@ -11364,7 +11467,7 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges( auto UR = getUnsignedRange(RHS); if (CheckRanges(UL, UR)) return true; - auto *Diff = getMinusSCEV(LHS, RHS); + auto Diff = getMinusSCEV(LHS, RHS); return !isa(Diff) && isKnownNonZero(Diff); } @@ -11380,17 +11483,16 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges( } bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { // Match X to (A + C1) and Y to (A + C2), where // C1 and C2 are constant integers. If either X or Y are not add expressions, // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via // OutC1 and OutC2. - auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y, - APInt &OutC1, APInt &OutC2, + auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1, + APInt &OutC2, SCEV::NoWrapFlags ExpectedFlags) { - const SCEV *XNonConstOp, *XConstOp; - const SCEV *YNonConstOp, *YConstOp; + SCEVUse XNonConstOp, XConstOp; + SCEVUse YNonConstOp, YConstOp; SCEV::NoWrapFlags XFlagsPresent; SCEV::NoWrapFlags YFlagsPresent; @@ -11473,8 +11575,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, } bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate) return false; @@ -11495,8 +11596,8 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, } bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { + ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { // No need to even try if we know the module has no guards. if (!HasGuards) return false; @@ -11514,10 +11615,9 @@ bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is /// protected by a conditional between LHS and RHS. This is used to /// to eliminate casts. -bool -ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS) { // Interpret a null as meaning no loop, where there is obviously no guard // (interprocedural conditions notwithstanding). Do not bother about // unreachable loops. @@ -11553,15 +11653,15 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // See if we can exploit a trip count to prove the predicate. const auto &BETakenInfo = getBackedgeTakenInfo(L); - const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this); + SCEVUse LatchBECount = BETakenInfo.getExact(Latch, this); if (LatchBECount != getCouldNotCompute()) { // We know that Latch branches back to the loop header exactly // LatchBECount times. This means the backdege condition at Latch is // equivalent to "{0,+,1} u< LatchBECount". Type *Ty = LatchBECount->getType(); auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW); - const SCEV *LoopCounter = - getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); + SCEVUse LoopCounter = + getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter, LatchBECount)) return true; @@ -11622,8 +11722,7 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { // Do not bother proving facts for unreachable code. if (!DT.isReachableFromEntry(BB)) return true; @@ -11722,8 +11821,7 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { // Interpret a null as meaning no loop, where there is obviously no guard // (interprocedural conditions notwithstanding). if (!L) @@ -11741,10 +11839,9 @@ bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS); } -bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, - const Value *FoundCondValue, bool Inverse, - const Instruction *CtxI) { +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, const Value *FoundCondValue, + bool Inverse, const Instruction *CtxI) { // False conditions implies anything. Do not bother analyzing it further. if (FoundCondValue == ConstantInt::getBool(FoundCondValue->getContext(), Inverse)) @@ -11779,16 +11876,15 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, else FoundPred = ICI->getPredicate(); - const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); - const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); + SCEVUse FoundLHS = getSCEV(ICI->getOperand(0)); + SCEVUse FoundRHS = getSCEV(ICI->getOperand(1)); return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI); } -bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, - ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, const SCEV *FoundRHS, +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, ICmpInst::Predicate FoundPred, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) { // Balance the types. if (getTypeSizeInBits(LHS->getType()) < @@ -11801,14 +11897,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, auto *NarrowType = LHS->getType(); auto *WideType = FoundLHS->getType(); auto BitWidth = getTypeSizeInBits(NarrowType); - const SCEV *MaxValue = getZeroExtendExpr( + SCEVUse MaxValue = getZeroExtendExpr( getConstant(APInt::getMaxValue(BitWidth)), WideType); if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) && isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) { - const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType); - const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType); + SCEVUse TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType); + SCEVUse TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType); if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS, TruncFoundRHS, CtxI)) return true; @@ -11840,10 +11936,12 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, FoundRHS, CtxI); } -bool ScalarEvolution::isImpliedCondBalancedTypes( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS, - const Instruction *CtxI) { +bool ScalarEvolution::isImpliedCondBalancedTypes(ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS, + ICmpInst::Predicate FoundPred, + SCEVUse FoundLHS, + SCEVUse FoundRHS, + const Instruction *CtxI) { assert(getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(FoundLHS->getType()) && "Types should be balanced!"); @@ -11921,8 +12019,8 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( // Create local copies that we can freely swap and canonicalize our // conditions to "le/lt". ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred; - const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS, - *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS; + SCEVUse CanonicalLHS = LHS, CanonicalRHS = RHS, + CanonicalFoundLHS = FoundLHS, CanonicalFoundRHS = FoundRHS; if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) { CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred); CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred); @@ -11955,7 +12053,7 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( (isa(FoundLHS) || isa(FoundRHS))) { const SCEVConstant *C = nullptr; - const SCEV *V = nullptr; + SCEVUse V = nullptr; if (isa(FoundLHS)) { C = cast(FoundLHS); @@ -12044,8 +12142,7 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( return false; } -bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, - const SCEV *&L, const SCEV *&R, +bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R, SCEV::NoWrapFlags &Flags) { const auto *AE = dyn_cast(Expr); if (!AE || AE->getNumOperands() != 2) @@ -12057,8 +12154,8 @@ bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, return true; } -std::optional -ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) { +std::optional ScalarEvolution::computeConstantDifference(SCEVUse More, + SCEVUse Less) { // We avoid subtracting expressions here because this function is usually // fairly deep in the call stack (i.e. is called many times). @@ -12175,8 +12272,8 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) { } bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS, const Instruction *CtxI) { // Try to recognize the following pattern: // // FoundRHS = ... @@ -12220,8 +12317,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( } bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, const SCEV *FoundRHS) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS) { if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT) return false; @@ -12298,10 +12395,9 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( getConstant(FoundRHSLimit)); } -bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS, unsigned Depth) { +bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS, unsigned Depth) { const PHINode *LPhi = nullptr, *RPhi = nullptr; auto ClearOnExit = make_scope_exit([&]() { @@ -12355,7 +12451,7 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, const BasicBlock *LBB = LPhi->getParent(); const SCEVAddRecExpr *RAR = dyn_cast(RHS); - auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) { + auto ProvedEasily = [&](SCEVUse S1, SCEVUse S2) { return isKnownViaNonRecursiveReasoning(Pred, S1, S2) || isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) || isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth); @@ -12367,8 +12463,8 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, // the predicate is true for incoming values from this block, then the // predicate is also true for the Phis. for (const BasicBlock *IncBB : predecessors(LBB)) { - const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); - const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB)); + SCEVUse L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); + SCEVUse R = getSCEV(RPhi->getIncomingValueForBlock(IncBB)); if (!ProvedEasily(L, R)) return false; } @@ -12383,12 +12479,12 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, auto *RLoop = RAR->getLoop(); auto *Predecessor = RLoop->getLoopPredecessor(); assert(Predecessor && "Loop with AddRec with no predecessor?"); - const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor)); + SCEVUse L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor)); if (!ProvedEasily(L1, RAR->getStart())) return false; auto *Latch = RLoop->getLoopLatch(); assert(Latch && "Loop with AddRec with no latch?"); - const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch)); + SCEVUse L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch)); if (!ProvedEasily(L2, RAR->getPostIncExpr(*this))) return false; } else { @@ -12400,7 +12496,7 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, // Check that RHS is available in this block. if (!dominates(RHS, IncBB)) return false; - const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); + SCEVUse L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); // Make sure L does not refer to a value from a potentially previous // iteration of a loop. if (!properlyDominates(L, LBB)) @@ -12413,10 +12509,9 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, } bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS) { + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, + SCEVUse FoundRHS) { // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make // sure that we are dealing with same LHS. if (RHS == FoundRHS) { @@ -12436,7 +12531,7 @@ bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, using namespace PatternMatch; if (match(SUFoundRHS->getValue(), m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) { - auto *ShifteeS = getSCEV(Shiftee); + auto ShifteeS = getSCEV(Shiftee); // Prove one of the following: // LHS > shiftvalue) && shiftee <=u RHS ---> LHS > shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS @@ -12455,9 +12550,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, } bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS, + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) { if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS)) return true; @@ -12478,8 +12572,7 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, /// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values? template -static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, - const SCEV *Candidate) { +static bool IsMinMaxConsistingOf(SCEVUse MaybeMinMaxExpr, SCEVUse Candidate) { const MinMaxExprType *MinMaxExpr = dyn_cast(MaybeMinMaxExpr); if (!MinMaxExpr) return false; @@ -12489,7 +12582,7 @@ static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { // If both sides are affine addrecs for the same loop, with equal // steps, and we know the recurrences don't wrap, then we only // need to check the predicate on the starting values. @@ -12522,8 +12615,8 @@ static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max /// expression? static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { + ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { switch (Pred) { default: return false; @@ -12554,9 +12647,8 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, } bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS, + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, unsigned Depth) { assert(getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(RHS->getType()) && @@ -12584,7 +12676,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us // use this fact to prove that LHS and RHS are non-negative. - const SCEV *MinusOne = getMinusOne(LHS->getType()); + SCEVUse MinusOne = getMinusOne(LHS->getType()); if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS, FoundRHS) && isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS, @@ -12595,7 +12687,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, if (Pred != ICmpInst::ICMP_SGT) return false; - auto GetOpFromSExt = [&](const SCEV *S) { + auto GetOpFromSExt = [&](SCEVUse S) { if (auto *Ext = dyn_cast(S)) return Ext->getOperand(); // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off @@ -12604,13 +12696,13 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, }; // Acquire values from extensions. - auto *OrigLHS = LHS; - auto *OrigFoundLHS = FoundLHS; + auto OrigLHS = LHS; + auto OrigFoundLHS = FoundLHS; LHS = GetOpFromSExt(LHS); FoundLHS = GetOpFromSExt(FoundLHS); // Is the SGT predicate can be proved trivially or using the found context. - auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { + auto IsSGTViaContext = [&](SCEVUse S1, SCEVUse S2) { return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) || isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, FoundRHS, Depth + 1); @@ -12629,12 +12721,12 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, if (!LHSAddExpr->hasNoSignedWrap()) return false; - auto *LL = LHSAddExpr->getOperand(0); - auto *LR = LHSAddExpr->getOperand(1); - auto *MinusOne = getMinusOne(RHS->getType()); + auto LL = LHSAddExpr->getOperand(0); + auto LR = LHSAddExpr->getOperand(1); + auto MinusOne = getMinusOne(RHS->getType()); // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. - auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { + auto IsSumGreaterThanRHS = [&](SCEVUse S1, SCEVUse S2) { return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS); }; // Try to prove the following rule: @@ -12664,7 +12756,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // We want to make sure that LHS = FoundLHS / Denominator. If it is so, // then a SCEV for the numerator already exists and matches with FoundLHS. - auto *Numerator = getExistingSCEV(LL); + auto Numerator = getExistingSCEV(LL); if (!Numerator || Numerator->getType() != FoundLHS->getType()) return false; @@ -12685,14 +12777,14 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // Given that: // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. auto *WTy = getWiderType(DTy, FRHSTy); - auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); - auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); + auto DenominatorExt = getNoopOrSignExtend(Denominator, WTy); + auto FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); // Try to prove the following rule: // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). // For example, given that FoundLHS > 2. It means that FoundLHS is at // least 3. If we divide it by Denominator < 4, we will have at least 1. - auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); + auto DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); if (isKnownNonPositive(RHS) && IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) return true; @@ -12704,8 +12796,8 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // 1. If FoundLHS is negative, then the result is 0. // 2. If FoundLHS is non-negative, then the result is non-negative. // Anyways, the result is non-negative. - auto *MinusOne = getMinusOne(WTy); - auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); + auto MinusOne = getMinusOne(WTy); + auto NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); if (isKnownNegative(RHS) && IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) return true; @@ -12721,8 +12813,8 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, return false; } -static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { // zext x u<= sext x, sext x s<= zext x switch (Pred) { case ICmpInst::ICMP_SGE: @@ -12753,9 +12845,9 @@ static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, return false; } -bool -ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, + SCEVUse LHS, + SCEVUse RHS) { return isKnownPredicateExtendIdiom(Pred, LHS, RHS) || isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || @@ -12763,11 +12855,10 @@ ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, isKnownPredicateViaNoOverflow(Pred, LHS, RHS); } -bool -ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS) { +bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, + SCEVUse FoundRHS) { switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -12808,12 +12899,9 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } -bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS, - ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, - const SCEV *FoundRHS) { +bool ScalarEvolution::isImpliedCondOperandsViaRanges( + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + ICmpInst::Predicate FoundPred, SCEVUse FoundLHS, SCEVUse FoundRHS) { if (!isa(RHS) || !isa(FoundRHS)) // The restriction on `FoundRHS` be lifted easily -- it exists only to // reduce the compile time impact of this optimization. @@ -12841,12 +12929,12 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, return LHSRange.icmp(Pred, ConstRHS); } -bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, +bool ScalarEvolution::canIVOverflowOnLT(SCEVUse RHS, SCEVUse Stride, bool IsSigned) { assert(isKnownPositive(Stride) && "Positive stride expected!"); unsigned BitWidth = getTypeSizeInBits(RHS->getType()); - const SCEV *One = getOne(Stride->getType()); + SCEVUse One = getOne(Stride->getType()); if (IsSigned) { APInt MaxRHS = getSignedRangeMax(RHS); @@ -12865,11 +12953,11 @@ bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS); } -bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, +bool ScalarEvolution::canIVOverflowOnGT(SCEVUse RHS, SCEVUse Stride, bool IsSigned) { unsigned BitWidth = getTypeSizeInBits(RHS->getType()); - const SCEV *One = getOne(Stride->getType()); + SCEVUse One = getOne(Stride->getType()); if (IsSigned) { APInt MinRHS = getSignedRangeMin(RHS); @@ -12888,20 +12976,18 @@ bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS); } -const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) { +SCEVUse ScalarEvolution::getUDivCeilSCEV(SCEVUse N, SCEVUse D) { // umin(N, 1) + floor((N - umin(N, 1)) / D) // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin // expression fixes the case of N=0. - const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType())); - const SCEV *NMinusOne = getMinusSCEV(N, MinNOne); + SCEVUse MinNOne = getUMinExpr(N, getOne(N->getType())); + SCEVUse NMinusOne = getMinusSCEV(N, MinNOne); return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D)); } -const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, - const SCEV *Stride, - const SCEV *End, - unsigned BitWidth, - bool IsSigned) { +SCEVUse ScalarEvolution::computeMaxBECountForLT(SCEVUse Start, SCEVUse Stride, + SCEVUse End, unsigned BitWidth, + bool IsSigned) { // The logic in this function assumes we can represent a positive stride. // If we can't, the backedge-taken count must be zero. if (IsSigned && BitWidth == 1) @@ -12947,9 +13033,9 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, } ScalarEvolution::ExitLimit -ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool IsSigned, - bool ControlsOnlyExit, bool AllowPredicates) { +ScalarEvolution::howManyLessThans(SCEVUse LHS, SCEVUse RHS, const Loop *L, + bool IsSigned, bool ControlsOnlyExit, + bool AllowPredicates) { SmallVector Predicates; const SCEVAddRecExpr *IV = dyn_cast(LHS); @@ -12993,11 +13079,11 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (AR->hasNoUnsignedWrap()) { // Emulate what getZeroExtendExpr would have done during construction // if we'd been able to infer the fact just above at that time. - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Step = AR->getStepRecurrence(*this); Type *Ty = ZExt->getType(); - auto *S = getAddRecExpr( - getExtendAddRecStart(AR, Ty, this, 0), - getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags()); + auto S = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, 0), + getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags()); IV = dyn_cast(S); } } @@ -13031,7 +13117,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType); ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; - const SCEV *Stride = IV->getStepRecurrence(*this); + SCEVUse Stride = IV->getStepRecurrence(*this); bool PositiveStride = isKnownPositive(Stride); @@ -13097,7 +13183,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // Note: The (Start - Stride) term is used to get the start' term from // (start' + stride,+,stride). Remember that we only care about the // result of this expression when stride == 0 at runtime. - auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride); + auto StartIfZero = getMinusSCEV(IV->getStart(), Stride); return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS); }; if (!wouldZeroStrideBeUB()) { @@ -13120,14 +13206,14 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // before any possible exit. // Note that we have not yet proved RHS invariant (in general). - const SCEV *Start = IV->getStart(); + SCEVUse Start = IV->getStart(); // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond. // If we convert to integers, isLoopEntryGuardedByCond will miss some cases. // Use integer-typed versions for actual computation; we can't subtract // pointers in general. - const SCEV *OrigStart = Start; - const SCEV *OrigRHS = RHS; + SCEVUse OrigStart = Start; + SCEVUse OrigRHS = RHS; if (Start->getType()->isPointerTy()) { Start = getLosslessPtrToIntExpr(Start); if (isa(Start)) @@ -13139,8 +13225,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, return RHS; } - const SCEV *End = nullptr, *BECount = nullptr, - *BECountIfBackedgeTaken = nullptr; + SCEVUse End = nullptr, BECount = nullptr, BECountIfBackedgeTaken = nullptr; if (!isLoopInvariant(RHS, L)) { const auto *RHSAddRec = dyn_cast(RHS); if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L && @@ -13199,7 +13284,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // backedge count, as if the backedge is taken at least once // max(End,Start) is End and so the result is as above, and if not // max(End,Start) is Start so we get a backedge count of zero. - auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride); + const auto OrigStartMinusStride = getMinusSCEV(OrigStart, Stride); assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!"); assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!"); assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!"); @@ -13251,7 +13336,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // // FIXME: Should isLoopEntryGuardedByCond do this for us? auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; - auto *StartMinusOne = + auto StartMinusOne = getAddExpr(OrigStart, getMinusOne(OrigStart->getType())); return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne); }; @@ -13361,7 +13446,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, } } - const SCEV *ConstantMaxBECount; + SCEVUse ConstantMaxBECount; bool MaxOrZero = false; if (isa(BECount)) { ConstantMaxBECount = BECount; @@ -13381,15 +13466,16 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, !isa(BECount)) ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - const SCEV *SymbolicMaxBECount = + SCEVUse SymbolicMaxBECount = isa(BECount) ? ConstantMaxBECount : BECount; return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero, Predicates); } -ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( - const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsOnlyExit, bool AllowPredicates) { +ScalarEvolution::ExitLimit +ScalarEvolution::howManyGreaterThans(SCEVUse LHS, SCEVUse RHS, const Loop *L, + bool IsSigned, bool ControlsOnlyExit, + bool AllowPredicates) { SmallVector Predicates; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) @@ -13410,7 +13496,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType); ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; - const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); + SCEVUse Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); // Avoid negative or zero stride values if (!isKnownPositive(Stride)) @@ -13424,8 +13510,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( if (canIVOverflowOnGT(RHS, Stride, IsSigned)) return getCouldNotCompute(); - const SCEV *Start = IV->getStart(); - const SCEV *End = RHS; + SCEVUse Start = IV->getStart(); + SCEVUse End = RHS; if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { // If we know that Start >= RHS in the context of loop, then we know that // min(RHS, Start) = RHS at this point. @@ -13450,8 +13536,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( // Compute ((Start - End) + (Stride - 1)) / Stride. // FIXME: This can overflow. Holding off on fixing this for now; // howManyGreaterThans will hopefully be gone soon. - const SCEV *One = getOne(Stride->getType()); - const SCEV *BECount = getUDivExpr( + SCEVUse One = getOne(Stride->getType()); + SCEVUse BECount = getUDivExpr( getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride); APInt MaxStart = IsSigned ? getSignedRangeMax(Start) @@ -13471,7 +13557,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit) : APIntOps::umax(getUnsignedRangeMin(RHS), Limit); - const SCEV *ConstantMaxBECount = + SCEVUse ConstantMaxBECount = isa(BECount) ? BECount : getUDivCeilSCEV(getConstant(MaxStart - MinEnd), @@ -13479,25 +13565,25 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( if (isa(ConstantMaxBECount)) ConstantMaxBECount = BECount; - const SCEV *SymbolicMaxBECount = + SCEVUse SymbolicMaxBECount = isa(BECount) ? ConstantMaxBECount : BECount; return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, Predicates); } -const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, - ScalarEvolution &SE) const { +SCEVUse SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, + ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop. return SE.getCouldNotCompute(); // If the start is a non-zero constant, shift the range to simplify things. if (const SCEVConstant *SC = dyn_cast(getStart())) if (!SC->getValue()->isZero()) { - SmallVector Operands(operands()); + SmallVector Operands(operands()); Operands[0] = SE.getZero(SC->getType()); - const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), - getNoWrapFlags(FlagNW)); + SCEVUse Shifted = + SE.getAddRecExpr(Operands, getLoop(), getNoWrapFlags(FlagNW)); if (const auto *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( Range.subtract(SC->getAPInt()), SE); @@ -13507,7 +13593,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, // The only time we can solve this is when we have all constant indices. // Otherwise, we cannot determine the overflow conditions. - if (any_of(operands(), [](const SCEV *Op) { return !isa(Op); })) + if (any_of(operands(), [](SCEVUse Op) { return !isa(Op); })) return SE.getCouldNotCompute(); // Okay at this point we know that all elements of the chrec are constants and @@ -13566,7 +13652,7 @@ SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { // simplification: it is legal to return ({rec1} + {rec2}). For example, it // may happen if we reach arithmetic depth limit while simplifying. So we // construct the returned value explicitly. - SmallVector Ops; + SmallVector Ops; // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and // (this + Step) is {A+B,+,B+C,+...,+,N}. for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i) @@ -13575,7 +13661,7 @@ SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { // have been popped out earlier). This guarantees us that if the result has // the same last operand, then it will also not be popped out, meaning that // the returned value will be an AddRec. - const SCEV *Last = getOperand(getNumOperands() - 1); + SCEVUse Last = getOperand(getNumOperands() - 1); assert(!Last->isZero() && "Recurrency with zero step?"); Ops.push_back(Last); return cast(SE.getAddRecExpr(Ops, getLoop(), @@ -13583,8 +13669,8 @@ SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { } // Return true when S contains at least an undef value. -bool ScalarEvolution::containsUndefs(const SCEV *S) const { - return SCEVExprContains(S, [](const SCEV *S) { +bool ScalarEvolution::containsUndefs(SCEVUse S) const { + return SCEVExprContains(S, [](SCEVUse S) { if (const auto *SU = dyn_cast(S)) return isa(SU->getValue()); return false; @@ -13592,8 +13678,8 @@ bool ScalarEvolution::containsUndefs(const SCEV *S) const { } // Return true when S contains a value that is a nullptr. -bool ScalarEvolution::containsErasedValue(const SCEV *S) const { - return SCEVExprContains(S, [](const SCEV *S) { +bool ScalarEvolution::containsErasedValue(SCEVUse S) const { + return SCEVExprContains(S, [](SCEVUse S) { if (const auto *SU = dyn_cast(S)) return SU->getValue() == nullptr; return false; @@ -13601,7 +13687,7 @@ bool ScalarEvolution::containsErasedValue(const SCEV *S) const { } /// Return the size of an element read or written by Inst. -const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { +SCEVUse ScalarEvolution::getElementSize(Instruction *Inst) { Type *Ty; if (StoreInst *Store = dyn_cast(Inst)) Ty = Store->getValueOperand()->getType(); @@ -13649,6 +13735,7 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI), CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64) { + CouldNotCompute->setCanonical(&*CouldNotCompute); // To use guards for proving predicates, we need to scan every instruction in // relevant basic blocks, and not just terminators. Doing this is a waste of // time if the IR does not actually contain any calls to @@ -13710,6 +13797,19 @@ ScalarEvolution::~ScalarEvolution() { HasRecMap.clear(); BackedgeTakenCounts.clear(); PredicatedBackedgeTakenCounts.clear(); + UnsignedRanges.clear(); + SignedRanges.clear(); + + BECountUsers.clear(); + SCEVUsers.clear(); + FoldCache.clear(); + FoldCacheUser.clear(); + ValuesAtScopes.clear(); + ValuesAtScopesUsers.clear(); + LoopDispositions.clear(); + + BlockDispositions.clear(); + ConstantMultipleCache.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); assert(PendingPhiRanges.empty() && "getRangeRef garbage"); @@ -13745,7 +13845,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, if (ExitingBlocks.size() != 1) OS << " "; - auto *BTC = SE->getBackedgeTakenCount(L); + auto BTC = SE->getBackedgeTakenCount(L); if (!isa(BTC)) { OS << "backedge-taken count is "; PrintSCEVWithTypeHint(OS, BTC); @@ -13778,7 +13878,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, L->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": "; - auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L); + auto ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L); if (!isa(ConstantBTC)) { OS << "constant max backedge-taken count is "; PrintSCEVWithTypeHint(OS, ConstantBTC); @@ -13793,7 +13893,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, L->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": "; - auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L); + auto SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L); if (!isa(SymbolicBTC)) { OS << "symbolic max backedge-taken count is "; PrintSCEVWithTypeHint(OS, SymbolicBTC); @@ -13807,8 +13907,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, if (ExitingBlocks.size() > 1) for (BasicBlock *ExitingBlock : ExitingBlocks) { OS << " symbolic max exit count for " << ExitingBlock->getName() << ": "; - auto *ExitBTC = SE->getExitCount(L, ExitingBlock, - ScalarEvolution::SymbolicMaximum); + auto ExitBTC = + SE->getExitCount(L, ExitingBlock, ScalarEvolution::SymbolicMaximum); PrintSCEVWithTypeHint(OS, ExitBTC); if (isa(ExitBTC)) { // Retry with predicates. @@ -13828,7 +13928,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, } SmallVector Preds; - auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds); + auto PBT = SE->getPredicatedBackedgeTakenCount(L, Preds); if (PBT != BTC) { assert(!Preds.empty() && "Different predicated BTC, but no predicates"); OS << "Loop "; @@ -13846,7 +13946,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, } Preds.clear(); - auto *PredConstantMax = + auto PredConstantMax = SE->getPredicatedConstantMaxBackedgeTakenCount(L, Preds); if (PredConstantMax != ConstantBTC) { assert(!Preds.empty() && @@ -13866,7 +13966,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, } Preds.clear(); - auto *PredSymbolicMax = + auto PredSymbolicMax = SE->getPredicatedSymbolicMaxBackedgeTakenCount(L, Preds); if (SymbolicBTC != PredSymbolicMax) { assert(!Preds.empty() && @@ -13942,7 +14042,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { if (isSCEVable(I.getType()) && !isa(I)) { OS << I << '\n'; OS << " --> "; - const SCEV *SV = SE.getSCEV(&I); + SCEVUse SV = SE.getSCEV(&I); SV->print(OS); if (!isa(SV)) { OS << " U: "; @@ -13953,7 +14053,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { const Loop *L = LI.getLoopFor(I.getParent()); - const SCEV *AtUse = SE.getSCEVAtScope(SV, L); + SCEVUse AtUse = SE.getSCEVAtScope(SV, L); if (AtUse != SV) { OS << " --> "; AtUse->print(OS); @@ -13967,7 +14067,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { if (L) { OS << "\t\t" "Exits: "; - const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); + SCEVUse ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); if (!SE.isLoopInvariant(ExitValue, L)) { OS << "<>"; } else { @@ -14016,7 +14116,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { } ScalarEvolution::LoopDisposition -ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { +ScalarEvolution::getLoopDisposition(SCEVUse S, const Loop *L) { auto &Values = LoopDispositions[S]; for (auto &V : Values) { if (V.getPointer() == L) @@ -14035,7 +14135,7 @@ ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { } ScalarEvolution::LoopDisposition -ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { +ScalarEvolution::computeLoopDisposition(SCEVUse S, const Loop *L) { switch (S->getSCEVType()) { case scConstant: case scVScale: @@ -14063,7 +14163,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { // This recurrence is variant w.r.t. L if any of its operands // are variant. - for (const auto *Op : AR->operands()) + for (const auto Op : AR->operands()) if (!isLoopInvariant(Op, L)) return LoopVariant; @@ -14083,7 +14183,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { case scSMinExpr: case scSequentialUMinExpr: { bool HasVarying = false; - for (const auto *Op : S->operands()) { + for (const auto Op : S->operands()) { LoopDisposition D = getLoopDisposition(Op, L); if (D == LoopVariant) return LoopVariant; @@ -14106,16 +14206,16 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { llvm_unreachable("Unknown SCEV kind!"); } -bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) { +bool ScalarEvolution::isLoopInvariant(SCEVUse S, const Loop *L) { return getLoopDisposition(S, L) == LoopInvariant; } -bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { +bool ScalarEvolution::hasComputableLoopEvolution(SCEVUse S, const Loop *L) { return getLoopDisposition(S, L) == LoopComputable; } ScalarEvolution::BlockDisposition -ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { +ScalarEvolution::getBlockDisposition(SCEVUse S, const BasicBlock *BB) { auto &Values = BlockDispositions[S]; for (auto &V : Values) { if (V.getPointer() == BB) @@ -14134,7 +14234,7 @@ ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { } ScalarEvolution::BlockDisposition -ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { +ScalarEvolution::computeBlockDisposition(SCEVUse S, const BasicBlock *BB) { switch (S->getSCEVType()) { case scConstant: case scVScale: @@ -14164,7 +14264,7 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { case scSMinExpr: case scSequentialUMinExpr: { bool Proper = true; - for (const SCEV *NAryOp : S->operands()) { + for (SCEVUse NAryOp : S->operands()) { BlockDisposition D = getBlockDisposition(NAryOp, BB); if (D == DoesNotDominateBlock) return DoesNotDominateBlock; @@ -14189,16 +14289,16 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { llvm_unreachable("Unknown SCEV kind!"); } -bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) { +bool ScalarEvolution::dominates(SCEVUse S, const BasicBlock *BB) { return getBlockDisposition(S, BB) >= DominatesBlock; } -bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { +bool ScalarEvolution::properlyDominates(SCEVUse S, const BasicBlock *BB) { return getBlockDisposition(S, BB) == ProperlyDominatesBlock; } -bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { - return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); +bool ScalarEvolution::hasOperand(SCEVUse S, SCEVUse Op) const { + return SCEVExprContains(S, [&](SCEVUse Expr) { return Expr == Op; }); } void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L, @@ -14208,7 +14308,7 @@ void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L, auto It = BECounts.find(L); if (It != BECounts.end()) { for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) { - for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { + for (SCEVUse S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { if (!isa(S)) { auto UserIt = BECountUsers.find(S); assert(UserIt != BECountUsers.end()); @@ -14220,25 +14320,25 @@ void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L, } } -void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { - SmallPtrSet ToForget(SCEVs.begin(), SCEVs.end()); - SmallVector Worklist(ToForget.begin(), ToForget.end()); +void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { + SmallPtrSet ToForget(SCEVs.begin(), SCEVs.end()); + SmallVector Worklist(ToForget.begin(), ToForget.end()); while (!Worklist.empty()) { - const SCEV *Curr = Worklist.pop_back_val(); + SCEVUse Curr = Worklist.pop_back_val(); auto Users = SCEVUsers.find(Curr); if (Users != SCEVUsers.end()) - for (const auto *User : Users->second) + for (const auto User : Users->second) if (ToForget.insert(User).second) Worklist.push_back(User); } - for (const auto *S : ToForget) + for (const auto S : ToForget) forgetMemoizedResultsImpl(S); for (auto I = PredicatedSCEVRewrites.begin(); I != PredicatedSCEVRewrites.end();) { - std::pair Entry = I->first; + std::pair Entry = I->first; if (ToForget.count(Entry.first)) PredicatedSCEVRewrites.erase(I++); else @@ -14246,7 +14346,7 @@ void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { } } -void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { +void ScalarEvolution::forgetMemoizedResultsImpl(SCEVUse S) { LoopDispositions.erase(S); BlockDispositions.erase(S); UnsignedRanges.erase(S); @@ -14301,14 +14401,13 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { FoldCacheUser.erase(S); } -void -ScalarEvolution::getUsedLoops(const SCEV *S, - SmallPtrSetImpl &LoopsUsed) { +void ScalarEvolution::getUsedLoops(SCEVUse S, + SmallPtrSetImpl &LoopsUsed) { struct FindUsedLoops { FindUsedLoops(SmallPtrSetImpl &LoopsUsed) : LoopsUsed(LoopsUsed) {} SmallPtrSetImpl &LoopsUsed; - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { if (auto *AR = dyn_cast(S)) LoopsUsed.insert(AR->getLoop()); return true; @@ -14340,8 +14439,8 @@ void ScalarEvolution::getReachableBlocks( } if (auto *Cmp = dyn_cast(Cond)) { - const SCEV *L = getSCEV(Cmp->getOperand(0)); - const SCEV *R = getSCEV(Cmp->getOperand(1)); + SCEVUse L = getSCEV(Cmp->getOperand(0)); + SCEVUse R = getSCEV(Cmp->getOperand(1)); if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) { Worklist.push_back(TrueBB); continue; @@ -14368,15 +14467,15 @@ void ScalarEvolution::verify() const { struct SCEVMapper : public SCEVRewriteVisitor { SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {} - const SCEV *visitConstant(const SCEVConstant *Constant) { + SCEVUse visitConstant(const SCEVConstant *Constant) { return SE.getConstant(Constant->getAPInt()); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { return SE.getUnknown(Expr->getValue()); } - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + SCEVUse visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return SE.getCouldNotCompute(); } }; @@ -14385,7 +14484,7 @@ void ScalarEvolution::verify() const { SmallPtrSet ReachableBlocks; SE2.getReachableBlocks(ReachableBlocks, F); - auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * { + auto GetDelta = [&](SCEVUse Old, SCEVUse New) -> SCEVUse { if (containsUndefs(Old) || containsUndefs(New)) { // SCEV treats "undef" as an unknown but consistent value (i.e. it does // not propagate undef aggressively). This means we can (and do) fail @@ -14396,7 +14495,7 @@ void ScalarEvolution::verify() const { } // Unless VerifySCEVStrict is set, we only compare constant deltas. - const SCEV *Delta = SE2.getMinusSCEV(Old, New); + SCEVUse Delta = SE2.getMinusSCEV(Old, New); if (!VerifySCEVStrict && !isa(Delta)) return nullptr; @@ -14418,9 +14517,9 @@ void ScalarEvolution::verify() const { if (It == BackedgeTakenCounts.end()) continue; - auto *CurBECount = + auto CurBECount = SCM.visit(It->second.getExact(L, const_cast(this))); - auto *NewBECount = SE2.getBackedgeTakenCount(L); + auto NewBECount = SE2.getBackedgeTakenCount(L); if (CurBECount == SE2.getCouldNotCompute() || NewBECount == SE2.getCouldNotCompute()) { @@ -14439,7 +14538,7 @@ void ScalarEvolution::verify() const { SE.getTypeSizeInBits(NewBECount->getType())) CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType()); - const SCEV *Delta = GetDelta(CurBECount, NewBECount); + SCEVUse Delta = GetDelta(CurBECount, NewBECount); if (Delta && !Delta->isZero()) { dbgs() << "Trip Count for " << *L << " Changed!\n"; dbgs() << "Old: " << *CurBECount << "\n"; @@ -14477,9 +14576,9 @@ void ScalarEvolution::verify() const { if (auto *I = dyn_cast(&*KV.first)) { if (!ReachableBlocks.contains(I->getParent())) continue; - const SCEV *OldSCEV = SCM.visit(KV.second); - const SCEV *NewSCEV = SE2.getSCEV(I); - const SCEV *Delta = GetDelta(OldSCEV, NewSCEV); + SCEVUse OldSCEV = SCM.visit(KV.second); + SCEVUse NewSCEV = SE2.getSCEV(I); + SCEVUse Delta = GetDelta(OldSCEV, NewSCEV); if (Delta && !Delta->isZero()) { dbgs() << "SCEV for value " << *I << " changed!\n" << "Old: " << *OldSCEV << "\n" @@ -14508,7 +14607,7 @@ void ScalarEvolution::verify() const { // Verify integrity of SCEV users. for (const auto &S : UniqueSCEVs) { - for (const auto *Op : S.operands()) { + for (const auto Op : S.operands()) { // We do not store dependencies of constants. if (isa(Op)) continue; @@ -14523,10 +14622,10 @@ void ScalarEvolution::verify() const { // Verify integrity of ValuesAtScopes users. for (const auto &ValueAndVec : ValuesAtScopes) { - const SCEV *Value = ValueAndVec.first; + SCEVUse Value = ValueAndVec.first; for (const auto &LoopAndValueAtScope : ValueAndVec.second) { const Loop *L = LoopAndValueAtScope.first; - const SCEV *ValueAtScope = LoopAndValueAtScope.second; + SCEVUse ValueAtScope = LoopAndValueAtScope.second; if (!isa(ValueAtScope)) { auto It = ValuesAtScopesUsers.find(ValueAtScope); if (It != ValuesAtScopesUsers.end() && @@ -14540,10 +14639,10 @@ void ScalarEvolution::verify() const { } for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) { - const SCEV *ValueAtScope = ValueAtScopeAndVec.first; + SCEVUse ValueAtScope = ValueAtScopeAndVec.first; for (const auto &LoopAndValue : ValueAtScopeAndVec.second) { const Loop *L = LoopAndValue.first; - const SCEV *Value = LoopAndValue.second; + SCEVUse Value = LoopAndValue.second; assert(!isa(Value)); auto It = ValuesAtScopes.find(Value); if (It != ValuesAtScopes.end() && @@ -14561,7 +14660,7 @@ void ScalarEvolution::verify() const { Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts; for (const auto &LoopAndBEInfo : BECounts) { for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) { - for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { + for (SCEVUse S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { if (!isa(S)) { auto UserIt = BECountUsers.find(S); if (UserIt != BECountUsers.end() && @@ -14736,22 +14835,22 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequiredTransitive(); } -const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS, - const SCEV *RHS) { +const SCEVPredicate *ScalarEvolution::getEqualPredicate(SCEVUse LHS, + SCEVUse RHS) { return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS); } const SCEVPredicate * ScalarEvolution::getComparePredicate(const ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { FoldingSetNodeID ID; assert(LHS->getType() == RHS->getType() && "Type mismatch between LHS and RHS"); // Unique this node based on the arguments ID.AddInteger(SCEVPredicate::P_Compare); ID.AddInteger(Pred); - ID.AddPointer(LHS); - ID.AddPointer(RHS); + ID.AddPointer(LHS.getRawPointer()); + ID.AddPointer(RHS.getRawPointer()); void *IP = nullptr; if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) return S; @@ -14767,6 +14866,7 @@ const SCEVPredicate *ScalarEvolution::getWrapPredicate( FoldingSetNodeID ID; // Unique this node based on the arguments ID.AddInteger(SCEVPredicate::P_Wrap); + // TODO: Use SCEVUse ID.AddPointer(AR); ID.AddInteger(AddedFlags); void *IP = nullptr; @@ -14791,14 +14891,14 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { /// /// If \p NewPreds is non-null, rewrite is free to add further predicates to /// \p NewPreds such that the result will be an AddRecExpr. - static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, - SmallVectorImpl *NewPreds, - const SCEVPredicate *Pred) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE, + SmallVectorImpl *NewPreds, + const SCEVPredicate *Pred) { SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); return Rewriter.visit(S); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { if (Pred) { if (auto *U = dyn_cast(Pred)) { for (const auto *Pred : U->getPredicates()) @@ -14815,13 +14915,13 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { return convertToAddRecWithPreds(Expr); } - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + SCEVUse visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + SCEVUse Operand = visit(Expr->getOperand()); const SCEVAddRecExpr *AR = dyn_cast(Operand); if (AR && AR->getLoop() == L && AR->isAffine()) { // This couldn't be folded because the operand didn't have the nuw // flag. Add the nusw flag as an assumption that we could make. - const SCEV *Step = AR->getStepRecurrence(SE); + SCEVUse Step = AR->getStepRecurrence(SE); Type *Ty = Expr->getType(); if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), @@ -14831,13 +14931,13 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { return SE.getZeroExtendExpr(Operand, Expr->getType()); } - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + SCEVUse visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + SCEVUse Operand = visit(Expr->getOperand()); const SCEVAddRecExpr *AR = dyn_cast(Operand); if (AR && AR->getLoop() == L && AR->isAffine()) { // This couldn't be folded because the operand didn't have the nsw // flag. Add the nssw flag as an assumption that we could make. - const SCEV *Step = AR->getStepRecurrence(SE); + SCEVUse Step = AR->getStepRecurrence(SE); Type *Ty = Expr->getType(); if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), @@ -14875,11 +14975,10 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { // If \p Expr does not meet these conditions (is not a PHI node, or we // couldn't create an AddRec for it, or couldn't add the predicate), we just // return \p Expr. - const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) { + SCEVUse convertToAddRecWithPreds(const SCEVUnknown *Expr) { if (!isa(Expr->getValue())) return Expr; - std::optional< - std::pair>> + std::optional>> PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr); if (!PredicatedRewrite) return Expr; @@ -14902,15 +15001,13 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { } // end anonymous namespace -const SCEV * -ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, - const SCEVPredicate &Preds) { +SCEVUse ScalarEvolution::rewriteUsingPredicate(SCEVUse S, const Loop *L, + const SCEVPredicate &Preds) { return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); } const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( - const SCEV *S, const Loop *L, - SmallVectorImpl &Preds) { + SCEVUse S, const Loop *L, SmallVectorImpl &Preds) { SmallVector TransformPreds; S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); auto *AddRec = dyn_cast(S); @@ -14931,9 +15028,9 @@ SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, : FastID(ID), Kind(Kind) {} SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, - const ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) - : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) { + const ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS) + : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) { assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match"); assert(LHS != RHS && "LHS and RHS are the same SCEV"); } @@ -15059,9 +15156,8 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, Preds = std::make_unique(Empty); } -void ScalarEvolution::registerUser(const SCEV *User, - ArrayRef Ops) { - for (const auto *Op : Ops) +void ScalarEvolution::registerUser(SCEVUse User, ArrayRef Ops) { + for (const auto Op : Ops) // We do not expect that forgetting cached data for SCEVConstants will ever // open any prospects for sharpening or introduce any correctness issues, // so we don't bother storing their dependencies. @@ -15069,8 +15165,8 @@ void ScalarEvolution::registerUser(const SCEV *User, SCEVUsers[Op].insert(User); } -const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { - const SCEV *Expr = SE.getSCEV(V); +SCEVUse PredicatedScalarEvolution::getSCEV(Value *V) { + SCEVUse Expr = SE.getSCEV(V); RewriteEntry &Entry = RewriteMap[Expr]; // If we already have an entry and the version matches, return it. @@ -15082,13 +15178,13 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds); + SCEVUse NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds); Entry = {Generation, NewSCEV}; return NewSCEV; } -const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { +SCEVUse PredicatedScalarEvolution::getBackedgeTakenCount() { if (!BackedgeCount) { SmallVector Preds; BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds); @@ -15137,7 +15233,7 @@ void PredicatedScalarEvolution::updateGeneration() { // If the generation number wrapped recompute everything. if (++Generation == 0) { for (auto &II : RewriteMap) { - const SCEV *Rewritten = II.second.second; + SCEVUse Rewritten = II.second.second; II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)}; } } @@ -15145,7 +15241,7 @@ void PredicatedScalarEvolution::updateGeneration() { void PredicatedScalarEvolution::setNoOverflow( Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { - const SCEV *Expr = getSCEV(V); + SCEVUse Expr = getSCEV(V); const auto *AR = cast(Expr); auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); @@ -15161,7 +15257,7 @@ void PredicatedScalarEvolution::setNoOverflow( bool PredicatedScalarEvolution::hasNoOverflow( Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { - const SCEV *Expr = getSCEV(V); + SCEVUse Expr = getSCEV(V); const auto *AR = cast(Expr); Flags = SCEVWrapPredicate::clearFlags( @@ -15176,7 +15272,7 @@ bool PredicatedScalarEvolution::hasNoOverflow( } const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { - const SCEV *Expr = this->getSCEV(V); + SCEVUse Expr = this->getSCEV(V); SmallVector NewPreds; auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds); @@ -15206,7 +15302,7 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { if (!SE.isSCEVable(I.getType())) continue; - auto *Expr = SE.getSCEV(&I); + auto Expr = SE.getSCEV(&I); auto II = RewriteMap.find(Expr); if (II == RewriteMap.end()) @@ -15227,8 +15323,7 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { // for URem with constant power-of-2 second operands. // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is // 4, A / B becomes X / 8). -bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, - const SCEV *&RHS) { +bool ScalarEvolution::matchURem(SCEVUse Expr, SCEVUse &LHS, SCEVUse &RHS) { if (Expr->getType()->isPointerTy()) return false; @@ -15253,13 +15348,13 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, if (Add == nullptr || Add->getNumOperands() != 2) return false; - const SCEV *A = Add->getOperand(1); + SCEVUse A = Add->getOperand(1); const auto *Mul = dyn_cast(Add->getOperand(0)); if (Mul == nullptr) return false; - const auto MatchURemWithDivisor = [&](const SCEV *B) { + const auto MatchURemWithDivisor = [&](SCEVUse B) { // (SomeExpr + (-(SomeExpr / B) * B)). if (Expr == getURemExpr(A, B)) { LHS = A; @@ -15361,10 +15456,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock( const BasicBlock *Block, const BasicBlock *Pred, SmallPtrSetImpl &VisitedBlocks, unsigned Depth) { SmallVector ExprsToRewrite; - auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, - const SCEV *RHS, - DenseMap - &RewriteMap) { + auto CollectCondition = [&](ICmpInst::Predicate Predicate, SCEVUse LHS, + SCEVUse RHS, + DenseMap &RewriteMap) { // WARNING: It is generally unsound to apply any wrap flags to the proposed // replacement SCEV which isn't directly implied by the structure of that // SCEV. In particular, using contextual facts to imply flags is *NOT* @@ -15399,7 +15493,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet()) return false; auto I = RewriteMap.find(LHSUnknown); - const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown; + SCEVUse RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown; RewriteMap[LHSUnknown] = SE.getUMaxExpr( SE.getConstant(ExactRegion.getUnsignedMin()), SE.getUMinExpr(RewrittenLHS, @@ -15414,8 +15508,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS // the non-constant operand and in \p LHS the constant operand. auto IsMinMaxSCEVWithNonNegativeConstant = - [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, - const SCEV *&RHS) { + [&](SCEVUse Expr, SCEVTypes &SCTy, SCEVUse &LHS, SCEVUse &RHS) { if (auto *MinMax = dyn_cast(Expr)) { if (MinMax->getNumOperands() != 2) return false; @@ -15433,7 +15526,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // Checks whether Expr is a non-negative constant, and Divisor is a positive // constant, and returns their APInt in ExprVal and in DivisorVal. - auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor, + auto GetNonNegExprAndPosDivisor = [&](SCEVUse Expr, SCEVUse Divisor, APInt &ExprVal, APInt &DivisorVal) { auto *ConstExpr = dyn_cast(Expr); auto *ConstDivisor = dyn_cast(Divisor); @@ -15447,8 +15540,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // Return a new SCEV that modifies \p Expr to the closest number divides by // \p Divisor and greater or equal than Expr. // For now, only handle constant Expr and Divisor. - auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { + auto GetNextSCEVDividesByDivisor = [&](SCEVUse Expr, SCEVUse Divisor) { APInt ExprVal; APInt DivisorVal; if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) @@ -15463,8 +15555,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // Return a new SCEV that modifies \p Expr to the closest number divides by // \p Divisor and less or equal than Expr. // For now, only handle constant Expr and Divisor. - auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { + auto GetPreviousSCEVDividesByDivisor = [&](SCEVUse Expr, SCEVUse Divisor) { APInt ExprVal; APInt DivisorVal; if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) @@ -15477,10 +15568,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, // recursively. This is done by aligning up/down the constant value to the // Divisor. - std::function - ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, - const SCEV *Divisor) { - const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; + std::function ApplyDivisibiltyOnMinMaxExpr = + [&](SCEVUse MinMaxExpr, SCEVUse Divisor) { + SCEVUse MinMaxLHS = nullptr, MinMaxRHS = nullptr; SCEVTypes SCTy; if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, MinMaxRHS)) @@ -15489,10 +15579,10 @@ void ScalarEvolution::LoopGuards::collectFromBlock( isa(MinMaxExpr) || isa(MinMaxExpr); assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!"); - auto *DivisibleExpr = + auto DivisibleExpr = IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor) : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor); - SmallVector Ops = { + SmallVector Ops = { ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; return SE.getMinMaxExpr(SCTy, Ops); }; @@ -15502,15 +15592,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock( if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) { // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to // explicitly express that. - const SCEV *URemLHS = nullptr; - const SCEV *URemRHS = nullptr; + SCEVUse URemLHS = nullptr; + SCEVUse URemRHS = nullptr; if (SE.matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { auto I = RewriteMap.find(LHSUnknown); - const SCEV *RewrittenLHS = - I != RewriteMap.end() ? I->second : LHSUnknown; + SCEVUse RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown; RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); - const auto *Multiple = + const auto Multiple = SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; ExprsToRewrite.push_back(LHSUnknown); @@ -15533,8 +15622,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // and \p FromRewritten are the same (i.e. there has been no rewrite // registered for \p From), then puts this value in the list of rewritten // expressions. - auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten, - const SCEV *To) { + auto AddRewrite = [&](SCEVUse From, SCEVUse FromRewritten, SCEVUse To) { if (From == FromRewritten) ExprsToRewrite.push_back(From); RewriteMap[From] = To; @@ -15543,7 +15631,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // Checks whether \p S has already been rewritten. In that case returns the // existing rewrite because we want to chain further rewrites onto the // already rewritten value. Otherwise returns \p S. - auto GetMaybeRewritten = [&](const SCEV *S) { + auto GetMaybeRewritten = [&](SCEVUse S) { auto I = RewriteMap.find(S); return I != RewriteMap.end() ? I->second : S; }; @@ -15555,13 +15643,13 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p // DividesBy. - std::function HasDivisibiltyInfo = - [&](const SCEV *Expr, const SCEV *&DividesBy) { + std::function HasDivisibiltyInfo = + [&](SCEVUse Expr, SCEVUse &DividesBy) { if (auto *Mul = dyn_cast(Expr)) { if (Mul->getNumOperands() != 2) return false; - auto *MulLHS = Mul->getOperand(0); - auto *MulRHS = Mul->getOperand(1); + auto MulLHS = Mul->getOperand(0); + auto MulRHS = Mul->getOperand(1); if (isa(MulLHS)) std::swap(MulLHS, MulRHS); if (auto *Div = dyn_cast(MulLHS)) @@ -15577,8 +15665,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock( }; // Return true if Expr known to divide by \p DividesBy. - std::function IsKnownToDivideBy = - [&](const SCEV *Expr, const SCEV *DividesBy) { + std::function IsKnownToDivideBy = + [&](SCEVUse Expr, SCEVUse DividesBy) { if (SE.getURemExpr(Expr, DividesBy)->isZero()) return true; if (auto *MinMax = dyn_cast(Expr)) @@ -15587,8 +15675,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock( return false; }; - const SCEV *RewrittenLHS = GetMaybeRewritten(LHS); - const SCEV *DividesBy = nullptr; + SCEVUse RewrittenLHS = GetMaybeRewritten(LHS); + SCEVUse DividesBy = nullptr; if (HasDivisibiltyInfo(RewrittenLHS, DividesBy)) // Check that the whole expression is divided by DividesBy DividesBy = @@ -15605,50 +15693,50 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // We cannot express strict predicates in SCEV, so instead we replace them // with non-strict ones against plus or minus one of RHS depending on the // predicate. - const SCEV *One = SE.getOne(RHS->getType()); + SCEVUse One = SE.getOne(RHS->getType()); switch (Predicate) { - case CmpInst::ICMP_ULT: - if (RHS->getType()->isPointerTy()) - return; - RHS = SE.getUMaxExpr(RHS, One); - [[fallthrough]]; - case CmpInst::ICMP_SLT: { - RHS = SE.getMinusSCEV(RHS, One); - RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; - break; - } - case CmpInst::ICMP_UGT: - case CmpInst::ICMP_SGT: - RHS = SE.getAddExpr(RHS, One); - RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; - break; - case CmpInst::ICMP_ULE: - case CmpInst::ICMP_SLE: - RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; - break; - case CmpInst::ICMP_UGE: - case CmpInst::ICMP_SGE: - RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; - break; - default: - break; + case CmpInst::ICMP_ULT: + if (RHS->getType()->isPointerTy()) + return; + RHS = SE.getUMaxExpr(RHS, One); + [[fallthrough]]; + case CmpInst::ICMP_SLT: { + RHS = SE.getMinusSCEV(RHS, One); + RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + } + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_SGT: + RHS = SE.getAddExpr(RHS, One); + RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + case CmpInst::ICMP_ULE: + case CmpInst::ICMP_SLE: + RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + case CmpInst::ICMP_UGE: + case CmpInst::ICMP_SGE: + RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + default: + break; } - SmallVector Worklist(1, LHS); - SmallPtrSet Visited; + SmallVector Worklist(1, LHS); + SmallPtrSet Visited; auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) { append_range(Worklist, S->operands()); }; while (!Worklist.empty()) { - const SCEV *From = Worklist.pop_back_val(); + SCEVUse From = Worklist.pop_back_val(); if (isa(From)) continue; if (!Visited.insert(From).second) continue; - const SCEV *FromRewritten = GetMaybeRewritten(From); - const SCEV *To = nullptr; + SCEVUse FromRewritten = GetMaybeRewritten(From); + SCEVUse To = nullptr; switch (Predicate) { case CmpInst::ICMP_ULT: @@ -15766,8 +15854,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock( if (auto *Cmp = dyn_cast(Cond)) { auto Predicate = EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate(); - const auto *LHS = SE.getSCEV(Cmp->getOperand(0)); - const auto *RHS = SE.getSCEV(Cmp->getOperand(1)); + const auto LHS = SE.getSCEV(Cmp->getOperand(0)); + const auto RHS = SE.getSCEV(Cmp->getOperand(1)); CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap); continue; } @@ -15799,7 +15887,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // sub-expressions. if (ExprsToRewrite.size() > 1) { for (const SCEV *Expr : ExprsToRewrite) { - const SCEV *RewriteTo = Guards.RewriteMap[Expr]; + SCEVUse RewriteTo = Guards.RewriteMap[Expr]; Guards.RewriteMap.erase(Expr); Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)}); } @@ -15812,7 +15900,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { /// replacement is loop invariant in the loop of the AddRec. class SCEVLoopGuardRewriter : public SCEVRewriteVisitor { - const DenseMap ⤅ + const DenseMap ⤅ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap; @@ -15841,12 +15929,12 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { // If we didn't find the extact ZExt expr in the map, check if there's // an entry for a smaller ZExt we can use instead. Type *Ty = Expr->getType(); - const SCEV *Op = Expr->getOperand(0); + SCEVUse Op = Expr->getOperand(0); unsigned Bitwidth = Ty->getScalarSizeInBits() / 2; while (Bitwidth % 8 == 0 && Bitwidth >= 8 && Bitwidth > Op->getType()->getScalarSizeInBits()) { Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth); - auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy); + auto NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy); auto I = Map.find(NarrowExt); if (I != Map.end()) return SE.getZeroExtendExpr(I->second, Ty); @@ -15884,7 +15972,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back( SCEVRewriteVisitor::visit(Op)); Changed |= Op != Operands.back(); @@ -15900,7 +15988,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back( SCEVRewriteVisitor::visit(Op)); Changed |= Op != Operands.back(); @@ -15921,11 +16009,11 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { return Rewriter.visit(Expr); } -const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { +SCEVUse ScalarEvolution::applyLoopGuards(SCEVUse Expr, const Loop *L) { return applyLoopGuards(Expr, LoopGuards::collect(L, *this)); } -const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, - const LoopGuards &Guards) { +SCEVUse ScalarEvolution::applyLoopGuards(SCEVUse Expr, + const LoopGuards &Guards) { return Guards.rewrite(Expr); } diff --git a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp index 800b96c45aecc..1e2ab97446e59 100644 --- a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp +++ b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp @@ -563,6 +563,7 @@ bool PPCLoopInstrFormPrep::rewriteLoadStoresForCommoningChains( const SCEV *BaseSCEV = ChainIdx ? SE->getAddExpr(Bucket.BaseSCEV, Bucket.Elements[BaseElemIdx].Offset) + .getPointer() : Bucket.BaseSCEV; const SCEVAddRecExpr *BasePtrSCEV = cast(BaseSCEV); @@ -596,6 +597,7 @@ bool PPCLoopInstrFormPrep::rewriteLoadStoresForCommoningChains( const SCEV *OffsetSCEV = BaseElemIdx ? SE->getMinusSCEV(Bucket.Elements[Idx].Offset, Bucket.Elements[BaseElemIdx].Offset) + .getPointer() : Bucket.Elements[Idx].Offset; // Make sure offset is able to expand. Only need to check one time as the diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index e706a6f83b1e7..191fc08cc0cb8 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -424,8 +424,8 @@ bool InductiveRangeCheck::reassociateSubLHS( auto getExprScaledIfOverflow = [&](Instruction::BinaryOps BinOp, const SCEV *LHS, const SCEV *RHS) -> const SCEV * { - const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *, - SCEV::NoWrapFlags, unsigned); + SCEVUse (ScalarEvolution::*Operation)(SCEVUse, SCEVUse, SCEV::NoWrapFlags, + unsigned); switch (BinOp) { default: llvm_unreachable("Unsupported binary op"); @@ -766,7 +766,7 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, const SCEV *Zero = SE.getZero(M->getType()); // This function returns SCEV equal to 1 if X is non-negative 0 otherwise. - auto SCEVCheckNonNegative = [&](const SCEV *X) { + auto SCEVCheckNonNegative = [&](const SCEV *X) -> const SCEV * { const Loop *L = IndVar->getLoop(); const SCEV *Zero = SE.getZero(X->getType()); const SCEV *One = SE.getOne(X->getType()); diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 05cf638d3f09d..e43ae298e8757 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -846,8 +846,8 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, return false; } - const SCEV *PointerStrideSCEV = Ev->getOperand(1); - const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); + SCEVUse PointerStrideSCEV = Ev->getOperand(1); + SCEVUse MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); if (!PointerStrideSCEV || !MemsetSizeSCEV) return false; @@ -889,9 +889,9 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, // Compare positive direction PointerStrideSCEV with MemsetSizeSCEV IsNegStride = PointerStrideSCEV->isNonConstantNegative(); - const SCEV *PositiveStrideSCEV = - IsNegStride ? SE->getNegativeSCEV(PointerStrideSCEV) - : PointerStrideSCEV; + SCEVUse PositiveStrideSCEV = IsNegStride + ? SE->getNegativeSCEV(PointerStrideSCEV) + : PointerStrideSCEV; LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n" << " PositiveStrideSCEV: " << *PositiveStrideSCEV << "\n"); diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index d51d043f9f0d9..9cab0774d045c 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -3444,8 +3444,9 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, // be signed. const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy); Accum = SE.getAddExpr(Accum, IncExpr); - LeftOverExpr = LeftOverExpr ? - SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr; + LeftOverExpr = LeftOverExpr + ? SE.getAddExpr(LeftOverExpr, IncExpr).getPointer() + : IncExpr; } // Look through each base to see if any can produce a nice addressing mode. @@ -3846,7 +3847,7 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, for (const SCEV *S : Add->operands()) { const SCEV *Remainder = CollectSubexprs(S, C, Ops, L, SE, Depth+1); if (Remainder) - Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder); + Ops.push_back(C ? SE.getMulExpr(C, Remainder).getPointer() : Remainder); } return nullptr; } else if (const SCEVAddRecExpr *AR = dyn_cast(S)) { @@ -3859,7 +3860,7 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, // Split the non-zero AddRec unless it is part of a nested recurrence that // does not pertain to this loop. if (Remainder && (AR->getLoop() == L || !isa(Remainder))) { - Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder); + Ops.push_back(C ? SE.getMulExpr(C, Remainder).getPointer() : Remainder); Remainder = nullptr; } if (Remainder != AR->getStart()) { diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 791d528823972..70dba295e6ceb 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -509,8 +509,8 @@ class LoopCompare { Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { // Recognize the canonical representation of an unsimplifed urem. - const SCEV *URemLHS = nullptr; - const SCEV *URemRHS = nullptr; + SCEVUse URemLHS = nullptr; + SCEVUse URemRHS = nullptr; if (SE.matchURem(S, URemLHS, URemRHS)) { Value *LHS = expand(URemLHS); Value *RHS = expand(URemRHS); diff --git a/llvm/test/Transforms/IndVarSimplify/turn-to-invariant.ll b/llvm/test/Transforms/IndVarSimplify/turn-to-invariant.ll index 326ee75e135b0..bc9e8004ec5df 100644 --- a/llvm/test/Transforms/IndVarSimplify/turn-to-invariant.ll +++ b/llvm/test/Transforms/IndVarSimplify/turn-to-invariant.ll @@ -846,11 +846,9 @@ failed: define i32 @test_litter_conditions_constant(i32 %start, i32 %len) { ; CHECK-LABEL: @test_litter_conditions_constant( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[START:%.*]], -1 -; CHECK-NEXT: [[RANGE_CHECK_FIRST_ITER:%.*]] = icmp ult i32 [[TMP0]], [[LEN:%.*]] ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[START]], [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[BACKEDGE:%.*]] ] +; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[START:%.*]], [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[BACKEDGE:%.*]] ] ; CHECK-NEXT: [[CANONICAL_IV:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[CANONICAL_IV_NEXT:%.*]], [[BACKEDGE]] ] ; CHECK-NEXT: [[CONSTANT_CHECK:%.*]] = icmp ult i32 [[CANONICAL_IV]], 65635 ; CHECK-NEXT: br i1 [[CONSTANT_CHECK]], label [[CONSTANT_CHECK_PASSED:%.*]], label [[CONSTANT_CHECK_FAILED:%.*]] @@ -860,8 +858,10 @@ define i32 @test_litter_conditions_constant(i32 %start, i32 %len) { ; CHECK-NEXT: [[AND_1:%.*]] = and i1 [[ZERO_CHECK]], [[FAKE_1]] ; CHECK-NEXT: br i1 [[AND_1]], label [[RANGE_CHECK_BLOCK:%.*]], label [[FAILED_1:%.*]] ; CHECK: range_check_block: +; CHECK-NEXT: [[IV_MINUS_1:%.*]] = add i32 [[IV]], -1 +; CHECK-NEXT: [[RANGE_CHECK:%.*]] = icmp ult i32 [[IV_MINUS_1]], [[LEN:%.*]] ; CHECK-NEXT: [[FAKE_2:%.*]] = call i1 @cond() -; CHECK-NEXT: [[AND_2:%.*]] = and i1 [[RANGE_CHECK_FIRST_ITER]], [[FAKE_2]] +; CHECK-NEXT: [[AND_2:%.*]] = and i1 [[RANGE_CHECK]], [[FAKE_2]] ; CHECK-NEXT: br i1 [[AND_2]], label [[BACKEDGE]], label [[FAILED_2:%.*]] ; CHECK: backedge: ; CHECK-NEXT: [[IV_NEXT]] = add i32 [[IV]], -1 diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index c72cecbba3cb8..d39bd8f343db0 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -63,10 +63,10 @@ static std::optional computeConstantDifference(ScalarEvolution &SE, return SE.computeConstantDifference(LHS, RHS); } - static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS, - const SCEV *&RHS) { - return SE.matchURem(Expr, LHS, RHS); - } +static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, SCEVUse &LHS, + SCEVUse &RHS) { + return SE.matchURem(Expr, LHS, RHS); +} static bool isImpliedCond( ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, @@ -1521,8 +1521,8 @@ TEST_F(ScalarEvolutionsTest, MatchURem) { runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { for (auto *N : {"rem1", "rem2", "rem3", "rem5"}) { auto *URemI = getInstructionByName(F, N); - auto *S = SE.getSCEV(URemI); - const SCEV *LHS, *RHS; + const SCEV *S = SE.getSCEV(URemI); + SCEVUse LHS, RHS; EXPECT_TRUE(matchURem(SE, S, LHS, RHS)); EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0))); EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1))); @@ -1534,8 +1534,8 @@ TEST_F(ScalarEvolutionsTest, MatchURem) { // match results are extended to the size of the input expression. auto *Ext = getInstructionByName(F, "ext"); auto *URem1 = getInstructionByName(F, "rem4"); - auto *S = SE.getSCEV(Ext); - const SCEV *LHS, *RHS; + const SCEV *S = SE.getSCEV(Ext); + SCEVUse LHS, RHS; EXPECT_TRUE(matchURem(SE, S, LHS, RHS)); EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0))); // RHS and URem1->getOperand(1) have different widths, so compare the @@ -1661,11 +1661,11 @@ TEST_F(ScalarEvolutionsTest, ForgetValueWithOverflowInst) { auto *ExtractValue = getInstructionByName(F, "extractvalue"); auto *IV = getInstructionByName(F, "iv"); - auto *ExtractValueScev = SE.getSCEV(ExtractValue); + auto ExtractValueScev = SE.getSCEV(ExtractValue); EXPECT_NE(ExtractValueScev, nullptr); SE.forgetValue(IV); - auto *ExtractValueScevForgotten = SE.getExistingSCEV(ExtractValue); + auto ExtractValueScevForgotten = SE.getExistingSCEV(ExtractValue); EXPECT_EQ(ExtractValueScevForgotten, nullptr); }); } @@ -1706,4 +1706,59 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) { }); } +TEST_F(ScalarEvolutionsTest, SCEVUseWithFlags) { + Type *Ty = IntegerType::get(Context, 32); + FunctionType *FTy = + FunctionType::get(Type::getVoidTy(Context), {Ty, Ty, Ty}, false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); + BasicBlock *BB = BasicBlock::Create(Context, "entry", F); + ReturnInst::Create(Context, nullptr, BB); + + Value *V0 = F->getArg(0); + Value *V1 = F->getArg(1); + Value *V2 = F->getArg(2); + + ScalarEvolution SE = buildSE(*F); + + const SCEV *S0 = SE.getSCEV(V0); + const SCEV *S1 = SE.getSCEV(V1); + const SCEV *S2 = SE.getSCEV(V2); + + SCEVUse AddNoFlags = SE.getAddExpr(S0, SE.getConstant(S0->getType(), 2)); + SCEVUse AddWithFlag2 = {AddNoFlags, 2}; + SCEVUse MulNoFlags = SE.getMulExpr(AddNoFlags, S1); + SCEVUse MulFlags2 = SE.getMulExpr(AddWithFlag2, S1); + EXPECT_EQ(AddNoFlags.getCanonical(), AddWithFlag2.getCanonical()); + EXPECT_EQ(MulNoFlags.getCanonical(), MulFlags2.getCanonical()); + + SCEVUse AddWithFlag1 = {AddNoFlags, 1}; + SCEVUse MulFlags1 = SE.getMulExpr(AddWithFlag1, S1); + EXPECT_EQ(MulNoFlags.getCanonical(), MulFlags1.getCanonical()); + EXPECT_EQ(MulFlags1.getCanonical(), MulFlags2.getCanonical()); + + SCEVUse AddNoFlags2 = SE.getAddExpr(S0, SE.getConstant(S0->getType(), 2)); + EXPECT_EQ(AddNoFlags.getCanonical(), AddNoFlags2.getCanonical()); + EXPECT_EQ(AddNoFlags2.getCanonical(), AddWithFlag2.getCanonical()); + + SCEVUse MulFlags22 = SE.getMulExpr(AddWithFlag2, S1); + EXPECT_EQ(MulFlags22.getCanonical(), MulFlags2.getCanonical()); + EXPECT_EQ(MulNoFlags.getCanonical(), MulFlags22.getCanonical()); + + SCEVUse MulNoFlags2 = SE.getMulExpr(AddNoFlags, S1); + EXPECT_EQ(MulNoFlags.getCanonical(), MulNoFlags2.getCanonical()); + EXPECT_EQ(MulNoFlags2.getCanonical(), MulFlags2.getCanonical()); + EXPECT_EQ(MulNoFlags2.getCanonical(), MulFlags22.getCanonical()); + + SE.getAddExpr(MulNoFlags, S2); + SE.getAddExpr(MulFlags1, S2); + SE.getAddExpr(MulFlags2, S2); + SCEVUse AddMulNoFlags = SE.getAddExpr(MulNoFlags, S2); + SCEVUse AddMulFlags1 = SE.getAddExpr(MulFlags1, S2); + SCEVUse AddMulFlags2 = SE.getAddExpr(MulFlags2, S2); + + EXPECT_EQ(AddMulNoFlags.getCanonical(), AddMulFlags1.getCanonical()); + EXPECT_EQ(AddMulNoFlags.getCanonical(), AddMulFlags2.getCanonical()); + EXPECT_EQ(AddMulFlags1.getCanonical(), AddMulFlags2.getCanonical()); +} + } // end namespace llvm diff --git a/polly/include/polly/Support/ScopHelper.h b/polly/include/polly/Support/ScopHelper.h index 13852ecb18ee7..7ec5abd5dc22f 100644 --- a/polly/include/polly/Support/ScopHelper.h +++ b/polly/include/polly/Support/ScopHelper.h @@ -14,6 +14,7 @@ #define POLLY_SUPPORT_IRHELPER_H #include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/ValueHandle.h" @@ -37,7 +38,7 @@ class Scop; class ScopStmt; /// Same as llvm/Analysis/ScalarEvolutionExpressions.h -using LoopToScevMapT = llvm::DenseMap; +using LoopToScevMapT = llvm::DenseMap; /// Enumeration of assumptions Polly can take. enum AssumptionKind { diff --git a/polly/lib/Support/ScopHelper.cpp b/polly/lib/Support/ScopHelper.cpp index 6d50e297ef715..5f6c24e9277e1 100644 --- a/polly/lib/Support/ScopHelper.cpp +++ b/polly/lib/Support/ScopHelper.cpp @@ -411,38 +411,38 @@ struct ScopExpander final : SCEVVisitor { return GenSE.getMulExpr(NewOps); } const SCEV *visitUMaxExpr(const SCEVUMaxExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getUMaxExpr(NewOps); } const SCEV *visitSMaxExpr(const SCEVSMaxExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getSMaxExpr(NewOps); } const SCEV *visitUMinExpr(const SCEVUMinExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getUMinExpr(NewOps); } const SCEV *visitSMinExpr(const SCEVSMinExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getSMinExpr(NewOps); } const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getUMinExpr(NewOps, /*Sequential=*/true); } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); const Loop *L = E->getLoop();