Skip to content

Commit 4b321e9

Browse files
authored
[SYCL][CUDA] Reland adding bf16 support to NVPTX (#11055)
This reverts commit bfc1956. With workaround landed in #10695, we should be able to reland the community commit llvm/llvm-project@250f2bb now. This revert the revert to reland it.
1 parent 7a77037 commit 4b321e9

24 files changed

+1706
-370
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.def

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,20 @@ TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
176176
AND(SM_86, PTX72))
177177
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
178178
AND(SM_86, PTX72))
179-
TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
180-
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
181-
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
182-
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
179+
TARGET_BUILTIN(__nvvm_fmin_bf16, "yyy", "", AND(SM_80, PTX70))
180+
TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
181+
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "yyy", "", AND(SM_80, PTX70))
182+
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
183+
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
184+
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "yyy", "",
183185
AND(SM_86, PTX72))
184-
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
185-
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
186-
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
186+
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
187+
TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
188+
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
189+
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
190+
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "V2yV2yV2y", "",
187191
AND(SM_86, PTX72))
188-
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
192+
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
189193
AND(SM_86, PTX72))
190194
BUILTIN(__nvvm_fmin_f, "fff", "")
191195
BUILTIN(__nvvm_fmin_ftz_f, "fff", "")
@@ -218,16 +222,20 @@ TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
218222
AND(SM_86, PTX72))
219223
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
220224
AND(SM_86, PTX72))
221-
TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
222-
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
223-
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
224-
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
225+
TARGET_BUILTIN(__nvvm_fmax_bf16, "yyy", "", AND(SM_80, PTX70))
226+
TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
227+
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "yyy", "", AND(SM_80, PTX70))
228+
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
229+
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
230+
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "yyy", "",
225231
AND(SM_86, PTX72))
226-
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
227-
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
228-
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
232+
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
233+
TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
234+
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
235+
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
236+
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "V2yV2yV2y", "",
229237
AND(SM_86, PTX72))
230-
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
238+
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
231239
AND(SM_86, PTX72))
232240
BUILTIN(__nvvm_fmax_f, "fff", "")
233241
BUILTIN(__nvvm_fmax_ftz_f, "fff", "")
@@ -361,10 +369,10 @@ TARGET_BUILTIN(__nvvm_fma_rn_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
361369
TARGET_BUILTIN(__nvvm_fma_rn_ftz_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
362370
TARGET_BUILTIN(__nvvm_fma_rn_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
363371
TARGET_BUILTIN(__nvvm_fma_rn_ftz_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
364-
TARGET_BUILTIN(__nvvm_fma_rn_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
365-
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
366-
TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
367-
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
372+
TARGET_BUILTIN(__nvvm_fma_rn_bf16, "yyyy", "", AND(SM_80, PTX70))
373+
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "yyyy", "", AND(SM_80, PTX70))
374+
TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
375+
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
368376
BUILTIN(__nvvm_fma_rn_ftz_f, "ffff", "")
369377
BUILTIN(__nvvm_fma_rn_f, "ffff", "")
370378
BUILTIN(__nvvm_fma_rz_ftz_f, "ffff", "")
@@ -553,20 +561,20 @@ BUILTIN(__nvvm_ull2d_rp, "dULLi", "")
553561
BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "")
554562
BUILTIN(__nvvm_f2h_rn, "Usf", "")
555563

556-
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70))
557-
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70))
558-
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70))
559-
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70))
564+
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "V2yff", "", AND(SM_80,PTX70))
565+
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "V2yff", "", AND(SM_80,PTX70))
566+
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "V2yff", "", AND(SM_80,PTX70))
567+
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "V2yff", "", AND(SM_80,PTX70))
560568

561569
TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70))
562570
TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70))
563571
TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70))
564572
TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70))
565573

566-
TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70))
567-
TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70))
568-
TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70))
569-
TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70))
574+
TARGET_BUILTIN(__nvvm_f2bf16_rn, "yf", "", AND(SM_80,PTX70))
575+
TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "yf", "", AND(SM_80,PTX70))
576+
TARGET_BUILTIN(__nvvm_f2bf16_rz, "yf", "", AND(SM_80,PTX70))
577+
TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))
570578

