@@ -790,7 +790,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
790
790
setOperationAction(ISD::FROUND, V8Narrow, Legal);
791
791
setOperationAction(ISD::FROUNDEVEN, V8Narrow, Legal);
792
792
setOperationAction(ISD::FRINT, V8Narrow, Legal);
793
- setOperationAction(ISD::FSQRT, V8Narrow, Expand);
793
+ setOperationAction(ISD::FSQRT, V8Narrow, Expand);
794
794
setOperationAction(ISD::FSUB, V8Narrow, Legal);
795
795
setOperationAction(ISD::FTRUNC, V8Narrow, Legal);
796
796
setOperationAction(ISD::SETCC, V8Narrow, Expand);
@@ -1147,8 +1147,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1147
1147
1148
1148
for (auto Op :
1149
1149
{ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP, ISD::UINT_TO_FP,
1150
- ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::MUL ,
1151
- ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT,
1150
+ ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::LRINT ,
1151
+ ISD::LLRINT, ISD::MUL, ISD:: STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT,
1152
1152
ISD::STRICT_SINT_TO_FP, ISD::STRICT_UINT_TO_FP, ISD::STRICT_FP_ROUND})
1153
1153
setOperationAction(Op, MVT::v1i64, Expand);
1154
1154
@@ -1355,6 +1355,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1355
1355
setOperationAction(ISD::SINT_TO_FP, VT, Custom);
1356
1356
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
1357
1357
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
1358
+ setOperationAction(ISD::LRINT, VT, Custom);
1359
+ setOperationAction(ISD::LLRINT, VT, Custom);
1358
1360
setOperationAction(ISD::MGATHER, VT, Custom);
1359
1361
setOperationAction(ISD::MSCATTER, VT, Custom);
1360
1362
setOperationAction(ISD::MLOAD, VT, Custom);
@@ -1420,6 +1422,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1420
1422
for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
1421
1423
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1422
1424
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1425
+ setOperationAction(ISD::LRINT, VT, Custom);
1426
+ setOperationAction(ISD::LLRINT, VT, Custom);
1423
1427
}
1424
1428
1425
1429
// Legalize unpacked bitcasts to REINTERPRET_CAST.
@@ -1522,6 +1526,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1522
1526
setOperationAction(ISD::FFLOOR, VT, Custom);
1523
1527
setOperationAction(ISD::FNEARBYINT, VT, Custom);
1524
1528
setOperationAction(ISD::FRINT, VT, Custom);
1529
+ setOperationAction(ISD::LRINT, VT, Custom);
1530
+ setOperationAction(ISD::LLRINT, VT, Custom);
1525
1531
setOperationAction(ISD::FROUND, VT, Custom);
1526
1532
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1527
1533
setOperationAction(ISD::FTRUNC, VT, Custom);
@@ -1785,9 +1791,9 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
1785
1791
setOperationAction(ISD::SREM, VT, Expand);
1786
1792
setOperationAction(ISD::FREM, VT, Expand);
1787
1793
1788
- for (unsigned Opcode :
1789
- { ISD::FP_TO_SINT , ISD::FP_TO_UINT , ISD::FP_TO_SINT_SAT ,
1790
- ISD::FP_TO_UINT_SAT, ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT})
1794
+ for (unsigned Opcode : {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
1795
+ ISD::FP_TO_UINT_SAT , ISD::LRINT , ISD::LLRINT ,
1796
+ ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT})
1791
1797
setOperationAction(Opcode, VT, Custom);
1792
1798
1793
1799
if (!VT.isFloatingPoint())
@@ -1947,6 +1953,8 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT,
1947
1953
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
1948
1954
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
1949
1955
setOperationAction(ISD::FRINT, VT, Custom);
1956
+ setOperationAction(ISD::LRINT, VT, Custom);
1957
+ setOperationAction(ISD::LLRINT, VT, Custom);
1950
1958
setOperationAction(ISD::FROUND, VT, Custom);
1951
1959
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1952
1960
setOperationAction(ISD::FSQRT, VT, Custom);
@@ -4371,6 +4379,54 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
4371
4379
return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
4372
4380
}
4373
4381
4382
+ SDValue AArch64TargetLowering::LowerVectorXRINT(SDValue Op,
4383
+ SelectionDAG &DAG) const {
4384
+ EVT VT = Op.getValueType();
4385
+ SDValue Src = Op.getOperand(0);
4386
+ SDLoc DL(Op);
4387
+
4388
+ assert(VT.isVector() && "Expected vector type");
4389
+
4390
+ EVT ContainerVT = VT;
4391
+ EVT SrcVT = Src.getValueType();
4392
+ EVT CastVT =
4393
+ ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
4394
+
4395
+ if (VT.isFixedLengthVector()) {
4396
+ ContainerVT = getContainerForFixedLengthVector(DAG, VT);
4397
+ CastVT = ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
4398
+ Src = convertToScalableVector(DAG, CastVT, Src);
4399
+ }
4400
+
4401
+ // First, round the floating-point value into a floating-point register with
4402
+ // the current rounding mode.
4403
+ SDValue FOp = DAG.getNode(ISD::FRINT, DL, CastVT, Src);
4404
+
4405
+ // In the case of vector filled with f32, ftrunc will convert it to an i32,
4406
+ // but a vector filled with i32 isn't legal. So, FP_EXTEND the f32 into the
4407
+ // required size.
4408
+ size_t SrcSz = SrcVT.getScalarSizeInBits();
4409
+ size_t ContainerSz = ContainerVT.getScalarSizeInBits();
4410
+ if (ContainerSz > SrcSz) {
4411
+ EVT WidenedVT = MVT::getVectorVT(MVT::getFloatingPointVT(ContainerSz),
4412
+ ContainerVT.getVectorElementCount());
4413
+ FOp = DAG.getNode(ISD::FP_EXTEND, DL, WidenedVT, FOp.getOperand(0));
4414
+ }
4415
+
4416
+ // Finally, truncate the rounded floating point to an integer, rounding to
4417
+ // zero.
4418
+ SDValue Pred = getPredicateForVector(DAG, DL, ContainerVT);
4419
+ SDValue Undef = DAG.getUNDEF(ContainerVT);
4420
+ SDValue Truncated =
4421
+ DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, DL, ContainerVT,
4422
+ {Pred, FOp.getOperand(0), Undef}, FOp->getFlags());
4423
+
4424
+ if (!VT.isFixedLengthVector())
4425
+ return Truncated;
4426
+
4427
+ return convertFromScalableVector(DAG, VT, Truncated);
4428
+ }
4429
+
4374
4430
SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
4375
4431
SelectionDAG &DAG) const {
4376
4432
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
@@ -6628,10 +6684,13 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
6628
6684
return LowerVECTOR_DEINTERLEAVE(Op, DAG);
6629
6685
case ISD::VECTOR_INTERLEAVE:
6630
6686
return LowerVECTOR_INTERLEAVE(Op, DAG);
6631
- case ISD::LROUND:
6632
- case ISD::LLROUND:
6633
6687
case ISD::LRINT:
6634
- case ISD::LLRINT: {
6688
+ case ISD::LLRINT:
6689
+ if (Op.getValueType().isVector())
6690
+ return LowerVectorXRINT(Op, DAG);
6691
+ [[fallthrough]];
6692
+ case ISD::LROUND:
6693
+ case ISD::LLROUND: {
6635
6694
assert((Op.getOperand(0).getValueType() == MVT::f16 ||
6636
6695
Op.getOperand(0).getValueType() == MVT::bf16) &&
6637
6696
"Expected custom lowering of rounding operations only for f16");
0 commit comments