Skip to content

[X86] Use NSW/NUW flags on ISD::TRUNCATE nodes to improve X86 PACKSS/PACKUS lowering #123956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20819,7 +20819,8 @@ static SDValue truncateVectorWithPACKSS(EVT DstVT, SDValue In, const SDLoc &DL,
static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
SDValue In, const SDLoc &DL,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
const X86Subtarget &Subtarget,
const SDNodeFlags Flags = SDNodeFlags()) {
// Requires SSE2.
if (!Subtarget.hasSSE2())
return SDValue();
Expand Down Expand Up @@ -20865,7 +20866,8 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
// e.g. Masks, zext_in_reg, etc.
// Pre-SSE41 we can only use PACKUSWB.
KnownBits Known = DAG.computeKnownBits(In);
if ((NumSrcEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros()) {
if ((Flags.hasNoUnsignedWrap() && NumDstEltBits <= NumPackedZeroBits) ||
(NumSrcEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we take NSW/NUW into account when countMinLeadingZeros/ComputeNumSignBits so that other cases may benefit from it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be tricky as its the TRUNC user node that implies the min leading sign/zero count on the source node, not the source node itself. So we'd have to check for users of a node and retroactively adjust the analysis.

I'll do some experiments but I'm not confident

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First thing I've noticed - as soon as the ISD::TRUNCATE node disappears we lose this extra analysis on the upper bits of the source node, which is acceptable for the x86 PACK nodes, but might not be for other nodes that were relying on access to that information?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. Thanks for the investigation!

PackOpcode = X86ISD::PACKUS;
return In;
}
Expand All @@ -20884,7 +20886,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
return SDValue();

unsigned MinSignBits = NumSrcEltBits - NumPackedSignBits;
if (MinSignBits < NumSignBits) {
if (Flags.hasNoSignedWrap() || MinSignBits < NumSignBits) {
PackOpcode = X86ISD::PACKSS;
return In;
}
Expand All @@ -20906,10 +20908,9 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
/// This function lowers a vector truncation of 'extended sign-bits' or
/// 'extended zero-bits' values.
/// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations.
static SDValue LowerTruncateVecPackWithSignBits(MVT DstVT, SDValue In,
const SDLoc &DL,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
static SDValue LowerTruncateVecPackWithSignBits(
MVT DstVT, SDValue In, const SDLoc &DL, const X86Subtarget &Subtarget,
SelectionDAG &DAG, const SDNodeFlags Flags = SDNodeFlags()) {
MVT SrcVT = In.getSimpleValueType();
MVT DstSVT = DstVT.getVectorElementType();
MVT SrcSVT = SrcVT.getVectorElementType();
Expand All @@ -20931,8 +20932,8 @@ static SDValue LowerTruncateVecPackWithSignBits(MVT DstVT, SDValue In,
}

unsigned PackOpcode;
if (SDValue Src =
matchTruncateWithPACK(PackOpcode, DstVT, In, DL, DAG, Subtarget))
if (SDValue Src = matchTruncateWithPACK(PackOpcode, DstVT, In, DL, DAG,
Subtarget, Flags))
return truncateVectorWithPACK(PackOpcode, DstVT, Src, DL, DAG, Subtarget);

return SDValue();
Expand Down Expand Up @@ -21102,8 +21103,8 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
// Pre-AVX512 (or prefer-256bit) see if we can make use of PACKSS/PACKUS.
if (!Subtarget.hasAVX512() ||
(InVT.is512BitVector() && VT.is256BitVector()))
if (SDValue SignPack =
LowerTruncateVecPackWithSignBits(VT, In, DL, Subtarget, DAG))
if (SDValue SignPack = LowerTruncateVecPackWithSignBits(
VT, In, DL, Subtarget, DAG, Op->getFlags()))
return SignPack;

// Pre-AVX512 see if we can make use of PACKSS/PACKUS.
Expand All @@ -21120,8 +21121,8 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
// Attempt to truncate with PACKUS/PACKSS even on AVX512 if we'd have to
// concat from subvectors to use VPTRUNC etc.
if (!Subtarget.hasAVX512() || isFreeToSplitVector(In.getNode(), DAG))
if (SDValue SignPack =
LowerTruncateVecPackWithSignBits(VT, In, DL, Subtarget, DAG))
if (SDValue SignPack = LowerTruncateVecPackWithSignBits(
VT, In, DL, Subtarget, DAG, Op->getFlags()))
return SignPack;

// vpmovqb/w/d, vpmovdb/w, vpmovwb
Expand Down Expand Up @@ -33578,10 +33579,10 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,

// See if there are sufficient leading bits to perform a PACKUS/PACKSS.
unsigned PackOpcode;
if (SDValue Src =
matchTruncateWithPACK(PackOpcode, VT, In, dl, DAG, Subtarget)) {
if (SDValue Res = truncateVectorWithPACK(PackOpcode, VT, Src,
dl, DAG, Subtarget)) {
if (SDValue Src = matchTruncateWithPACK(PackOpcode, VT, In, dl, DAG,
Subtarget, N->getFlags())) {
if (SDValue Res =
truncateVectorWithPACK(PackOpcode, VT, Src, dl, DAG, Subtarget)) {
Res = widenSubVector(WidenVT, Res, false, Subtarget, DAG, dl);
Results.push_back(Res);
return;
Expand Down
Loading
Loading