Skip to content

[SYCL][CUDA] Reland adding bf16 support to NVPTX #11055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 40 additions & 32 deletions clang/include/clang/Basic/BuiltinsNVPTX.def
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,20 @@ TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
TARGET_BUILTIN(__nvvm_fmin_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "yyy", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
BUILTIN(__nvvm_fmin_f, "fff", "")
BUILTIN(__nvvm_fmin_ftz_f, "fff", "")
Expand Down Expand Up @@ -218,16 +222,20 @@ TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
TARGET_BUILTIN(__nvvm_fmax_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "yyy", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
BUILTIN(__nvvm_fmax_f, "fff", "")
BUILTIN(__nvvm_fmax_ftz_f, "fff", "")
Expand Down Expand Up @@ -361,10 +369,10 @@ TARGET_BUILTIN(__nvvm_fma_rn_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
TARGET_BUILTIN(__nvvm_fma_rn_ftz_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
TARGET_BUILTIN(__nvvm_fma_rn_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_ftz_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_bf16, "yyyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "yyyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
BUILTIN(__nvvm_fma_rn_ftz_f, "ffff", "")
BUILTIN(__nvvm_fma_rn_f, "ffff", "")
BUILTIN(__nvvm_fma_rz_ftz_f, "ffff", "")
Expand Down Expand Up @@ -553,20 +561,20 @@ BUILTIN(__nvvm_ull2d_rp, "dULLi", "")
BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "")
BUILTIN(__nvvm_f2h_rn, "Usf", "")

TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "V2yff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "V2yff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "V2yff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "V2yff", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rn, "yf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "yf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz, "yf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))

Expand Down Expand Up @@ -2649,10 +2657,10 @@ TARGET_BUILTIN(__nvvm_cp_async_wait_all, "v", "", AND(SM_80,PTX70))


// bf16, bf16x2 abs, neg
TARGET_BUILTIN(__nvvm_abs_bf16, "UsUs", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_abs_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_neg_bf16, "UsUs", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_neg_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_abs_bf16, "yy", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_abs_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_neg_bf16, "yy", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_neg_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_mapa, "v*v*i", "", AND(SM_90, PTX78))
TARGET_BUILTIN(__nvvm_mapa_shared_cluster, "v*3v*3i", "", AND(SM_90, PTX78))
Expand Down
114 changes: 65 additions & 49 deletions clang/test/CodeGen/builtins-nvptx.c
Original file line number Diff line number Diff line change
Expand Up @@ -4575,13 +4575,13 @@ __device__ void nvvm_async_copy(__attribute__((address_space(3))) void* dst, __a
// CHECK-LABEL: nvvm_cvt_sm80
__device__ void nvvm_cvt_sm80() {
#if __CUDA_ARCH__ >= 800
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn(1, 1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn_relu(1, 1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz(1, 1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz_relu(1, 1);

// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
Expand All @@ -4593,13 +4593,13 @@ __device__ void nvvm_cvt_sm80() {
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rz_relu(1, 1);

// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
__nvvm_f2bf16_rn(1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
__nvvm_f2bf16_rn_relu(1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
__nvvm_f2bf16_rz(1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
__nvvm_f2bf16_rz_relu(1);

// CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
Expand All @@ -4608,32 +4608,32 @@ __device__ void nvvm_cvt_sm80() {
// CHECK: ret void
}

#define NAN32 0x7FBFFFFF
#define NAN16 (__bf16)0x7FBF
#define BF16 (__bf16)0.1f
#define BF16_2 (__bf16)0.2f
#define NANBF16 (__bf16)0xFFC1
#define BF16X2 {(__bf16)0.1f, (__bf16)0.1f}
#define BF16X2_2 {(__bf16)0.2f, (__bf16)0.2f}
#define NANBF16X2 {NANBF16, NANBF16}

// CHECK-LABEL: nvvm_abs_neg_bf16_bf16x2_sm80
__device__ void nvvm_abs_neg_bf16_bf16x2_sm80() {
#if __CUDA_ARCH__ >= 800

// CHECK_PTX70_SM80: call i16 @llvm.nvvm.abs.bf16(i16 -1)
__nvvm_abs_bf16(0xFFFF);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.abs.bf16x2(i32 -1)
__nvvm_abs_bf16x2(0xFFFFFFFF);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.abs.bf16(bfloat 0xR3DCD)
__nvvm_abs_bf16(BF16);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
__nvvm_abs_bf16x2(BF16X2);

// CHECK_PTX70_SM80: call i16 @llvm.nvvm.neg.bf16(i16 -1)
__nvvm_neg_bf16(0xFFFF);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.neg.bf16x2(i32 -1)
__nvvm_neg_bf16x2(0xFFFFFFFF);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.neg.bf16(bfloat 0xR3DCD)
__nvvm_neg_bf16(BF16);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
__nvvm_neg_bf16x2(BF16X2);
#endif
// CHECK: ret void
}

#define NAN32 0x7FBFFFFF
#define NAN16 0x7FBF
#define BF16 0x1234
#define BF16_2 0x4321
#define NANBF16 0xFFC1
#define BF16X2 0x12341234
#define BF16X2_2 0x32343234
#define NANBF16X2 0xFFC1FFC1

// CHECK-LABEL: nvvm_min_max_sm80
__device__ void nvvm_min_max_sm80() {
#if __CUDA_ARCH__ >= 800
Expand All @@ -4643,14 +4643,22 @@ __device__ void nvvm_min_max_sm80() {
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmin.ftz.nan.f
__nvvm_fmin_ftz_nan_f(0.1f, (float)NAN32);

// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.bf16
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.bf16
__nvvm_fmin_bf16(BF16, BF16_2);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.nan.bf16
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.bf16
__nvvm_fmin_ftz_bf16(BF16, BF16_2);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.nan.bf16
__nvvm_fmin_nan_bf16(BF16, NANBF16);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.bf16x2
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.nan.bf16
__nvvm_fmin_ftz_nan_bf16(BF16, NANBF16);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.bf16x2
__nvvm_fmin_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.nan.bf16x2
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.bf16x2
__nvvm_fmin_ftz_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2
__nvvm_fmin_nan_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.nan.bf16x2
__nvvm_fmin_ftz_nan_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
__nvvm_fmax_nan_f(0.1f, 0.11f);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
Expand All @@ -4660,14 +4668,22 @@ __device__ void nvvm_min_max_sm80() {
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
__nvvm_fmax_ftz_nan_f(0.1f, (float)NAN32);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.bf16
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.bf16
__nvvm_fmax_bf16(BF16, BF16_2);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.nan.bf16
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.bf16
__nvvm_fmax_ftz_bf16(BF16, BF16_2);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.nan.bf16
__nvvm_fmax_nan_bf16(BF16, NANBF16);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.bf16x2
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.nan.bf16
__nvvm_fmax_ftz_nan_bf16(BF16, NANBF16);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.bf16x2
__nvvm_fmax_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.nan.bf16x2
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.bf16x2
__nvvm_fmax_ftz_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2
__nvvm_fmax_nan_bf16x2(NANBF16X2, BF16X2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.nan.bf16x2
__nvvm_fmax_ftz_nan_bf16x2(NANBF16X2, BF16X2);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
Expand All @@ -4680,14 +4696,14 @@ __device__ void nvvm_min_max_sm80() {
// CHECK-LABEL: nvvm_fma_bf16_bf16x2_sm80
__device__ void nvvm_fma_bf16_bf16x2_sm80() {
#if __CUDA_ARCH__ >= 800
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.bf16
__nvvm_fma_rn_bf16(0x1234, 0x7FBF, 0x1234);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.relu.bf16
__nvvm_fma_rn_relu_bf16(0x1234, 0x7FBF, 0x1234);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.bf16x2
__nvvm_fma_rn_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.relu.bf16x2
__nvvm_fma_rn_relu_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.bf16
__nvvm_fma_rn_bf16(BF16, BF16_2, BF16_2);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.relu.bf16
__nvvm_fma_rn_relu_bf16(BF16, BF16_2, BF16_2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2
__nvvm_fma_rn_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2
__nvvm_fma_rn_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
#endif
// CHECK: ret void
}
Expand All @@ -4696,13 +4712,13 @@ __device__ void nvvm_fma_bf16_bf16x2_sm80() {
__device__ void nvvm_min_max_sm86() {
#if __CUDA_ARCH__ >= 860

// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.xorsign.abs.bf16
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.xorsign.abs.bf16
__nvvm_fmin_xorsign_abs_bf16(BF16, BF16_2);
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16
__nvvm_fmin_nan_xorsign_abs_bf16(BF16, NANBF16);
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2
__nvvm_fmin_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
__nvvm_fmin_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.xorsign.abs.f
__nvvm_fmin_xorsign_abs_f(-0.1f, 0.1f);
Expand All @@ -4713,13 +4729,13 @@ __device__ void nvvm_min_max_sm86() {
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f
__nvvm_fmin_ftz_nan_xorsign_abs_f(-0.1f, (float)NAN32);

// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.xorsign.abs.bf16
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.xorsign.abs.bf16
__nvvm_fmax_xorsign_abs_bf16(BF16, BF16_2);
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16
__nvvm_fmax_nan_xorsign_abs_bf16(BF16, NANBF16);
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2
__nvvm_fmax_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
__nvvm_fmax_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmax.xorsign.abs.f
__nvvm_fmax_xorsign_abs_f(-0.1f, 0.1f);
Expand Down
12 changes: 6 additions & 6 deletions clang/test/CodeGenCUDA/bf16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

// CHECK-LABEL: .visible .func _Z8test_argPDF16bDF16b(
// CHECK: .param .b64 _Z8test_argPDF16bDF16b_param_0,
// CHECK: .param .b16 _Z8test_argPDF16bDF16b_param_1
// CHECK: .param .align 2 .b8 _Z8test_argPDF16bDF16b_param_1[2]
//
__device__ void test_arg(__bf16 *out, __bf16 in) {
// CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
Expand All @@ -20,8 +20,8 @@ __device__ void test_arg(__bf16 *out, __bf16 in) {
}


// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retDF16b(
// CHECK: .param .b16 _Z8test_retDF16b_param_0
// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z8test_retDF16b(
// CHECK: .param .align 2 .b8 _Z8test_retDF16b_param_0[2]
__device__ __bf16 test_ret( __bf16 in) {
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_retDF16b_param_0];
return in;
Expand All @@ -31,12 +31,12 @@ __device__ __bf16 test_ret( __bf16 in) {

__device__ __bf16 external_func( __bf16 in);

// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callDF16b(
// CHECK: .param .b16 _Z9test_callDF16b_param_0
// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z9test_callDF16b(
// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2]
__device__ __bf16 test_call( __bf16 in) {
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
// CHECK: st.param.b16 [param0+0], %[[R]];
// CHECK: .param .b32 retval0;
// CHECK: .param .align 2 .b8 retval0[2];
// CHECK: call.uni (retval0),
// CHECK-NEXT: _Z13external_funcDF16b,
// CHECK-NEXT: (
Expand Down
Loading