-
Notifications
You must be signed in to change notification settings - Fork 794
[SYCL][CUDA][MATRIX] joint_matrix_bmad implementation #5363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
d34df92
04a4a34
48ab1f3
fc7ebbd
af13da9
0f423c6
f35956c
6c11a2c
95df95d
82f1996
95385cf
a693cd0
cabcee6
c05d2a1
c36bea2
48978cc
445a41f
7457400
0283942
42e2b17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,34 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2) | |
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2) | ||
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8) | ||
|
||
// single-bit | ||
template <matrix_layout Layout> | ||
struct joint_matrix< | ||
uint32_t, matrix_use::a, 8, 4, Layout, sycl::sub_group, | ||
typename std::enable_if_t<Layout == matrix_layout::row_major || | ||
Layout == matrix_layout::col_major>> { | ||
joint_matrix() { | ||
static_assert((Layout == matrix_layout::row_major), | ||
"For the matrix_use::a case, matrix_layout::row_major must " | ||
"be used for Bitwise MAD"); | ||
}; | ||
int32_t data[1]; | ||
}; | ||
|
||
template <matrix_layout Layout> | ||
struct joint_matrix< | ||
uint32_t, matrix_use::b, 4, 8, Layout, sycl::sub_group, | ||
typename std::enable_if_t<Layout == matrix_layout::row_major || | ||
Layout == matrix_layout::col_major>> { | ||
joint_matrix() { | ||
static_assert((Layout == matrix_layout::col_major), | ||
"For the matrix_use::b case, matrix_layout::col_major must " | ||
"be used for Bitwise MAD"); | ||
}; | ||
int32_t data[1]; | ||
}; | ||
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 8, int32_t, 2) | ||
|
||
#undef __SYCL_JOINT_MATRIX_OVERLOAD | ||
} // namespace experimental::matrix | ||
|
||
|
@@ -235,6 +263,28 @@ struct joint_matrix_load_impl< | |
get_layout_id<Layout>()); | ||
} | ||
|
||
} else if constexpr (std::is_same<T, double>::value) { | ||
if constexpr (Use == | ||
sycl::ext::oneapi::experimental::matrix::matrix_use::a) { | ||
__dmma_m8n8k4_ld_a(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: | ||
matrix_use::b) { | ||
__dmma_m8n8k4_ld_b(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: | ||
matrix_use::accumulator) { | ||
__dmma_m8n8k4_ld_c(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} | ||
} else if constexpr (NumRows == 8 && NumCols == 4) { | ||
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get()); | ||
__bmma_m8n8k128_ld_a_b1(res.data, tileptr, stride * 32, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (NumRows == 4 && NumCols == 8) { | ||
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get()); | ||
__bmma_m8n8k128_ld_b_b1(res.data, tileptr, stride * 32, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (std::is_same<T, int32_t>::value) { | ||
if constexpr (NumRows == 16 && NumCols == 16) { | ||
__imma_m16n16k16_ld_c(res.data, src.get(), stride, | ||
|
@@ -245,6 +295,9 @@ struct joint_matrix_load_impl< | |
} else if constexpr (NumRows == 32 && NumCols == 8) { | ||
__imma_m32n8k16_ld_c(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (NumRows == 8 && NumCols == 8) { | ||
__bmma_m8n8k128_ld_c(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} | ||
} else if constexpr (std::is_same<T, float>::value) { | ||
if constexpr (NumRows == 16 && NumCols == 16) { | ||
|
@@ -257,20 +310,6 @@ struct joint_matrix_load_impl< | |
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} | ||
} else if constexpr (std::is_same<T, double>::value) { | ||
if constexpr (Use == | ||
sycl::ext::oneapi::experimental::matrix::matrix_use::a) { | ||
__dmma_m8n8k4_ld_a(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: | ||
matrix_use::b) { | ||
__dmma_m8n8k4_ld_b(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: | ||
matrix_use::accumulator) { | ||
__dmma_m8n8k4_ld_c(res.data, src.get(), stride, | ||
get_layout_id<Layout>()); | ||
} | ||
} | ||
} | ||
}; | ||
|
@@ -339,6 +378,9 @@ struct joint_matrix_store_impl< | |
} else if constexpr (std::is_same<T, double>::value) { | ||
__dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride, | ||
get_layout_id<Layout>()); | ||
} else if constexpr (std::is_same<T, int32_t>::value) { | ||
__bmma_m8n8k128_st_c_i32(dst.get(), src.data, stride, | ||
get_layout_id<Layout>()); | ||
} | ||
} | ||
}; | ||
|
@@ -366,6 +408,31 @@ struct joint_matrix_mad_impl { | |
C); | ||
}; | ||
|
||
template <std::size_t M, std::size_t K, std::size_t N, | ||
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC, | ||
class BinaryOperation, typename Cond = void> | ||
struct joint_matrix_bmad_impl { | ||
sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, | ||
M, N, LayoutC, sycl::sub_group> | ||
bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, | ||
K, sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, | ||
sycl::sub_group> | ||
A, | ||
sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, | ||
N, sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, | ||
sycl::sub_group> | ||
B, | ||
sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
int32_t, | ||
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, | ||
N, LayoutC, sycl::sub_group> | ||
C, | ||
BinaryOperation Op); | ||
}; | ||
|
||
template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA, | ||
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB> | ||
constexpr int get_layout_pair_id(); | ||
|
@@ -495,14 +562,59 @@ struct joint_matrix_mad_impl< | |
get_layout_pair_id<LayoutA, LayoutB>(), 0); | ||
} | ||
} | ||
} else if constexpr (std::is_same<T1, double>::value) { | ||
} else if constexpr (M == 8 && N == 8 && K == 4) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this change related to bmad addition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No this is a superficial/non-important change that I made just for better consistency of the |
||
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, | ||
get_layout_pair_id<LayoutA, LayoutB>(), 0); | ||
} | ||
return D; | ||
} | ||
}; | ||
|
||
template <std::size_t M, std::size_t K, std::size_t N, | ||
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC, | ||
class BinaryOperation> | ||
struct joint_matrix_bmad_impl< | ||
M, K, N, LayoutC, BinaryOperation, | ||
typename std::enable_if_t<( | ||
LayoutC == | ||
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major || | ||
LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout:: | ||
col_major)>> { | ||
sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, | ||
M, N, LayoutC, sycl::sub_group> | ||
bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, | ||
K, sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major, | ||
sycl::sub_group> | ||
A, | ||
sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, | ||
N, sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major, | ||
sycl::sub_group> | ||
B, | ||
sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
int32_t, | ||
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, | ||
N, LayoutC, sycl::sub_group> | ||
C, | ||
BinaryOperation Op) { | ||
sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||
int32_t, | ||
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, | ||
LayoutC, sycl::sub_group> | ||
D; | ||
if constexpr (std::is_same<BinaryOperation, | ||
sycl::bit_and<uint32_t>>::value) { | ||
__bmma_m8n8k128_mma_and_popc_b1(D.data, A.data, B.data, C.data, 1); | ||
} else if constexpr (std::is_same<BinaryOperation, | ||
sycl::bit_xor<uint32_t>>::value) { | ||
__bmma_m8n8k128_mma_xor_popc_b1(D.data, A.data, B.data, C.data, 1); | ||
} | ||
return D; | ||
} | ||
}; | ||
|
||
} // namespace detail | ||
|
||
namespace experimental::matrix { | ||
|
@@ -573,6 +685,33 @@ joint_matrix_mad( | |
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
template <typename Group, std::size_t M, std::size_t K, std::size_t N, | ||
matrix_layout LayoutC, class BinaryOperation> | ||
joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group> | ||
joint_matrix_bmad( | ||
Group sg, | ||
joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::row_major, Group> | ||
A, | ||
joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::col_major, Group> | ||
B, | ||
joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group> C, | ||
BinaryOperation Op) { | ||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
return sycl::ext::oneapi::detail::joint_matrix_bmad_impl<M, K, N, LayoutC, | ||
BinaryOperation>{} | ||
.bmad(A, B, C, Op); | ||
#else | ||
(void)sg; | ||
(void)A; | ||
(void)B; | ||
(void)C; | ||
(void)Op; | ||
throw runtime_error("joint_matrix_bmad is " | ||
"only supported by CUDA devices", | ||
PI_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) | ||
} | ||
|
||
} // namespace experimental::matrix | ||
} // namespace oneapi | ||
} // namespace ext | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// REQUIRES: cuda | ||
|
||
// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s | ||
|
||
#include <CL/sycl.hpp> | ||
|
||
using namespace sycl; | ||
using namespace sycl::ext::oneapi::experimental::matrix; | ||
|
||
// M, N, (K * 32) define the sizes of dimensions of the three matrix types (a, | ||
// b, accumulator) used per subgroup operation. | ||
constexpr int M = 8; // number of rows of accumulator, | ||
// number of cols of b. | ||
constexpr int N = 8; // number of cols of accumulator, | ||
// number of rows of a. | ||
constexpr int K = 4; // number of cols of a/number of rows of b divided by 32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
should be: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here K=4 is not the number of cols in a subgroup matrix: you have to multiply by 32, since K gives a dimension of the arrays A/B which hold the single-bit matrix elements in uint32_t storage type. There are 32 single-bit matrix elements per uint32_t storage type. I tried to describe the purpose of these bitwise matrix multiplications here without going into too much detail. I added references for full details on the origins of the single-bit models and how they use the bitwise matrix multiplications within them. I do not find references to the usage of such "bitwise matrix multiplications" outside of such models (although of course this does not mean they don't exist/will exist in the future), but I think that It is important for such users to understand that each bit is considered an element of the matrix by I could add back the naming scheme A -> A_Packed, B-> B_Packed that I originally used when I switched to K=128 -> K=4 to make is clearer that I am calling "a" the matrix and "A_Packed" a packed array representation of the matrix? |
||
|
||
// Each bit of each uint32_t A/B array element is an element of a single-bit | ||
// matrix. joint_matrix_bmad performs Binary Dot Products on these matrices (see | ||
// M. Rastegari et al. Computer Vision – ECCV 2016, 525-542 and A. Li et al. | ||
// IEEE Transactions on Parallel and Distributed Systems, 32(7):1878-1891, | ||
// 2021)) | ||
uint32_t A[M * K]; | ||
uint32_t B[K * N]; | ||
int32_t C[M * N]; | ||
int32_t D[M * N]; | ||
|
||
int main() { | ||
|
||
buffer<uint32_t, 1> bufA(A, range<1>(M * K)); | ||
buffer<uint32_t, 1> bufB(B, range<1>(K * N)); | ||
buffer<int32_t, 1> bufC(C, range<1>(M * N)); | ||
buffer<int32_t, 1> bufD(D, range<1>(M * N)); | ||
|
||
queue q; | ||
|
||
q.submit([&](handler &cgh) { | ||
auto accC = bufC.get_access<access::mode::read_write>(cgh); | ||
auto accA = bufA.get_access<access::mode::read_write>(cgh); | ||
auto accB = bufB.get_access<access::mode::read_write>(cgh); | ||
auto accD = bufD.get_access<access::mode::read_write>(cgh); | ||
|
||
cgh.parallel_for<class row_col>( | ||
nd_range<2>({1, 32}, {1, 32}), | ||
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { | ||
sycl::sub_group sg = item.get_sub_group(); | ||
|
||
joint_matrix<int32_t, matrix_use::accumulator, M, N, | ||
matrix_layout::row_major> | ||
sub_c; | ||
|
||
joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::row_major> | ||
sub_a; | ||
|
||
joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::col_major> | ||
sub_b; | ||
|
||
//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 8) #{{.*}} | ||
joint_matrix_load(sg, sub_c, accC.get_pointer(), N); | ||
//CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.a.row.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128) #{{.*}} | ||
joint_matrix_load(sg, sub_a, accA.get_pointer(), K); | ||
//CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.b.col.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128) #{{.*}} | ||
joint_matrix_load(sg, sub_b, accB.get_pointer(), K); | ||
//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1(i32 %3, i32 %4, i32 %1, i32 %2) #{{.*}} | ||
sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c, | ||
sycl::bit_xor<uint32_t>()); | ||
//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1(i32 %3, i32 %4, i32 %6, i32 %7) #{{.*}} | ||
sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c, | ||
sycl::bit_and<uint32_t>()); | ||
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k128.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %9, i32 %10, i32 8) #{{.*}} | ||
joint_matrix_store(sg, sub_c, accD.get_pointer(), N); | ||
}); | ||
}); | ||
|
||
return 0; | ||
}; |
Uh oh!
There was an error while loading. Please reload this page.