From b6c12a70f2084c23637d653c3b4bb407e07bb87c Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 4 Aug 2023 09:31:12 -0700 Subject: [PATCH 1/2] Swapped bf16 builtins with inline asm. Signed-off-by: JackAKirk --- libclc/ptx-nvidiacl/libspirv/math/fabs.cl | 8 ++++++-- libclc/ptx-nvidiacl/libspirv/math/fma.cl | 8 ++++++-- libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl | 8 ++++++-- libclc/ptx-nvidiacl/libspirv/math/fmax.cl | 8 ++++++-- libclc/ptx-nvidiacl/libspirv/math/fmin.cl | 8 ++++++-- sycl/include/sycl/ext/oneapi/bfloat16.hpp | 8 ++++++-- 6 files changed, 36 insertions(+), 12 deletions(-) diff --git a/libclc/ptx-nvidiacl/libspirv/math/fabs.cl b/libclc/ptx-nvidiacl/libspirv/math/fabs.cl index 4f12e85310a01..6f8c681621e43 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fabs.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fabs.cl @@ -18,12 +18,16 @@ // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fabs(ushort x) { - return __nvvm_abs_bf16(x); + ushort res; + __asm__ volatile("abs.bf16 %0,%1;" : "=h"(res) : "h"(x)); + return res; } _CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fabs, ushort) // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fabs(uint x) { - return __nvvm_abs_bf16x2(x); + uint res; + __asm__ volatile("abs.bf16x2 %0,%1;" : "=r"(res) : "r"(x)); + return res; } _CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fabs, uint) diff --git a/libclc/ptx-nvidiacl/libspirv/math/fma.cl b/libclc/ptx-nvidiacl/libspirv/math/fma.cl index 738951681e66e..d453629702626 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fma.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fma.cl @@ -50,14 +50,18 @@ _CLC_TERNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fma, // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fma(ushort x, ushort y, ushort z) { - return __nvvm_fma_rn_bf16(x, y, z); + ushort res; + __asm__ volatile("fma.rn.bf16 %0,%1,%2,%3;" : "=h"(res) : "h"(x), "h"(y), "h"(z)); + return res; } _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fma, ushort, ushort, ushort) // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fma(uint x, uint y, uint z) { - return __nvvm_fma_rn_bf16x2(x, y, z); + uint res; + __asm__ volatile("fma.rn.bf16x2 %0,%1,%2,%3;" : "=r"(res) : "r"(x), "r"(y), "r"(z)); + return res; } _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fma, uint, uint, uint) diff --git a/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl b/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl index b48e25c7c628d..8f47cf7ba9755 100755 --- a/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl @@ -40,7 +40,9 @@ _CLC_TERNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __clc_fma_relu, _CLC_DEF _CLC_OVERLOAD ushort __clc_fma_relu(ushort x, ushort y, ushort z) { if (__clc_nvvm_reflect_arch() >= 800) { - return __nvvm_fma_rn_relu_bf16(x, y, z); + ushort res; + __asm__ volatile("fma.rn.relu.bf16 %0,%1,%2,%3;" : "=h"(res) : "h"(x), "h"(y), "h"(z)); + return res; } __builtin_trap(); __builtin_unreachable(); @@ -50,7 +52,9 @@ _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fma_relu, _CLC_DEF _CLC_OVERLOAD uint __clc_fma_relu(uint x, uint y, uint z) { if (__clc_nvvm_reflect_arch() >= 800) { - return __nvvm_fma_rn_relu_bf16x2(x, y, z); + uint res; + __asm__ volatile("fma.rn.relu.bf16x2 %0,%1,%2,%3;" : "=r"(res) : "r"(x), "r"(y), "r"(z)); + return res; } __builtin_trap(); __builtin_unreachable(); diff --git a/libclc/ptx-nvidiacl/libspirv/math/fmax.cl b/libclc/ptx-nvidiacl/libspirv/math/fmax.cl index d9dac6e752513..af53469ba03bb 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fmax.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fmax.cl @@ -53,14 +53,18 @@ _CLC_BINARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fmax, // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fmax(ushort x, ushort y) { - return __nvvm_fmax_bf16(x, y); + ushort res; + __asm__ volatile("max.bf16 %0,%1,%2;" : "=h"(res) : "h"(x), "h"(y)); + return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmax, ushort, ushort) // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fmax(uint x, uint y) { - return __nvvm_fmax_bf16x2(x, y); + uint res; + __asm__ volatile("max.bf16x2 %0,%1,%2;" : "=r"(res) : "r"(x), "r"(y)); + return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fmax, uint, uint) diff --git a/libclc/ptx-nvidiacl/libspirv/math/fmin.cl b/libclc/ptx-nvidiacl/libspirv/math/fmin.cl index 167e65cdc5ec8..3995af0a50e44 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fmin.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fmin.cl @@ -53,14 +53,18 @@ _CLC_BINARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fmin, half // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fmin(ushort x, ushort y) { - return __nvvm_fmin_bf16(x, y); + ushort res; + __asm__ volatile("min.bf16 %0,%1,%2;" : "=h"(res) : "h"(x), "h"(y)); + return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmin, ushort, ushort) // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fmin(uint x, uint y) { - return __nvvm_fmin_bf16x2(x, y); + uint res; + __asm__ volatile("min.bf16x2 %0,%1,%2;" : "=r"(res) : "r"(x), "r"(y)); + return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fmin, uint, uint) diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 13d077b0f1d65..ca82c0b279c69 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -61,7 +61,9 @@ class bfloat16 { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) #if (__SYCL_CUDA_ARCH__ >= 800) - return __nvvm_f2bf16_rn(a); + detail::Bfloat16StorageT res; + asm volatile("cvt.rn.bf16.f32 %0,%1;" : "=h"(res) : "f"(a)); + return res; #else return from_float_fallback(a); #endif @@ -117,7 +119,9 @@ class bfloat16 { friend bfloat16 operator-(bfloat16 &lhs) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \ (__SYCL_CUDA_ARCH__ >= 800) - return detail::bitsToBfloat16(__nvvm_neg_bf16(lhs.value)); + detail::Bfloat16StorageT res; + asm volatile("neg.bf16 %0,%1;" : "=h"(res) : "h"(lhs.value)); + return detail::bitsToBfloat16(res); #elif defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__) return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)}; #else From 6342e90bfd194406b18ba28932506424af70e5ee Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 8 Aug 2023 05:54:13 -0700 Subject: [PATCH 2/2] Removed all volatile usage from asm. Improved formating. Signed-off-by: JackAKirk --- libclc/ptx-nvidiacl/libspirv/math/fabs.cl | 4 ++-- libclc/ptx-nvidiacl/libspirv/math/fma.cl | 4 ++-- libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl | 8 ++++++-- libclc/ptx-nvidiacl/libspirv/math/fmax.cl | 4 ++-- libclc/ptx-nvidiacl/libspirv/math/fmin.cl | 4 ++-- sycl/include/sycl/ext/oneapi/bfloat16.hpp | 4 ++-- 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/libclc/ptx-nvidiacl/libspirv/math/fabs.cl b/libclc/ptx-nvidiacl/libspirv/math/fabs.cl index 6f8c681621e43..2042076877a06 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fabs.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fabs.cl @@ -19,7 +19,7 @@ // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fabs(ushort x) { ushort res; - __asm__ volatile("abs.bf16 %0,%1;" : "=h"(res) : "h"(x)); + __asm__("abs.bf16 %0, %1;" : "=h"(res) : "h"(x)); return res; } _CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fabs, ushort) @@ -27,7 +27,7 @@ _CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fabs, ushort) // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fabs(uint x) { uint res; - __asm__ volatile("abs.bf16x2 %0,%1;" : "=r"(res) : "r"(x)); + __asm__("abs.bf16x2 %0, %1;" : "=r"(res) : "r"(x)); return res; } _CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fabs, uint) diff --git a/libclc/ptx-nvidiacl/libspirv/math/fma.cl b/libclc/ptx-nvidiacl/libspirv/math/fma.cl index d453629702626..e68d6f52b885f 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fma.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fma.cl @@ -51,7 +51,7 @@ _CLC_TERNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fma, // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fma(ushort x, ushort y, ushort z) { ushort res; - __asm__ volatile("fma.rn.bf16 %0,%1,%2,%3;" : "=h"(res) : "h"(x), "h"(y), "h"(z)); + __asm__("fma.rn.bf16 %0, %1, %2, %3;" : "=h"(res) : "h"(x), "h"(y), "h"(z)); return res; } _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fma, ushort, @@ -60,7 +60,7 @@ _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fma, ushort, // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fma(uint x, uint y, uint z) { uint res; - __asm__ volatile("fma.rn.bf16x2 %0,%1,%2,%3;" : "=r"(res) : "r"(x), "r"(y), "r"(z)); + __asm__("fma.rn.bf16x2 %0, %1, %2, %3;" : "=r"(res) : "r"(x), "r"(y), "r"(z)); return res; } _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fma, uint, diff --git a/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl b/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl index 8f47cf7ba9755..845b53237c794 100755 --- a/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fma_relu.cl @@ -41,7 +41,9 @@ _CLC_DEF _CLC_OVERLOAD ushort __clc_fma_relu(ushort x, ushort y, ushort z) { if (__clc_nvvm_reflect_arch() >= 800) { ushort res; - __asm__ volatile("fma.rn.relu.bf16 %0,%1,%2,%3;" : "=h"(res) : "h"(x), "h"(y), "h"(z)); + __asm__("fma.rn.relu.bf16 %0, %1, %2, %3;" + : "=h"(res) + : "h"(x), "h"(y), "h"(z)); return res; } __builtin_trap(); @@ -53,7 +55,9 @@ _CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fma_relu, _CLC_DEF _CLC_OVERLOAD uint __clc_fma_relu(uint x, uint y, uint z) { if (__clc_nvvm_reflect_arch() >= 800) { uint res; - __asm__ volatile("fma.rn.relu.bf16x2 %0,%1,%2,%3;" : "=r"(res) : "r"(x), "r"(y), "r"(z)); + __asm__("fma.rn.relu.bf16x2 %0, %1, %2, %3;" + : "=r"(res) + : "r"(x), "r"(y), "r"(z)); return res; } __builtin_trap(); diff --git a/libclc/ptx-nvidiacl/libspirv/math/fmax.cl b/libclc/ptx-nvidiacl/libspirv/math/fmax.cl index af53469ba03bb..780270d411522 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fmax.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fmax.cl @@ -54,7 +54,7 @@ _CLC_BINARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fmax, // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fmax(ushort x, ushort y) { ushort res; - __asm__ volatile("max.bf16 %0,%1,%2;" : "=h"(res) : "h"(x), "h"(y)); + __asm__("max.bf16 %0, %1, %2;" : "=h"(res) : "h"(x), "h"(y)); return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmax, ushort, @@ -63,7 +63,7 @@ _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmax, ushort, // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fmax(uint x, uint y) { uint res; - __asm__ volatile("max.bf16x2 %0,%1,%2;" : "=r"(res) : "r"(x), "r"(y)); + __asm__("max.bf16x2 %0, %1, %2;" : "=r"(res) : "r"(x), "r"(y)); return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fmax, uint, diff --git a/libclc/ptx-nvidiacl/libspirv/math/fmin.cl b/libclc/ptx-nvidiacl/libspirv/math/fmin.cl index 3995af0a50e44..1cff1cf73754c 100644 --- a/libclc/ptx-nvidiacl/libspirv/math/fmin.cl +++ b/libclc/ptx-nvidiacl/libspirv/math/fmin.cl @@ -54,7 +54,7 @@ _CLC_BINARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_fmin, half // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD ushort __clc_fmin(ushort x, ushort y) { ushort res; - __asm__ volatile("min.bf16 %0,%1,%2;" : "=h"(res) : "h"(x), "h"(y)); + __asm__("min.bf16 %0, %1, %2;" : "=h"(res) : "h"(x), "h"(y)); return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmin, ushort, @@ -63,7 +63,7 @@ _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, __clc_fmin, ushort, // Requires at least sm_80 _CLC_DEF _CLC_OVERLOAD uint __clc_fmin(uint x, uint y) { uint res; - __asm__ volatile("min.bf16x2 %0,%1,%2;" : "=r"(res) : "r"(x), "r"(y)); + __asm__("min.bf16x2 %0, %1, %2;" : "=r"(res) : "r"(x), "r"(y)); return res; } _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, __clc_fmin, uint, diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index ca82c0b279c69..2032168ad08a7 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -62,7 +62,7 @@ class bfloat16 { #if defined(__NVPTX__) #if (__SYCL_CUDA_ARCH__ >= 800) detail::Bfloat16StorageT res; - asm volatile("cvt.rn.bf16.f32 %0,%1;" : "=h"(res) : "f"(a)); + asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(res) : "f"(a)); return res; #else return from_float_fallback(a); @@ -120,7 +120,7 @@ class bfloat16 { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \ (__SYCL_CUDA_ARCH__ >= 800) detail::Bfloat16StorageT res; - asm volatile("neg.bf16 %0,%1;" : "=h"(res) : "h"(lhs.value)); + asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value)); return detail::bitsToBfloat16(res); #elif defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__) return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)};