diff --git a/libclc/ptx-nvidiacl/libspirv/math/fabs.cl b/libclc/ptx-nvidiacl/libspirv/math/fabs.cl index 4f12e85310a0..2042076877a0 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fabs.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fabs.cl @@ -18,12 +18,16 @@ // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fabs(ushort x) { - return __nvvm_abs_bf16(x); + ushort res; + __asm__("abs.bf16 %0, %1;" : "=h"(res) : "h"(x)); + return res; } _CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fabs, ushort) // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fabs(uint x) { - return __nvvm_abs_bf16x2(x); + uint res; + __asm__("abs.bf16x2 %0, %1;" : "=r"(res) : "r"(x)); + return res; } _CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fabs, uint) diff --git a/libclc/ptx-nvidiacl/libspirv/math/fma.cl b/libclc/ptx-nvidiacl/libspirv/math/fma.cl index 738951681e66..e68d6f52b885 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fma.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fma.cl @@ -50,14 +50,18 @@ _CLC_TERNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fma, // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fma(ushort x, ushort y, ushort z) { - return __nvvm_fma_rn_bf16(x, y, z); + ushort res; + __asm__("fma.rn.bf16 %0, %1, %2, %3;" : "=h"(res) : "h"(x), "h"(y), "h"(z)); + return res; } _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fma, ushort, ushort, ushort) // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fma(uint x, uint y, uint z) { - return __nvvm_fma_rn_bf16x2(x, y, z); + uint res; + __asm__("fma.rn.bf16x2 %0, %1, %2, %3;" : "=r"(res) : "r"(x), "r"(y), "r"(z)); + return res; } _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fma, uint, uint, uint) diff --git a/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl b/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl index b48e25c7c628..845b53237c79 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl @@ -40,7 +40,11 @@ _CLC_TERNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __clc_fma_relu, _CLC_DEF _CLC_OVERLOAD ushort __clc_fma_relu(ushort x, ushort y, ushort z) { if (__clc_nvvm_reflect_arch() >= 800) { - return __nvvm_fma_rn_relu_bf16(x, y, z); + ushort res; + __asm__("fma.rn.relu.bf16 %0, %1, %2, %3;" + : "=h"(res) + : "h"(x), "h"(y), "h"(z)); + return res; } __builtin_trap(); __builtin_unreachable(); @@ -50,7 +54,11 @@ _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fma_relu, _CLC_DEF _CLC_OVERLOAD uint __clc_fma_relu(uint x, uint y, uint z) { if (__clc_nvvm_reflect_arch() >= 800) { - return __nvvm_fma_rn_relu_bf16x2(x, y, z); + uint res; + __asm__("fma.rn.relu.bf16x2 %0, %1, %2, %3;" + : "=r"(res) + : "r"(x), "r"(y), "r"(z)); + return res; } __builtin_trap(); __builtin_unreachable(); diff --git a/libclc/ptx-nvidiacl/libspirv/math/fmax.cl b/libclc/ptx-nvidiacl/libspirv/math/fmax.cl index d9dac6e75251..780270d41152 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fmax.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fmax.cl @@ -53,14 +53,18 @@ _CLC_BINARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fmax, // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fmax(ushort x, ushort y) { - return __nvvm_fmax_bf16(x, y); + ushort res; + __asm__("max.bf16 %0, %1, %2;" : "=h"(res) : "h"(x), "h"(y)); + return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmax, ushort, ushort) // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fmax(uint x, uint y) { - return __nvvm_fmax_bf16x2(x, y); + uint res; + __asm__("max.bf16x2 %0, %1, %2;" : "=r"(res) : "r"(x), "r"(y)); + return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fmax, uint, uint) diff --git a/libclc/ptx-nvidiacl/libspirv/math/fmin.cl b/libclc/ptx-nvidiacl/libspirv/math/fmin.cl index 167e65cdc5ec..1cff1cf73754 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fmin.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fmin.cl @@ -53,14 +53,18 @@ _CLC_BINARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fmin, half // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fmin(ushort x, ushort y) { - return __nvvm_fmin_bf16(x, y); + ushort res; + __asm__("min.bf16 %0, %1, %2;" : "=h"(res) : "h"(x), "h"(y)); + return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmin, ushort, ushort) // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fmin(uint x, uint y) { - return __nvvm_fmin_bf16x2(x, y); + uint res; + __asm__("min.bf16x2 %0, %1, %2;" : "=r"(res) : "r"(x), "r"(y)); + return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fmin, uint, uint) diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 46ffb1590088..bd3052e9a048 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -64,7 +64,9 @@ class bfloat16 { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) #if (__SYCL_CUDA_ARCH__ >= 800) - return __nvvm_f2bf16_rn(a); + detail::Bfloat16StorageT res; + asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(res) : "f"(a)); + return res; #else return from_float_fallback(a); #endif @@ -120,7 +122,9 @@ class bfloat16 { friend bfloat16 operator-(bfloat16 &lhs) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \ (__SYCL_CUDA_ARCH__ >= 800) - return detail::bitsToBfloat16(__nvvm_neg_bf16(lhs.value)); + detail::Bfloat16StorageT res; + asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value)); + return detail::bitsToBfloat16(res); #elif defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__) return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)}; #else