571579
TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
572580

@@ -2649,10 +2657,10 @@ TARGET_BUILTIN(__nvvm_cp_async_wait_all, "v", "", AND(SM_80,PTX70))
26492657

26502658

26512659
// bf16, bf16x2 abs, neg
2652-
TARGET_BUILTIN(__nvvm_abs_bf16, "UsUs", "", AND(SM_80,PTX70))
2653-
TARGET_BUILTIN(__nvvm_abs_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
2654-
TARGET_BUILTIN(__nvvm_neg_bf16, "UsUs", "", AND(SM_80,PTX70))
2655-
TARGET_BUILTIN(__nvvm_neg_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
2660+
TARGET_BUILTIN(__nvvm_abs_bf16, "yy", "", AND(SM_80,PTX70))
2661+
TARGET_BUILTIN(__nvvm_abs_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))
2662+
TARGET_BUILTIN(__nvvm_neg_bf16, "yy", "", AND(SM_80,PTX70))
2663+
TARGET_BUILTIN(__nvvm_neg_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))
26562664

26572665
TARGET_BUILTIN(__nvvm_mapa, "v*v*i", "", AND(SM_90, PTX78))
26582666
TARGET_BUILTIN(__nvvm_mapa_shared_cluster, "v*3v*3i", "", AND(SM_90, PTX78))

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4575,13 +4575,13 @@ __device__ void nvvm_async_copy(__attribute__((address_space(3))) void* dst, __a
45754575
// CHECK-LABEL: nvvm_cvt_sm80
45764576
__device__ void nvvm_cvt_sm80() {
45774577
#if __CUDA_ARCH__ >= 800
4578-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
4578+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
45794579
__nvvm_ff2bf16x2_rn(1, 1);
4580-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
4580+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
45814581
__nvvm_ff2bf16x2_rn_relu(1, 1);
4582-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
4582+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
45834583
__nvvm_ff2bf16x2_rz(1, 1);
4584-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
4584+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
45854585
__nvvm_ff2bf16x2_rz_relu(1, 1);
45864586

45874587
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
@@ -4593,13 +4593,13 @@ __device__ void nvvm_cvt_sm80() {
45934593
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
45944594
__nvvm_ff2f16x2_rz_relu(1, 1);
45954595

4596-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
4596+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
45974597
__nvvm_f2bf16_rn(1);
4598-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
4598+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
45994599
__nvvm_f2bf16_rn_relu(1);
4600-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
4600+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
46014601
__nvvm_f2bf16_rz(1);
4602-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
4602+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
46034603
__nvvm_f2bf16_rz_relu(1);
46044604

46054605
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
@@ -4608,32 +4608,32 @@ __device__ void nvvm_cvt_sm80() {
46084608
// CHECK: ret void
46094609
}
46104610

4611+
#define NAN32 0x7FBFFFFF
4612+
#define NAN16 (__bf16)0x7FBF
4613+
#define BF16 (__bf16)0.1f
4614+
#define BF16_2 (__bf16)0.2f
4615+
#define NANBF16 (__bf16)0xFFC1
4616+
#define BF16X2 {(__bf16)0.1f, (__bf16)0.1f}
4617+
#define BF16X2_2 {(__bf16)0.2f, (__bf16)0.2f}
4618+
#define NANBF16X2 {NANBF16, NANBF16}
4619+
46114620
// CHECK-LABEL: nvvm_abs_neg_bf16_bf16x2_sm80
46124621
__device__ void nvvm_abs_neg_bf16_bf16x2_sm80() {
46134622
#if __CUDA_ARCH__ >= 800
46144623

4615-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.abs.bf16(i16 -1)
4616-
__nvvm_abs_bf16(0xFFFF);
4617-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.abs.bf16x2(i32 -1)
4618-
__nvvm_abs_bf16x2(0xFFFFFFFF);
4624+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.abs.bf16(bfloat 0xR3DCD)
4625+
__nvvm_abs_bf16(BF16);
4626+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
4627+
__nvvm_abs_bf16x2(BF16X2);
46194628

4620-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.neg.bf16(i16 -1)
4621-
__nvvm_neg_bf16(0xFFFF);
4622-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.neg.bf16x2(i32 -1)
4623-
__nvvm_neg_bf16x2(0xFFFFFFFF);
4629+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.neg.bf16(bfloat 0xR3DCD)
4630+
__nvvm_neg_bf16(BF16);
4631+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
4632+
__nvvm_neg_bf16x2(BF16X2);
46244633
#endif
46254634
// CHECK: ret void
46264635
}
46274636

4628-
#define NAN32 0x7FBFFFFF
4629-
#define NAN16 0x7FBF
4630-
#define BF16 0x1234
4631-
#define BF16_2 0x4321
4632-
#define NANBF16 0xFFC1
4633-
#define BF16X2 0x12341234
4634-
#define BF16X2_2 0x32343234
4635-
#define NANBF16X2 0xFFC1FFC1
4636-
46374637
// CHECK-LABEL: nvvm_min_max_sm80
46384638
__device__ void nvvm_min_max_sm80() {
46394639
#if __CUDA_ARCH__ >= 800
@@ -4643,14 +4643,22 @@ __device__ void nvvm_min_max_sm80() {
46434643
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmin.ftz.nan.f
46444644
__nvvm_fmin_ftz_nan_f(0.1f, (float)NAN32);
46454645

4646-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.bf16
4646+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.bf16
46474647
__nvvm_fmin_bf16(BF16, BF16_2);
4648-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.nan.bf16
4648+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.bf16
4649+
__nvvm_fmin_ftz_bf16(BF16, BF16_2);
4650+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.nan.bf16
46494651
__nvvm_fmin_nan_bf16(BF16, NANBF16);
4650-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.bf16x2
4652+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.nan.bf16
4653+
__nvvm_fmin_ftz_nan_bf16(BF16, NANBF16);
4654+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.bf16x2
46514655
__nvvm_fmin_bf16x2(BF16X2, BF16X2_2);
4652-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.nan.bf16x2
4656+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.bf16x2
4657+
__nvvm_fmin_ftz_bf16x2(BF16X2, BF16X2_2);
4658+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2
46534659
__nvvm_fmin_nan_bf16x2(BF16X2, NANBF16X2);
4660+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.nan.bf16x2
4661+
__nvvm_fmin_ftz_nan_bf16x2(BF16X2, NANBF16X2);
46544662
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
46554663
__nvvm_fmax_nan_f(0.1f, 0.11f);
46564664
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
@@ -4660,14 +4668,22 @@ __device__ void nvvm_min_max_sm80() {
46604668
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
46614669
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
46624670
__nvvm_fmax_ftz_nan_f(0.1f, (float)NAN32);
4663-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.bf16
4671+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.bf16
46644672
__nvvm_fmax_bf16(BF16, BF16_2);
4665-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.nan.bf16
4673+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.bf16
4674+
__nvvm_fmax_ftz_bf16(BF16, BF16_2);
4675+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.nan.bf16
46664676
__nvvm_fmax_nan_bf16(BF16, NANBF16);
4667-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.bf16x2
4677+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.nan.bf16
4678+
__nvvm_fmax_ftz_nan_bf16(BF16, NANBF16);
4679+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.bf16x2
46684680
__nvvm_fmax_bf16x2(BF16X2, BF16X2_2);
4669-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.nan.bf16x2
4681+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.bf16x2
4682+
__nvvm_fmax_ftz_bf16x2(BF16X2, BF16X2_2);
4683+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2
46704684
__nvvm_fmax_nan_bf16x2(NANBF16X2, BF16X2);
4685+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.nan.bf16x2
4686+
__nvvm_fmax_ftz_nan_bf16x2(NANBF16X2, BF16X2);
46714687
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
46724688
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
46734689
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
@@ -4680,14 +4696,14 @@ __device__ void nvvm_min_max_sm80() {
46804696
// CHECK-LABEL: nvvm_fma_bf16_bf16x2_sm80
46814697
__device__ void nvvm_fma_bf16_bf16x2_sm80() {
46824698
#if __CUDA_ARCH__ >= 800
4683-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.bf16
4684-
__nvvm_fma_rn_bf16(0x1234, 0x7FBF, 0x1234);
4685-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.relu.bf16
4686-
__nvvm_fma_rn_relu_bf16(0x1234, 0x7FBF, 0x1234);
4687-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.bf16x2
4688-
__nvvm_fma_rn_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
4689-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.relu.bf16x2
4690-
__nvvm_fma_rn_relu_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
4699+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.bf16
4700+
__nvvm_fma_rn_bf16(BF16, BF16_2, BF16_2);
4701+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.relu.bf16
4702+
__nvvm_fma_rn_relu_bf16(BF16, BF16_2, BF16_2);
4703+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2
4704+
__nvvm_fma_rn_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
4705+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2
4706+
__nvvm_fma_rn_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
46914707
#endif
46924708
// CHECK: ret void
46934709
}
@@ -4696,13 +4712,13 @@ __device__ void nvvm_fma_bf16_bf16x2_sm80() {
46964712
__device__ void nvvm_min_max_sm86() {
46974713
#if __CUDA_ARCH__ >= 860
46984714

4699-
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.xorsign.abs.bf16
4715+
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.xorsign.abs.bf16
47004716
__nvvm_fmin_xorsign_abs_bf16(BF16, BF16_2);
4701-
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16
4717+
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16
47024718
__nvvm_fmin_nan_xorsign_abs_bf16(BF16, NANBF16);
4703-
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2
4719+
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2
47044720
__nvvm_fmin_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
4705-
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
4721+
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
47064722
__nvvm_fmin_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
47074723
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.xorsign.abs.f
47084724
__nvvm_fmin_xorsign_abs_f(-0.1f, 0.1f);
@@ -4713,13 +4729,13 @@ __device__ void nvvm_min_max_sm86() {
47134729
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f
47144730
__nvvm_fmin_ftz_nan_xorsign_abs_f(-0.1f, (float)NAN32);
47154731

4716-
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.xorsign.abs.bf16
4732+
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.xorsign.abs.bf16
47174733
__nvvm_fmax_xorsign_abs_bf16(BF16, BF16_2);
4718-
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16
4734+
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16
47194735
__nvvm_fmax_nan_xorsign_abs_bf16(BF16, NANBF16);
4720-
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2
4736+
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2
47214737
__nvvm_fmax_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
4722-
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
4738+
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
47234739
__nvvm_fmax_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
47244740
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmax.xorsign.abs.f
47254741
__nvvm_fmax_xorsign_abs_f(-0.1f, 0.1f);

clang/test/CodeGenCUDA/bf16.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
// CHECK-LABEL: .visible .func _Z8test_argPDF16bDF16b(
1010
// CHECK: .param .b64 _Z8test_argPDF16bDF16b_param_0,
11-
// CHECK: .param .b16 _Z8test_argPDF16bDF16b_param_1
11+
// CHECK: .param .align 2 .b8 _Z8test_argPDF16bDF16b_param_1[2]
1212
//
1313
__device__ void test_arg(__bf16 *out, __bf16 in) {
1414
// CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
@@ -20,8 +20,8 @@ __device__ void test_arg(__bf16 *out, __bf16 in) {
2020
}
2121

2222

23-
// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retDF16b(
24-
// CHECK: .param .b16 _Z8test_retDF16b_param_0
23+
// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z8test_retDF16b(
24+
// CHECK: .param .align 2 .b8 _Z8test_retDF16b_param_0[2]
2525
__device__ __bf16 test_ret( __bf16 in) {
2626
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_retDF16b_param_0];
2727
return in;
@@ -31,12 +31,12 @@ __device__ __bf16 test_ret( __bf16 in) {
3131

3232
__device__ __bf16 external_func( __bf16 in);
3333

34-
// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callDF16b(
35-
// CHECK: .param .b16 _Z9test_callDF16b_param_0
34+
// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z9test_callDF16b(
35+
// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2]
3636
__device__ __bf16 test_call( __bf16 in) {
3737
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
3838
// CHECK: st.param.b16 [param0+0], %[[R]];
39-
// CHECK: .param .b32 retval0;
39+
// CHECK: .param .align 2 .b8 retval0[2];
4040
// CHECK: call.uni (retval0),
4141
// CHECK-NEXT: _Z13external_funcDF16b,
4242
// CHECK-NEXT: (

0 commit comments

Comments
 (0)