Skip to content

Commit 741b358

Browse files
committed
ISel/AArch64: custom lower vector ISD::LRINT, ISD::LLRINT
Since 98c90a1 (ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering), ISD::LRINT and ISD::LLRINT now have vector variants, that are custom lowered on RISCV, and scalarized on all other targets. Since 2302e4c (Reland "VectorUtils: mark xrint as trivially vectorizable"), lrint and llrint are trivially vectorizable, so all the vectorizers in-tree will produce vector variants when possible. Add a custom lowering for AArch64 to custom-lower the vector variants natively using a combination of frintx, fcvte, and fcvtzs.
1 parent 14774ad commit 741b358

File tree

6 files changed

+2384
-1109
lines changed

6 files changed

+2384
-1109
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
790790
setOperationAction(ISD::FROUND, V8Narrow, Legal);
791791
setOperationAction(ISD::FROUNDEVEN, V8Narrow, Legal);
792792
setOperationAction(ISD::FRINT, V8Narrow, Legal);
793-
setOperationAction(ISD::FSQRT, V8Narrow, Expand);
793+
setOperationAction(ISD::FSQRT, V8Narrow, Expand);
794794
setOperationAction(ISD::FSUB, V8Narrow, Legal);
795795
setOperationAction(ISD::FTRUNC, V8Narrow, Legal);
796796
setOperationAction(ISD::SETCC, V8Narrow, Expand);
@@ -1147,8 +1147,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11471147

11481148
for (auto Op :
11491149
{ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP, ISD::UINT_TO_FP,
1150-
ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::MUL,
1151-
ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT,
1150+
ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::LRINT,
1151+
ISD::LLRINT, ISD::MUL, ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT,
11521152
ISD::STRICT_SINT_TO_FP, ISD::STRICT_UINT_TO_FP, ISD::STRICT_FP_ROUND})
11531153
setOperationAction(Op, MVT::v1i64, Expand);
11541154

@@ -1355,6 +1355,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
13551355
setOperationAction(ISD::SINT_TO_FP, VT, Custom);
13561356
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
13571357
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
1358+
setOperationAction(ISD::LRINT, VT, Custom);
1359+
setOperationAction(ISD::LLRINT, VT, Custom);
13581360
setOperationAction(ISD::MGATHER, VT, Custom);
13591361
setOperationAction(ISD::MSCATTER, VT, Custom);
13601362
setOperationAction(ISD::MLOAD, VT, Custom);
@@ -1420,6 +1422,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14201422
for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
14211423
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
14221424
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1425+
setOperationAction(ISD::LRINT, VT, Custom);
1426+
setOperationAction(ISD::LLRINT, VT, Custom);
14231427
}
14241428

14251429
// Legalize unpacked bitcasts to REINTERPRET_CAST.
@@ -1522,6 +1526,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15221526
setOperationAction(ISD::FFLOOR, VT, Custom);
15231527
setOperationAction(ISD::FNEARBYINT, VT, Custom);
15241528
setOperationAction(ISD::FRINT, VT, Custom);
1529+
setOperationAction(ISD::LRINT, VT, Custom);
1530+
setOperationAction(ISD::LLRINT, VT, Custom);
15251531
setOperationAction(ISD::FROUND, VT, Custom);
15261532
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
15271533
setOperationAction(ISD::FTRUNC, VT, Custom);
@@ -1785,9 +1791,9 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
17851791
setOperationAction(ISD::SREM, VT, Expand);
17861792
setOperationAction(ISD::FREM, VT, Expand);
17871793

1788-
for (unsigned Opcode :
1789-
{ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
1790-
ISD::FP_TO_UINT_SAT, ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT})
1794+
for (unsigned Opcode : {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
1795+
ISD::FP_TO_UINT_SAT, ISD::LRINT, ISD::LLRINT,
1796+
ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT})
17911797
setOperationAction(Opcode, VT, Custom);
17921798

