Skip to content

[SYCL][CUDA][MATRIX] joint_matrix_bmad implementation #5363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 154 additions & 15 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,34 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2)
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2)
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8)

// single-bit
template <matrix_layout Layout>
struct joint_matrix<
uint32_t, matrix_use::a, 8, 4, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
joint_matrix() {
static_assert((Layout == matrix_layout::row_major),
"For the matrix_use::a case, matrix_layout::row_major must "
"be used for Bitwise MAD");
};
int32_t data[1];
};

template <matrix_layout Layout>
struct joint_matrix<
uint32_t, matrix_use::b, 4, 8, Layout, sycl::sub_group,
typename std::enable_if_t<Layout == matrix_layout::row_major ||
Layout == matrix_layout::col_major>> {
joint_matrix() {
static_assert((Layout == matrix_layout::col_major),
"For the matrix_use::b case, matrix_layout::col_major must "
"be used for Bitwise MAD");
};
int32_t data[1];
};
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 8, int32_t, 2)

#undef __SYCL_JOINT_MATRIX_OVERLOAD
} // namespace experimental::matrix

Expand Down Expand Up @@ -235,6 +263,28 @@ struct joint_matrix_load_impl<
get_layout_id<Layout>());
}

} else if constexpr (std::is_same<T, double>::value) {
if constexpr (Use ==
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
__dmma_m8n8k4_ld_a(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
matrix_use::b) {
__dmma_m8n8k4_ld_b(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
matrix_use::accumulator) {
__dmma_m8n8k4_ld_c(res.data, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (NumRows == 8 && NumCols == 4) {
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
__bmma_m8n8k128_ld_a_b1(res.data, tileptr, stride * 32,
get_layout_id<Layout>());
} else if constexpr (NumRows == 4 && NumCols == 8) {
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
__bmma_m8n8k128_ld_b_b1(res.data, tileptr, stride * 32,
get_layout_id<Layout>());
} else if constexpr (std::is_same<T, int32_t>::value) {
if constexpr (NumRows == 16 && NumCols == 16) {
__imma_m16n16k16_ld_c(res.data, src.get(), stride,
Expand All @@ -245,6 +295,9 @@ struct joint_matrix_load_impl<
} else if constexpr (NumRows == 32 && NumCols == 8) {
__imma_m32n8k16_ld_c(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 8 && NumCols == 8) {
__bmma_m8n8k128_ld_c(res.data, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, float>::value) {
if constexpr (NumRows == 16 && NumCols == 16) {
Expand All @@ -257,20 +310,6 @@ struct joint_matrix_load_impl<
__hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, double>::value) {
if constexpr (Use ==
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
__dmma_m8n8k4_ld_a(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
matrix_use::b) {
__dmma_m8n8k4_ld_b(res.data, src.get(), stride,
get_layout_id<Layout>());
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
matrix_use::accumulator) {
__dmma_m8n8k4_ld_c(res.data, src.get(), stride,
get_layout_id<Layout>());
}
}
}
};
Expand Down Expand Up @@ -339,6 +378,9 @@ struct joint_matrix_store_impl<
} else if constexpr (std::is_same<T, double>::value) {
__dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride,
get_layout_id<Layout>());
} else if constexpr (std::is_same<T, int32_t>::value) {
__bmma_m8n8k128_st_c_i32(dst.get(), src.data, stride,
get_layout_id<Layout>());
}
}
};
Expand Down Expand Up @@ -366,6 +408,31 @@ struct joint_matrix_mad_impl {
C);
};

template <std::size_t M, std::size_t K, std::size_t N,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
class BinaryOperation, typename Cond = void>
struct joint_matrix_bmad_impl {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
M, N, LayoutC, sycl::sub_group>
bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M,
K, sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
sycl::sub_group>
A,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K,
N, sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
sycl::sub_group>
B,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
N, LayoutC, sycl::sub_group>
C,
BinaryOperation Op);
};

template <sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutA,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutB>
constexpr int get_layout_pair_id();
Expand Down Expand Up @@ -495,14 +562,59 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
}
} else if constexpr (std::is_same<T1, double>::value) {
} else if constexpr (M == 8 && N == 8 && K == 4) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related to bmad addition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is a superficial/non-important change that I made just for better consistency of the if constexpr statements in this function.

__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
return D;
}
};

template <std::size_t M, std::size_t K, std::size_t N,
sycl::ext::oneapi::experimental::matrix::matrix_layout LayoutC,
class BinaryOperation>
struct joint_matrix_bmad_impl<
M, K, N, LayoutC, BinaryOperation,
typename std::enable_if_t<(
LayoutC ==
sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major ||
LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
col_major)>> {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
M, N, LayoutC, sycl::sub_group>
bmad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M,
K, sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
sycl::sub_group>
A,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
uint32_t, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K,
N, sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
sycl::sub_group>
B,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
N, LayoutC, sycl::sub_group>
C,
BinaryOperation Op) {
sycl::ext::oneapi::experimental::matrix::joint_matrix<
int32_t,
sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N,
LayoutC, sycl::sub_group>
D;
if constexpr (std::is_same<BinaryOperation,
sycl::bit_and<uint32_t>>::value) {
__bmma_m8n8k128_mma_and_popc_b1(D.data, A.data, B.data, C.data, 1);
} else if constexpr (std::is_same<BinaryOperation,
sycl::bit_xor<uint32_t>>::value) {
__bmma_m8n8k128_mma_xor_popc_b1(D.data, A.data, B.data, C.data, 1);
}
return D;
}
};

} // namespace detail

