Skip to content

Commit 711ba58

Browse files
authored
[SYCL][CUDA][Matrix] Add initial support for Tensorcore matrix ext (#4696)
Initial Implementation based on the new matrix extension supporting Nvidia Tensorcore, #4695, that is adapted from the AMX matrix extension. Only double data type matrix elements are initially supported. Signed-off-by: jack.kirk <jack.kirk@codeplay.com>
1 parent 9340c96 commit 711ba58

File tree

3 files changed

+363
-0
lines changed

3 files changed

+363
-0
lines changed
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
#pragma once
2+
3+
#include <CL/sycl/detail/defines_elementary.hpp>
4+
#include <immintrin.h>
5+
6+
__SYCL_INLINE_NAMESPACE(cl) {
7+
namespace sycl {
8+
namespace ext {
9+
namespace oneapi {
10+
namespace experimental::matrix {
11+
12+
enum class matrix_use { a, b, accumulator };
13+
14+
enum class matrix_layout { row_major, col_major, packed_a, packed_b };
15+
16+
template <typename T, matrix_use MT, size_t Rows = sycl::dynamic_extent,
17+
size_t Cols = sycl::dynamic_extent,
18+
matrix_layout Layout = matrix_layout::row_major,
19+
typename Group = sycl::sub_group, typename Cond = void>
20+
struct joint_matrix {
21+
joint_matrix(Group g) {}
22+
};
23+
24+
// The enable_if_t usage in this file is used to disable the
25+
// matrix_layout::packed case which is not compatible with the Nvidia cuda
26+
// backend.
27+
template <matrix_layout Layout>
28+
struct joint_matrix<
29+
double, matrix_use::a, 8, 4, Layout, sycl::sub_group,
30+
typename std::enable_if_t<Layout == matrix_layout::row_major ||
31+
Layout == matrix_layout::col_major>> {
32+
double data[1];
33+
};
34+
35+
template <matrix_layout Layout>
36+
struct joint_matrix<
37+
double, matrix_use::b, 4, 8, Layout, sycl::sub_group,
38+
typename std::enable_if_t<(Layout == matrix_layout::row_major ||
39+
Layout == matrix_layout::col_major)>> {
40+
double data[1];
41+
};
42+
43+
template <matrix_layout Layout>
44+
struct joint_matrix<
45+
double, matrix_use::accumulator, 8, 8, Layout, sycl::sub_group,
46+
typename std::enable_if_t<Layout == matrix_layout::row_major ||
47+
Layout == matrix_layout::col_major>> {
48+
double data[2];
49+
};
50+
51+
} // namespace experimental::matrix
52+
53+
namespace detail {
54+
using namespace experimental;
55+
56+
template <typename T, matrix::matrix_use MT, size_t NumRows, size_t NumCols,
57+
matrix::matrix_layout Layout, access::address_space Space,
58+
typename Cond = void>
59+
struct joint_matrix_load_impl {
60+
void load(matrix::joint_matrix<T, MT, NumRows, NumCols, Layout> &res,
61+
multi_ptr<T, Space> src, size_t stride);
62+
};
63+
64+
template <matrix::matrix_layout Layout> constexpr int get_layout_id();
65+
66+
template <> constexpr int get_layout_id<matrix::matrix_layout::row_major>() {
67+
return 0;
68+
}
69+
70+
template <> constexpr int get_layout_id<matrix::matrix_layout::col_major>() {
71+
return 1;
72+
}
73+
74+
template <matrix::matrix_layout Layout, access::address_space Space>
75+
struct joint_matrix_load_impl<
76+
double, matrix::matrix_use::a, 8, 4, Layout, Space,
77+
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
78+
Layout == matrix::matrix_layout::col_major>> {
79+
void
80+
load(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, Layout> &res,
81+
multi_ptr<double, Space> src, size_t stride) {
82+
83+
#ifdef __NVPTX__
84+
#ifdef __SYCL_DEVICE_ONLY__
85+
__dmma_m8n8k4_ld_a(res.data, src.get(), stride, get_layout_id<Layout>());
86+
#endif
87+
#endif
88+
}
89+
};
90+
91+
template <matrix::matrix_layout Layout, access::address_space Space>
92+
struct joint_matrix_load_impl<
93+
double, matrix::matrix_use::b, 4, 8, Layout, Space,
94+
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
95+
Layout == matrix::matrix_layout::col_major>> {
96+
void
97+
load(matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, Layout> &res,
98+
multi_ptr<double, Space> src, size_t stride) {
99+
#ifdef __NVPTX__
100+
#ifdef __SYCL_DEVICE_ONLY__
101+
__dmma_m8n8k4_ld_b(res.data, src.get(), stride, get_layout_id<Layout>());
102+
#endif
103+
#endif
104+
}
105+
};
106+
107+
template <matrix::matrix_layout Layout, access::address_space Space>
108+
struct joint_matrix_load_impl<
109+
double, matrix::matrix_use::accumulator, 8, 8, Layout, Space,
110+
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
111+
Layout == matrix::matrix_layout::col_major>> {
112+
void load(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
113+
Layout> &res,
114+
multi_ptr<double, Space> src, size_t stride) {
115+
116+
#ifdef __NVPTX__
117+
#ifdef __SYCL_DEVICE_ONLY__
118+
__dmma_m8n8k4_ld_c(res.data, src.get(), stride, get_layout_id<Layout>());
119+
#endif
120+
#endif
121+
}
122+
};
123+
124+
template <typename T, size_t NumRows, size_t NumCols,
125+
matrix::matrix_layout Layout, access::address_space Space,
126+
typename Cond = void>
127+
struct joint_matrix_store_impl {
128+
void store(matrix::joint_matrix<T, matrix::matrix_use::accumulator, NumRows,
129+
NumCols, Layout> &src,
130+
multi_ptr<T, Space> dst, size_t stride);
131+
};
132+
133+
template <matrix::matrix_layout Layout, access::address_space Space>
134+
struct joint_matrix_store_impl<
135+
double, 8, 8, Layout, Space,
136+
typename std::enable_if_t<Layout == matrix::matrix_layout::row_major ||
137+
Layout == matrix::matrix_layout::col_major>> {
138+
void store(matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
139+
Layout> &src,
140+
multi_ptr<double, Space> dst, size_t stride) {
141+
142+
#ifdef __NVPTX__
143+
#ifdef __SYCL_DEVICE_ONLY__
144+
__dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride,
145+
get_layout_id<Layout>());
146+
#endif
147+
#endif
148+
}
149+
};
150+
151+
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
152+
matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
153+
matrix::matrix_layout LayoutC, typename Cond = void>
154+
struct joint_matrix_mad_impl {
155+
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
156+
mad(matrix::joint_matrix<T1, matrix::matrix_use::a, M, K, LayoutA> A,
157+
matrix::joint_matrix<T1, matrix::matrix_use::b, K, N, LayoutB> B,
158+
matrix::joint_matrix<T2, matrix::matrix_use::accumulator, M, N, LayoutC>
159+
C);
160+
};
161+
162+
template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB>
163+
constexpr int get_layout_pair_id();
164+
165+
template <>
166+
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
167+
matrix::matrix_layout::row_major>() {
168+
return 0;
169+
}
170+
171+
template <>
172+
constexpr int get_layout_pair_id<matrix::matrix_layout::row_major,
173+
matrix::matrix_layout::col_major>() {
174+
return 1;
175+
}
176+
177+
template <>
178+
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
179+
matrix::matrix_layout::row_major>() {
180+
return 2;
181+
}
182+
183+
template <>
184+
constexpr int get_layout_pair_id<matrix::matrix_layout::col_major,
185+
matrix::matrix_layout::col_major>() {
186+
return 3;
187+
}
188+
189+
template <matrix::matrix_layout LayoutA, matrix::matrix_layout LayoutB,
190+
matrix::matrix_layout LayoutC>
191+
struct joint_matrix_mad_impl<
192+
double, double, 8, 4, 8, LayoutA, LayoutB, LayoutC,
193+
typename std::enable_if_t<(LayoutA == matrix::matrix_layout::row_major ||
194+
LayoutA == matrix::matrix_layout::col_major) &&
195+
(LayoutB == matrix::matrix_layout::row_major ||
196+
LayoutB == matrix::matrix_layout::col_major) &&
197+
(LayoutC == matrix::matrix_layout::row_major ||
198+
LayoutC == matrix::matrix_layout::col_major)>> {
199+
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
200+
mad(matrix::joint_matrix<double, matrix::matrix_use::a, 8, 4, LayoutA> A,
201+
matrix::joint_matrix<double, matrix::matrix_use::b, 4, 8, LayoutB> B,
202+
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8,
203+
LayoutC>
204+
C) {
205+
matrix::joint_matrix<double, matrix::matrix_use::accumulator, 8, 8, LayoutC>
206+
D;
207+
208+
#ifdef __NVPTX__
209+
#ifdef __SYCL_DEVICE_ONLY__
210+
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
211+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
212+
#endif
213+
#endif
214+
215+
return D;
216+
}
217+
};
218+
219+
} // namespace detail
220+
221+
namespace experimental::matrix {
222+
223+
template <typename Group, typename T, matrix_use MT, size_t NumRows,
224+
size_t NumCols, matrix_layout Layout, access::address_space Space>
225+
void joint_matrix_load(
226+
Group sg, joint_matrix<T, MT, NumRows, NumCols, Layout, Group> &res,
227+
multi_ptr<T, Space> src, size_t stride) {
228+
detail::joint_matrix_load_impl<T, MT, NumRows, NumCols, Layout, Space>{}.load(
229+
res, src, stride);
230+
}
231+
232+
template <typename Group, typename T, size_t NumRows, size_t NumCols,
233+
matrix_layout Layout, access::address_space Space>
234+
void joint_matrix_store(Group sg,
235+
joint_matrix<T, matrix_use::accumulator, NumRows,
236+
NumCols, Layout, Group> &src,
237+
multi_ptr<T, Space> dst, size_t stride) {
238+
detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout, Space>{}.store(
239+
src, dst, stride);
240+
}
241+
242+
template <typename Group, typename T1, typename T2, std::size_t M,
243+
std::size_t K, std::size_t N, matrix_layout LayoutA,
244+
matrix_layout LayoutB, matrix_layout LayoutC>
245+
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group>
246+
joint_matrix_mad(
247+
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
248+
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
249+
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
250+
return detail::joint_matrix_mad_impl<T1, T2, M, K, N, LayoutA, LayoutB,
251+
LayoutC>{}
252+
.mad(A, B, C);
253+
}
254+
255+
} // namespace experimental::matrix
256+
} // namespace oneapi
257+
} // namespace ext
258+
} // namespace sycl
259+
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/sycl/ext/oneapi/matrix/matrix.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@
2525
#include <sycl/ext/oneapi/matrix/matrix-jit.hpp>
2626
#include <sycl/ext/oneapi/matrix/static-query.hpp>
2727
#endif
28+
#if (SYCL_EXT_ONEAPI_MATRIX == 3)
29+
#include <sycl/ext/oneapi/matrix/matrix-tensorcore.hpp>
30+
#endif
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// REQUIRES: gpu, cuda
2+
3+
// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s
4+
5+
#include <CL/sycl.hpp>
6+
7+
using namespace sycl;
8+
using namespace sycl::ext::oneapi::experimental::matrix;
9+
10+
// M, N, K define the sizes of dimensions of the three matrix types (a, b,
11+
// accumulator) used per subgroup operation.
12+
constexpr int M = 8; // number of rows of accumulator,
13+
// number of cols of b.
14+
constexpr int N = 8; // number of cols of accumulator,
15+
// number of rows of a.
16+
constexpr int K = 4; // number of cols of a/number of rows of b.
17+
18+
double A[M * K];
19+
double B[K * N];
20+
double C[M * N];
21+
double D[M * N];
22+
23+
int main() {
24+
25+
buffer<double, 1> bufA(A, range<1>(M * K));
26+
buffer<double, 1> bufB(B, range<1>(K * N));
27+
buffer<double, 1> bufC(C, range<1>(M * N));
28+
buffer<double, 1> bufD(D, range<1>(M * N));
29+
30+
queue q;
31+
32+
q.submit([&](handler &cgh) {
33+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
34+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
35+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
36+
auto accD = bufD.get_access<access::mode::read_write>(cgh);
37+
38+
cgh.parallel_for<class row_row>(
39+
nd_range<2>({1, 32}, {1, 32}), [=
40+
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
41+
sycl::sub_group sg = item.get_sub_group();
42+
43+
joint_matrix<double, matrix_use::accumulator, M, N,
44+
matrix_layout::row_major>
45+
sub_c;
46+
47+
joint_matrix<double, matrix_use::a, M, K, matrix_layout::row_major>
48+
sub_a;
49+
50+
joint_matrix<double, matrix_use::b, K, N, matrix_layout::row_major>
51+
sub_b;
52+
53+
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
54+
joint_matrix_load(sg, sub_c, accC.get_pointer(), N);
55+
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 4) #{{.*}}
56+
joint_matrix_load(sg, sub_a, accA.get_pointer(), K);
57+
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 8) #{{.*}}
58+
joint_matrix_load(sg, sub_b, accB.get_pointer(), N);
59+
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %11, double %12, double %9, double %10) #{{.*}}
60+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
61+
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
62+
joint_matrix_store(sg, sub_c, accD.get_pointer(), N);
63+
});
64+
});
65+
66+
q.submit([&](handler &cgh) {
67+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
68+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
69+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
70+
auto accD = bufD.get_access<access::mode::read_write>(cgh);
71+
72+
cgh.parallel_for<class col_col>(
73+
nd_range<2>({1, 32}, {1, 32}), [=
74+
](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] {
75+
sycl::sub_group sg = item.get_sub_group();
76+
77+
joint_matrix<double, matrix_use::accumulator, M, N,
78+
matrix_layout::col_major>
79+
sub_c;
80+
81+
joint_matrix<double, matrix_use::a, M, K, matrix_layout::col_major>
82+
sub_a;
83+
84+
joint_matrix<double, matrix_use::b, K, N, matrix_layout::col_major>
85+
sub_b;
86+
87+
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i, i32 8) #{{.*}}
88+
joint_matrix_load(sg, sub_c, accC.get_pointer(), M);
89+
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i54, i32 8) #{{.*}}
90+
joint_matrix_load(sg, sub_a, accA.get_pointer(), M);
91+
//CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i65, i32 4) #{{.*}}
92+
joint_matrix_load(sg, sub_b, accB.get_pointer(), K);
93+
//CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %11, double %12, double %9, double %10) #{{.*}}
94+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
95+
//CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %add.ptr.i76, double %14, double %15, i32 8) #{{.*}}
96+
joint_matrix_store(sg, sub_c, accD.get_pointer(), M);
97+
});
98+
});
99+
100+
return 0;
101+
};

0 commit comments

Comments
 (0)