diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc index 57d9066616f8b..f7d916a05144f 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc @@ -252,11 +252,12 @@ layout. ```c++ namespace sycl::ext::oneapi::experimental::matrix { -template void joint_matrix_store(Group g, - const joint_matrix &res, - multi_ptr dest, size_t stride, layout Layout); + const joint_matrix &res, + multi_ptr dest, size_t stride, layout Layout); } // namespace sycl::ext::oneapi::experimental::matrix ``` @@ -371,15 +372,17 @@ joint_matrix_apply(sg, C, [=](T &x) { }); ``` -=== Support for the TF32 Data Type -Some devices support the TF32 floating point type for matrix -elements. This type has a 19 bit format with one sign bit, 8 exponent -bits (offering the same range as float), and 10 mantissa bits -(offering the same precision as sycl::half). Use of this type can -accelerate the joint_matrix_mad operation by reducing its -precision. In order to declare a `joint_matrix` object with this -element type, use `matrix::precision::tf32` in place of the `T` -template parameter. +=== Support for Machine Learning Types +Some devices support special matrix element types that are commonly +used in machine learning algorithms. +These types are unusual because the type of the matrix element is +different from the way the data is stored in memory. As a result, each +of these elements has two types. There is an abstract identifier for +the element type, which is an incomplete type defined in the +`sycl::ext::oneapi::experimental::matrix::precision` namespace, and +there is a corresponding storage format type. The following synopsis +lists the abstract types and the table shows the associated storage +format type. ```c++ namespace sycl::ext::oneapi::experimental::matrix::precision { @@ -389,94 +392,85 @@ class tf32; } // namespace sycl::ext::oneapi::experimental::matrix::precision ``` -For example: +[frame="none",options="header",cols="20%,20%,60%"] +|====================== +| `joint_matrix` element type | Storage type | Descritpion +|precision::tf32 | float | The TF32 type has a 19 bit format with one +sign bit, 8 exponent bits (offering the same range as float), and 10 +mantissa bits (offering the same precision as sycl::half). +|====================== + +In order to declare a `joint_matrix` with one of these element types, +use the abstract type like so: ```c++ joint_matrix tA; ``` -Whenever the application loads, stores, fills, or accesses the -elements of a TF32 matrix, the application sees the elements as -float. There are special overloads of these functions for TF32 for -this purpose. +Operations on these matrices use the functions described above, but +there are different constraints on the template parameters as +described below. -==== TF32 load -These overloads of `joint_matrix_load` load float values into a TF32 -matrix. It is unspecified whether the implementation loads all 32 bits -into the joint matrix or if it only loads the relevant 19 bits. +==== load +The template parameter `T2` must either be the storage format type +that corresponds to the abstract type `T1` or it must be a +const-qualified version of that storage format type. For example: ```c++ -namespace sycl::ext::oneapi::experimental::matrix { +joint_matrix tA; -template -void joint_matrix_load(Group g, - joint_matrix &res, - multi_ptr src, size_t stride, layout Layout); +float *buf = malloc_shared(M*K, q); +auto pBuf = address_space_cast(buf); -template -void joint_matrix_load(Group g, - joint_matrix &res, - multi_ptr src, size_t stride, layout Layout); +joint_matrix_load(sg, tA, pBuf + Offset, Stride); +``` -// Only available when Layout != layout::dynamic -template -void joint_matrix_load(Group g, - joint_matrix &res, - multi_ptr src, size_t stride); +==== store +The template parameter `T2` must be the storage format type that +corresponds to the abstract type `T1`. For example: -// Only available when Layout != layout::dynamic -template -void joint_matrix_load(Group g, - joint_matrix &res, - multi_ptr src, size_t stride); +```c++ +joint_matrix tC; -} // namespace sycl::ext::oneapi::experimental::matrix +float *buf = malloc_shared(M*K, q); +auto pBuf = address_space_cast(buf); + +joint_matrix_store(sg, tA, pBuf + Offset, Stride, layout::row_major); ``` -==== TF32 store -This overload of joint_matrix_store stores float values from a TF32 -matrix. +==== fill +The template parameter `Tv` must be implicitly convertible to the +storage format type that corresponds to the abstract type `T`. For example: ```c++ -namespace sycl::ext::oneapi::experimental::matrix { - -template -void joint_matrix_store(Group g, - const joint_matrix &res, - multi_ptr dest, size_t stride, layout Layout); - -} // namespace sycl::ext::oneapi::experimental::matrix +joint_matrix tA; +float v = 42.0; +joint_matrix_fill(sg, tA, v); ``` -==== TF32 fill -When `joint_matrix_fill` is called for a TF32 matrix, the type `Tv` -(the type of the fill value) must be implicitly convertible to -`float`. It is unspecified whether the implementation writes all 32 -bits of the value into the joint matrix or if it only writes the -relevant 19 bits. +==== copy +There is no special constraint for the `joint_matrix_copy` +function. The template parameters `T1` and `T2` correspond to the +element types of the `src` and `dest` matrices. -==== TF32 element-wise operations -When `joint_matrix_apply` is called for a TF32 matrix, the Callable -object func is called with a single argument of type `float &`. When the -application changes this value, it is unspecified whether the -implementation writes back all 32 bits of the element into the joint -matrix or if it only write the relevant 19 bits. +```c++ +joint_matrix tA; +joint_matrix tC; +joint_matrix_copy(sg, tC, tA); +``` -In the example below, `C` is a joint matrix of type `precision::tf32`. +==== Element-wise operations +The Callable function type `F` must be invocable with a single argument +whose type is a reference to the storage format type that corresponds +to the abstract type `T`. For example, in the case where `C` is a +joint matrix of type `precision::tf32`: ```c++ -joint_matrix_apply(sg, C, [=](float &x) { +joint_matrix tC; +joint_matrix_apply(sg, tC, [=](float &x) { x *= alpha; }); ``` @@ -887,7 +881,8 @@ is shown in a single column in the table below. This is currently available in devices with the architecture `architecture::intel_gpu_pvc`, `architecture::intel_gpu_dg2_g10`, `architecture::intel_gpu_dg2_g11`, and -`architecture::intel_gpu_dg2_g12`. In these architectures' +`architecture::intel_gpu_dg2_g12`. +In these architectures' implementation, the type of the C matrix must be the same as the type of the D matrix. Therefore, that common type is shown in a single column in the table below. @@ -897,27 +892,32 @@ column in the table below. | A type | B type | C and D type | M | N | K | device .2+| `matrix_type::uint8` .2+| `matrix_type::uint8` .2+| `matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 -|`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10, +|`architecture::intel_gpu_pvc` +|8|`architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12` .2+| `matrix_type::uint8` .2+| `matrix_type::sint8` .2+| `matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 | -`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10, +`architecture::intel_gpu_pvc` +|8|`architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12` .2+| `matrix_type::sint8` .2+| `matrix_type::uint8` .2+| `matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 | -`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10, +`architecture::intel_gpu_pvc` +|8|`architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12` .2+| `matrix_type::sint8` .2+| `matrix_type::sint8` .2+| `matrix_type::sint32` .2+| +<=+ 8 | 16 .2+| 32 | -`architecture::intel_gpu_pvc`|8|`architecture::intel_gpu_dg2_g10, +`architecture::intel_gpu_pvc` +|8|`architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12` .2+|`matrix_type::fp16` .2+| `matrix_type::fp16` .2+| `matrix_type::fp32` .2+| +<=+ 8 | 16 .2+| 16 | -`architecture::intel_gpu_pvc`|8| `architecture::intel_gpu_dg2_g10, +`architecture::intel_gpu_pvc` +|8| `architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12` .4+| `matrix_type::bf16` .4+| `matrix_type::bf16` .4+| -`matrix_type::fp32` | 16 | 16 | 16 .3+|`architecture::intel_gpu_pvc` | -32 | 64 | 16 +`matrix_type::fp32` | 16 | 16 | 16 .3+|`architecture::intel_gpu_pvc` +|32 | 64 | 16 .2+| +<=+ 8 | 16 .2+| 16 |8 | `architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12`