diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index aed10c2de4372..63ecb65457868 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3664,32 +3664,6 @@ static bool IsVMerge(SDNode *N) { return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMERGE_VVM; } -static bool IsVMv(SDNode *N) { - return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMV_V_V; -} - -static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) { - switch (LMUL) { - case RISCVII::LMUL_F8: - return RISCV::PseudoVMSET_M_B1; - case RISCVII::LMUL_F4: - return RISCV::PseudoVMSET_M_B2; - case RISCVII::LMUL_F2: - return RISCV::PseudoVMSET_M_B4; - case RISCVII::LMUL_1: - return RISCV::PseudoVMSET_M_B8; - case RISCVII::LMUL_2: - return RISCV::PseudoVMSET_M_B16; - case RISCVII::LMUL_4: - return RISCV::PseudoVMSET_M_B32; - case RISCVII::LMUL_8: - return RISCV::PseudoVMSET_M_B64; - case RISCVII::LMUL_RESERVED: - llvm_unreachable("Unexpected LMUL"); - } - llvm_unreachable("Unknown VLMUL enum"); -} - // Try to fold away VMERGE_VVM instructions into their true operands: // // %true = PseudoVADD_VV ... @@ -3704,35 +3678,22 @@ static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) { // If %true is masked, then we can use its mask instead of vmerge's if vmerge's // mask is all ones. // -// We can also fold a VMV_V_V into its true operand, since it is equivalent to a -// VMERGE_VVM with an all ones mask. -// // The resulting VL is the minimum of the two VLs. // // The resulting policy is the effective policy the vmerge would have had, // i.e. whether or not it's passthru operand was implicit-def. bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) { SDValue Passthru, False, True, VL, Mask, Glue; - // A vmv.v.v is equivalent to a vmerge with an all-ones mask. - if (IsVMv(N)) { - Passthru = N->getOperand(0); - False = N->getOperand(0); - True = N->getOperand(1); - VL = N->getOperand(2); - // A vmv.v.v won't have a Mask or Glue, instead we'll construct an all-ones - // mask later below. - } else { - assert(IsVMerge(N)); - Passthru = N->getOperand(0); - False = N->getOperand(1); - True = N->getOperand(2); - Mask = N->getOperand(3); - VL = N->getOperand(4); - // We always have a glue node for the mask at v0. - Glue = N->getOperand(N->getNumOperands() - 1); - } - assert(!Mask || cast(Mask)->getReg() == RISCV::V0); - assert(!Glue || Glue.getValueType() == MVT::Glue); + assert(IsVMerge(N)); + Passthru = N->getOperand(0); + False = N->getOperand(1); + True = N->getOperand(2); + Mask = N->getOperand(3); + VL = N->getOperand(4); + // We always have a glue node for the mask at v0. + Glue = N->getOperand(N->getNumOperands() - 1); + assert(cast(Mask)->getReg() == RISCV::V0); + assert(Glue.getValueType() == MVT::Glue); // If the EEW of True is different from vmerge's SEW, then we can't fold. if (True.getSimpleValueType() != N->getSimpleValueType(0)) @@ -3780,7 +3741,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) { // If True is masked then the vmerge must have either the same mask or an all // 1s mask, since we're going to keep the mask from True. - if (IsMasked && Mask) { + if (IsMasked) { // FIXME: Support mask agnostic True instruction which would have an // undef passthru operand. SDValue TrueMask = @@ -3810,11 +3771,9 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) { SmallVector LoopWorklist; SmallPtrSet Visited; LoopWorklist.push_back(False.getNode()); - if (Mask) - LoopWorklist.push_back(Mask.getNode()); + LoopWorklist.push_back(Mask.getNode()); LoopWorklist.push_back(VL.getNode()); - if (Glue) - LoopWorklist.push_back(Glue.getNode()); + LoopWorklist.push_back(Glue.getNode()); if (SDNode::hasPredecessorHelper(True.getNode(), Visited, LoopWorklist)) return false; } @@ -3875,21 +3834,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) { Glue = True->getOperand(True->getNumOperands() - 1); assert(Glue.getValueType() == MVT::Glue); } - // If we end up using the vmerge mask the vmerge is actually a vmv.v.v, create - // an all-ones mask to use. - else if (IsVMv(N)) { - unsigned TSFlags = TII->get(N->getMachineOpcode()).TSFlags; - unsigned VMSetOpc = GetVMSetForLMul(RISCVII::getLMul(TSFlags)); - ElementCount EC = N->getValueType(0).getVectorElementCount(); - MVT MaskVT = MVT::getVectorVT(MVT::i1, EC); - - SDValue AllOnesMask = - SDValue(CurDAG->getMachineNode(VMSetOpc, DL, MaskVT, VL, SEW), 0); - SDValue MaskCopy = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL, - RISCV::V0, AllOnesMask, SDValue()); - Mask = CurDAG->getRegister(RISCV::V0, MaskVT); - Glue = MaskCopy.getValue(1); - } unsigned MaskedOpc = Info->MaskedPseudo; #ifndef NDEBUG @@ -3968,7 +3912,7 @@ bool RISCVDAGToDAGISel::doPeepholeMergeVVMFold() { if (N->use_empty() || !N->isMachineOpcode()) continue; - if (IsVMerge(N) || IsVMv(N)) + if (IsVMerge(N)) MadeChange |= performCombineVMergeAndVOps(N); } return MadeChange; diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp index 979677ee92332..2abed1ac984e3 100644 --- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp +++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp @@ -65,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass { bool convertToWholeRegister(MachineInstr &MI) const; bool convertToUnmasked(MachineInstr &MI) const; bool convertVMergeToVMv(MachineInstr &MI) const; + bool foldVMV_V_V(MachineInstr &MI); bool isAllOnesMask(const MachineInstr *MaskDef) const; std::optional getConstant(const MachineOperand &VL) const; @@ -324,6 +325,143 @@ bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const { return true; } +/// Given two VL operands, returns the one known to be the smallest or nullptr +/// if unknown. +static const MachineOperand *getKnownMinVL(const MachineOperand *LHS, + const MachineOperand *RHS) { + if (LHS->isReg() && RHS->isReg() && LHS->getReg().isVirtual() && + LHS->getReg() == RHS->getReg()) + return LHS; + if (LHS->isImm() && LHS->getImm() == RISCV::VLMaxSentinel) + return RHS; + if (RHS->isImm() && RHS->getImm() == RISCV::VLMaxSentinel) + return LHS; + if (!LHS->isImm() || !RHS->isImm()) + return nullptr; + return LHS->getImm() <= RHS->getImm() ? LHS : RHS; +} + +/// Check if it's safe to move From down to To, checking that no physical +/// registers are clobbered. +static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) { + assert(From.getParent() == To.getParent() && !From.hasImplicitDef()); + SmallVector PhysUses; + for (const MachineOperand &MO : From.all_uses()) + if (MO.getReg().isPhysical()) + PhysUses.push_back(MO.getReg()); + bool SawStore = false; + for (auto II = From.getIterator(); II != To.getIterator(); II++) { + for (Register PhysReg : PhysUses) + if (II->definesRegister(PhysReg, nullptr)) + return false; + if (II->mayStore()) { + SawStore = true; + break; + } + } + return From.isSafeToMove(SawStore); +} + +static unsigned getSEWLMULRatio(const MachineInstr &MI) { + RISCVII::VLMUL LMUL = RISCVII::getLMul(MI.getDesc().TSFlags); + unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); + return RISCVVType::getSEWLMULRatio(1 << Log2SEW, LMUL); +} + +/// If a PseudoVMV_V_V is the only user of its input, fold its passthru and VL +/// into it. +/// +/// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl1, sew, policy +/// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl2, sew, policy +/// +/// -> +/// +/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, min(vl1, vl2), sew, policy +bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) { + if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V) + return false; + + MachineOperand &Passthru = MI.getOperand(1); + + if (!MRI->hasOneUse(MI.getOperand(2).getReg())) + return false; + + MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg()); + if (!Src || Src->hasUnmodeledSideEffects() || + Src->getParent() != MI.getParent() || Src->getNumDefs() != 1 || + !RISCVII::isFirstDefTiedToFirstUse(Src->getDesc()) || + !RISCVII::hasVLOp(Src->getDesc().TSFlags) || + !RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags)) + return false; + + // Src needs to have the same VLMAX as MI + if (getSEWLMULRatio(MI) != getSEWLMULRatio(*Src)) + return false; + + // Src needs to have the same passthru as VMV_V_V + MachineOperand &SrcPassthru = Src->getOperand(1); + if (SrcPassthru.getReg() != RISCV::NoRegister && + SrcPassthru.getReg() != Passthru.getReg()) + return false; + + // Because Src and MI have the same passthru, we can use either AVL as long as + // it's the smaller of the two. + // + // (src pt, ..., vl=5) x x x x x|. . . + // (vmv.v.v pt, src, vl=3) x x x|. . . . . + // -> + // (src pt, ..., vl=3) x x x|. . . . . + // + // (src pt, ..., vl=3) x x x|. . . . . + // (vmv.v.v pt, src, vl=6) x x x . . .|. . + // -> + // (src pt, ..., vl=3) x x x|. . . . . + MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc())); + const MachineOperand *MinVL = getKnownMinVL(&MI.getOperand(3), &SrcVL); + if (!MinVL) + return false; + + bool VLChanged = !MinVL->isIdenticalTo(SrcVL); + bool ActiveElementsAffectResult = RISCVII::activeElementsAffectResult( + TII->get(RISCV::getRVVMCOpcode(Src->getOpcode())).TSFlags); + + if (VLChanged && (ActiveElementsAffectResult || Src->mayRaiseFPException())) + return false; + + // If Src ends up using MI's passthru/VL, move it so it can access it. + // TODO: We don't need to do this if they already dominate Src. + if (!SrcVL.isIdenticalTo(*MinVL) || !SrcPassthru.isIdenticalTo(Passthru)) { + if (!isSafeToMove(*Src, MI)) + return false; + Src->moveBefore(&MI); + } + + if (SrcPassthru.getReg() != Passthru.getReg()) { + SrcPassthru.setReg(Passthru.getReg()); + // If Src is masked then its passthru needs to be in VRNoV0. + if (Passthru.getReg() != RISCV::NoRegister) + MRI->constrainRegClass(Passthru.getReg(), + TII->getRegClass(Src->getDesc(), 1, TRI, + *Src->getParent()->getParent())); + } + + if (MinVL->isImm()) + SrcVL.ChangeToImmediate(MinVL->getImm()); + else if (MinVL->isReg()) + SrcVL.ChangeToRegister(MinVL->getReg(), false); + + // Use a conservative tu,mu policy, RISCVInsertVSETVLI will relax it if + // passthru is undef. + Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())) + .setImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED); + + MRI->replaceRegWith(MI.getOperand(0).getReg(), Src->getOperand(0).getReg()); + MI.eraseFromParent(); + V0Defs.erase(&MI); + + return true; +} + bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) { if (skipFunction(MF.getFunction())) return false; @@ -358,11 +496,12 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) { } for (MachineBasicBlock &MBB : MF) { - for (MachineInstr &MI : MBB) { + for (MachineInstr &MI : make_early_inc_range(MBB)) { Changed |= convertToVLMAX(MI); Changed |= convertToUnmasked(MI); Changed |= convertToWholeRegister(MI); Changed |= convertVMergeToVMv(MI); + Changed |= foldVMV_V_V(MI); } }