17931799
if (!VT.isFloatingPoint())
@@ -1947,6 +1953,8 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT,
19471953
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
19481954
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
19491955
setOperationAction(ISD::FRINT, VT, Custom);
1956+
setOperationAction(ISD::LRINT, VT, Custom);
1957+
setOperationAction(ISD::LLRINT, VT, Custom);
19501958
setOperationAction(ISD::FROUND, VT, Custom);
19511959
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
19521960
setOperationAction(ISD::FSQRT, VT, Custom);
@@ -4371,6 +4379,54 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
43714379
return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
43724380
}
43734381

4382+
SDValue AArch64TargetLowering::LowerVectorXRINT(SDValue Op,
4383+
SelectionDAG &DAG) const {
4384+
EVT VT = Op.getValueType();
4385+
SDValue Src = Op.getOperand(0);
4386+
SDLoc DL(Op);
4387+
4388+
assert(VT.isVector() && "Expected vector type");
4389+
4390+
EVT ContainerVT = VT;
4391+
EVT SrcVT = Src.getValueType();
4392+
EVT CastVT =
4393+
ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
4394+
4395+
if (VT.isFixedLengthVector()) {
4396+
ContainerVT = getContainerForFixedLengthVector(DAG, VT);
4397+
CastVT = ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
4398+
Src = convertToScalableVector(DAG, CastVT, Src);
4399+
}
4400+
4401+
// First, round the floating-point value into a floating-point register with
4402+
// the current rounding mode.
4403+
SDValue FOp = DAG.getNode(ISD::FRINT, DL, CastVT, Src);
4404+
4405+
// In the case of vector filled with f32, ftrunc will convert it to an i32,
4406+
// but a vector filled with i32 isn't legal. So, FP_EXTEND the f32 into the
4407+
// required size.
4408+
size_t SrcSz = SrcVT.getScalarSizeInBits();
4409+
size_t ContainerSz = ContainerVT.getScalarSizeInBits();
4410+
if (ContainerSz > SrcSz) {
4411+
EVT WidenedVT = MVT::getVectorVT(MVT::getFloatingPointVT(ContainerSz),
4412+
ContainerVT.getVectorElementCount());
4413+
FOp = DAG.getNode(ISD::FP_EXTEND, DL, WidenedVT, FOp.getOperand(0));
4414+
}
4415+
4416+
// Finally, truncate the rounded floating point to an integer, rounding to
4417+
// zero.
4418+
SDValue Pred = getPredicateForVector(DAG, DL, ContainerVT);
4419+
SDValue Undef = DAG.getUNDEF(ContainerVT);
4420+
SDValue Truncated =
4421+
DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, DL, ContainerVT,
4422+
{Pred, FOp.getOperand(0), Undef}, FOp->getFlags());
4423+
4424+
if (!VT.isFixedLengthVector())
4425+
return Truncated;
4426+
4427+
return convertFromScalableVector(DAG, VT, Truncated);
4428+
}
4429+
43744430
SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
43754431
SelectionDAG &DAG) const {
43764432
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
@@ -6628,10 +6684,13 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
66286684
return LowerVECTOR_DEINTERLEAVE(Op, DAG);
66296685
case ISD::VECTOR_INTERLEAVE:
66306686
return LowerVECTOR_INTERLEAVE(Op, DAG);
6631-
case ISD::LROUND:
6632-
case ISD::LLROUND:
66336687
case ISD::LRINT:
6634-
case ISD::LLRINT: {
6688+
case ISD::LLRINT:
6689+
if (Op.getValueType().isVector())
6690+
return LowerVectorXRINT(Op, DAG);
6691+
[[fallthrough]];
6692+
case ISD::LROUND:
6693+
case ISD::LLROUND: {
66356694
assert((Op.getOperand(0).getValueType() == MVT::f16 ||
66366695
Op.getOperand(0).getValueType() == MVT::bf16) &&
66376696
"Expected custom lowering of rounding operations only for f16");

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,7 @@ class AArch64TargetLowering : public TargetLowering {
11551155
SDValue LowerVectorFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
11561156
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
11571157
SDValue LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
1158+
SDValue LowerVectorXRINT(SDValue Op, SelectionDAG &DAG) const;
11581159
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
11591160
SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
11601161
SDValue LowerVectorOR(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)