namespace experimental::matrix {
Expand Down Expand Up @@ -573,6 +685,33 @@ joint_matrix_mad(
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

template <typename Group, std::size_t M, std::size_t K, std::size_t N,
matrix_layout LayoutC, class BinaryOperation>
joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group>
joint_matrix_bmad(
Group sg,
joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::row_major, Group>
A,
joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::col_major, Group>
B,
joint_matrix<int32_t, matrix_use::accumulator, M, N, LayoutC, Group> C,
BinaryOperation Op) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
return sycl::ext::oneapi::detail::joint_matrix_bmad_impl<M, K, N, LayoutC,
BinaryOperation>{}
.bmad(A, B, C, Op);
#else
(void)sg;
(void)A;
(void)B;
(void)C;
(void)Op;
throw runtime_error("joint_matrix_bmad is "
"only supported by CUDA devices",
PI_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
}

} // namespace experimental::matrix
} // namespace oneapi
} // namespace ext
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// REQUIRES: cuda

// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s

#include <CL/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

// M, N, (K * 32) define the sizes of dimensions of the three matrix types (a,
// b, accumulator) used per subgroup operation.
constexpr int M = 8; // number of rows of accumulator,
// number of cols of b.
constexpr int N = 8; // number of cols of accumulator,
// number of rows of a.
constexpr int K = 4; // number of cols of a/number of rows of b divided by 32
Copy link
Contributor

@dkhaldi dkhaldi Feb 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number of cols of a/number of rows of b divided by 32

should be:
number of bits in cols of a/number of bits in rows of B divided by 32.
If this is true, do we need to add the "divided by 32" in the code example.
I meant before to add the "multiplies by 32" in the implementation code to explain that this is how we get number of bits that exist in the intrinsics. But at the user level code, is this needed?

Copy link
Contributor Author

@JackAKirk JackAKirk Feb 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here K=4 is not the number of cols in a subgroup matrix: you have to multiply by 32, since K gives a dimension of the arrays A/B which hold the single-bit matrix elements in uint32_t storage type. There are 32 single-bit matrix elements per uint32_t storage type.

I tried to describe the purpose of these bitwise matrix multiplications here without going into too much detail. I added references for full details on the origins of the single-bit models and how they use the bitwise matrix multiplications within them. I do not find references to the usage of such "bitwise matrix multiplications" outside of such models (although of course this does not mean they don't exist/will exist in the future), but I think that
this functionality was introduced specifically with such use cases in mind.

It is important for such users to understand that each bit is considered an element of the matrix by joint_matrix_bmad (the matrix element is "quantized" to a single bit), which is why in the original implementation I set K = 128. However as you pointed out this leads to lots of factors of 32 because we have to divide by 32 to get the number of uint32_t array elements that are used to store the matrix.
In the current implementation it is nice that these factors are gone, but there should still be proper documentation (see here) describing the relationship between "K" and the actual number of (single-bit) matrix elements. Since this is experimental I think it is normal to expect that once people start using this there could be feedback suggesting small changes to the interface: I'm not sure whether the interface I originally set up that led to the factors of 32 or the one you suggested is preferable for the users, but I imagined that at this experimental stage it can (and I imagine most likely will!) be changed in some way in the future anyway.

I could add back the naming scheme A -> A_Packed, B-> B_Packed that I originally used when I switched to K=128 -> K=4 to make is clearer that I am calling "a" the matrix and "A_Packed" a packed array representation of the matrix?
Then I could also add some more detailed description in both the tests and the implementation? I did not want to go into too much detail in tests/implementation because I thought that the proper place for such descriptions would be in the documentation of the extension? This is why I kept things concise here and did not mention details in the implementation.


// Each bit of each uint32_t A/B array element is an element of a single-bit
// matrix. joint_matrix_bmad performs Binary Dot Products on these matrices (see
// M. Rastegari et al. Computer Vision – ECCV 2016, 525-542 and A. Li et al.
// IEEE Transactions on Parallel and Distributed Systems, 32(7):1878-1891,
// 2021))
uint32_t A[M * K];
uint32_t B[K * N];
int32_t C[M * N];
int32_t D[M * N];

int main() {

buffer<uint32_t, 1> bufA(A, range<1>(M * K));
buffer<uint32_t, 1> bufB(B, range<1>(K * N));
buffer<int32_t, 1> bufC(C, range<1>(M * N));
buffer<int32_t, 1> bufD(D, range<1>(M * N));

queue q;

q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
auto accD = bufD.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class row_col>(
nd_range<2>({1, 32}, {1, 32}),
[=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
sycl::sub_group sg = item.get_sub_group();

joint_matrix<int32_t, matrix_use::accumulator, M, N,
matrix_layout::row_major>
sub_c;

joint_matrix<uint32_t, matrix_use::a, M, K, matrix_layout::row_major>
sub_a;

joint_matrix<uint32_t, matrix_use::b, K, N, matrix_layout::col_major>
sub_b;

//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 8) #{{.*}}
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
//CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.a.row.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128) #{{.*}}
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
//CHECK: tail call i32 @llvm.nvvm.wmma.m8n8k128.load.b.col.stride.b1.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 128) #{{.*}}
joint_matrix_load(sg, sub_b, accB.get_pointer(), K);
//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1(i32 %3, i32 %4, i32 %1, i32 %2) #{{.*}}
sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c,
sycl::bit_xor<uint32_t>());
//CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1(i32 %3, i32 %4, i32 %6, i32 %7) #{{.*}}
sub_c = joint_matrix_bmad(sg, sub_a, sub_b, sub_c,
sycl::bit_and<uint32_t>());
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k128.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %9, i32 %10, i32 8) #{{.*}}
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
});
});

return 0;
};