Skip to content

Commit 96da39e

Browse files
authored
[SYCL] Add group algorithms for MUL/OR/XOR/AND operations (#2339)
This patch only adds the operations but does not fully enables them as more changes are needed in driver to conditionally or unconditionally raise the spirv std from 1.1 to 1.3, which is needed to avoid assert fails for spirv-1.3 operations being used under -spirv-max-version=1.1 used by default now. This patch also: - adds few test cases to corresponding LIT tests, which are temporarily turned off though (until CPU/GPU RT gets ready to handle the new operations). - fixes the device check in 3 LIT tests to enable them for PI_LEVEL0. Signed-off-by: Vyacheslav N Klochkov <vyacheslav.n.klochkov@intel.com>
1 parent 8f8d34c commit 96da39e

File tree

6 files changed

+145
-17
lines changed

6 files changed

+145
-17
lines changed

clang/lib/Sema/SPIRVBuiltins.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -917,11 +917,13 @@ foreach name = ["GroupBroadcast"] in {
917917
}
918918
}
919919

920-
foreach name = ["GroupIAdd"] in {
920+
foreach name = ["GroupIAdd", "GroupNonUniformIMul", "GroupNonUniformBitwiseOr",
921+
"GroupNonUniformBitwiseXor", "GroupNonUniformBitwiseAnd"] in {
921922
def : SPVBuiltin<name, [AIGenTypeN, UInt, UInt, AIGenTypeN], Attr.Convergent>;
922923
}
923924

924-
foreach name = ["GroupFAdd", "GroupFMin", "GroupFMax"] in {
925+
foreach name = ["GroupFAdd", "GroupFMin", "GroupFMax",
926+
"GroupNonUniformFMul"] in {
925927
def : SPVBuiltin<name, [FGenTypeN, UInt, UInt, FGenTypeN], Attr.Convergent>;
926928
}
927929

sycl/include/CL/sycl/ONEAPI/functional.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ template <> struct maximum<void> {
5353
#endif
5454

5555
template <typename T = void> using plus = std::plus<T>;
56+
template <typename T = void> using multiplies = std::multiplies<T>;
5657
template <typename T = void> using bit_or = std::bit_or<T>;
5758
template <typename T = void> using bit_xor = std::bit_xor<T>;
5859
template <typename T = void> using bit_and = std::bit_and<T>;
@@ -103,6 +104,16 @@ __SYCL_CALC_OVERLOAD(GroupOpISigned, IAdd, ONEAPI::plus<T>)
103104
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, IAdd, ONEAPI::plus<T>)
104105
__SYCL_CALC_OVERLOAD(GroupOpFP, FAdd, ONEAPI::plus<T>)
105106

107+
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformIMul, ONEAPI::multiplies<T>)
108+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformIMul, ONEAPI::multiplies<T>)
109+
__SYCL_CALC_OVERLOAD(GroupOpFP, NonUniformFMul, ONEAPI::multiplies<T>)
110+
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseOr, ONEAPI::bit_or<T>)
111+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseOr, ONEAPI::bit_or<T>)
112+
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseXor, ONEAPI::bit_xor<T>)
113+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseXor, ONEAPI::bit_xor<T>)
114+
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseAnd, ONEAPI::bit_and<T>)
115+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseAnd, ONEAPI::bit_and<T>)
116+
106117
#undef __SYCL_CALC_OVERLOAD
107118

108119
template <typename T, __spv::GroupOperation O, __spv::Scope::Flag S,

sycl/include/CL/sycl/ONEAPI/group_algorithm.hpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,39 @@ template <typename T, typename V> struct identity<T, ONEAPI::plus<V>> {
8686
};
8787

8888
template <typename T, typename V> struct identity<T, ONEAPI::minimum<V>> {
89-
static constexpr T value = (std::numeric_limits<T>::max)();
89+
static constexpr T value = std::numeric_limits<T>::has_infinity
90+
? std::numeric_limits<T>::infinity()
91+
: (std::numeric_limits<T>::max)();
9092
};
9193

9294
template <typename T, typename V> struct identity<T, ONEAPI::maximum<V>> {
93-
static constexpr T value = std::numeric_limits<T>::lowest();
95+
static constexpr T value =
96+
std::numeric_limits<T>::has_infinity
97+
? static_cast<T>(-std::numeric_limits<T>::infinity())
98+
: std::numeric_limits<T>::lowest();
99+
};
100+
101+
template <typename T, typename V> struct identity<T, ONEAPI::multiplies<V>> {
102+
static constexpr T value = static_cast<T>(1);
103+
};
104+
105+
template <typename T, typename V> struct identity<T, ONEAPI::bit_or<V>> {
106+
static constexpr T value = 0;
107+
};
108+
109+
template <typename T, typename V> struct identity<T, ONEAPI::bit_xor<V>> {
110+
static constexpr T value = 0;
111+
};
112+
113+
template <typename T, typename V> struct identity<T, ONEAPI::bit_and<V>> {
114+
static constexpr T value = ~static_cast<T>(0);
94115
};
95116

96117
template <typename T>
97118
using native_op_list =
98119
type_list<ONEAPI::plus<T>, ONEAPI::bit_or<T>, ONEAPI::bit_xor<T>,
99-
ONEAPI::bit_and<T>, ONEAPI::maximum<T>, ONEAPI::minimum<T>>;
120+
ONEAPI::bit_and<T>, ONEAPI::maximum<T>, ONEAPI::minimum<T>,
121+
ONEAPI::multiplies<T>>;
100122

101123
template <typename T, typename BinaryOperation> struct is_native_op {
102124
static constexpr bool value =

sycl/test/group-algorithm/exclusive_scan.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
// RUN: %GPU_RUN_PLACEHOLDER %t.out
88
// RUN: %ACC_RUN_PLACEHOLDER %t.out
99

10+
// TODO: enable compile+runtime checks for operations defined in SPIR-V 1.3.
11+
// That requires either adding a switch to clang (-spirv-max-version=1.3) or
12+
// raising the spirv version from 1.1. to 1.3 for spirv translator
13+
// unconditionally. Using operators specific for spirv 1.3 and higher with
14+
// -spirv-max-version=1.1 being set by default causes assert/check fails
15+
// in spirv translator.
16+
// RUNx: %clangxx -fsycl -fsycl-targets=%sycl_triple -DSPIRV_1_3 %s -o %t13.out
17+
1018
#include <CL/sycl.hpp>
1119
#include <algorithm>
1220
#include <cassert>
@@ -120,10 +128,27 @@ void test(queue q, InputContainer input, OutputContainer output,
120128
assert(std::equal(output.begin(), output.begin() + N, expected.begin()));
121129
}
122130

131+
bool isSupportedDevice(device D) {
132+
std::string PlatformName = D.get_platform().get_info<info::platform::name>();
133+
if (PlatformName.find("Level-Zero") != std::string::npos)
134+
return true;
135+
136+
if (PlatformName.find("OpenCL") != std::string::npos) {
137+
std::string Version = D.get_info<info::device::version>();
138+
size_t Offset = Version.find("OpenCL");
139+
if (Offset == std::string::npos)
140+
return false;
141+
Version = Version.substr(Offset + 7, 3);
142+
if (Version >= std::string("2.0"))
143+
return true;
144+
}
145+
146+
return false;
147+
}
148+
123149
int main() {
124150
queue q;
125-
std::string version = q.get_device().get_info<info::device::version>();
126-
if (version < std::string("2.0")) {
151+
if (!isSupportedDevice(q.get_device())) {
127152
std::cout << "Skipping test\n";
128153
return 0;
129154
}
@@ -134,14 +159,20 @@ int main() {
134159
std::iota(input.begin(), input.end(), 0);
135160
std::fill(output.begin(), output.end(), 0);
136161

137-
#if __cplusplus >= 201402L
138162
test(q, input, output, plus<>(), 0);
139163
test(q, input, output, minimum<>(), std::numeric_limits<int>::max());
140164
test(q, input, output, maximum<>(), std::numeric_limits<int>::lowest());
141-
#endif
165+
142166
test(q, input, output, plus<int>(), 0);
143167
test(q, input, output, minimum<int>(), std::numeric_limits<int>::max());
144168
test(q, input, output, maximum<int>(), std::numeric_limits<int>::lowest());
145169

170+
#ifdef SPIRV_1_3
171+
test(q, input, output, multiplies<int>(), 1);
172+
test(q, input, output, bit_or<int>(), 0);
173+
test(q, input, output, bit_xor<int>(), 0);
174+
test(q, input, output, bit_and<int>(), ~0);
175+
#endif // SPIRV_1_3
176+
146177
std::cout << "Test passed." << std::endl;
147178
}

sycl/test/group-algorithm/inclusive_scan.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
// RUN: %GPU_RUN_PLACEHOLDER %t.out
88
// RUN: %ACC_RUN_PLACEHOLDER %t.out
99

10+
// TODO: enable compile+runtime checks for operations defined in SPIR-V 1.3.
11+
// That requires either adding a switch to clang (-spirv-max-version=1.3) or
12+
// raising the spirv version from 1.1. to 1.3 for spirv translator
13+
// unconditionally. Using operators specific for spirv 1.3 and higher with
14+
// -spirv-max-version=1.1 being set by default causes assert/check fails
15+
// in spirv translator.
16+
// RUNx: %clangxx -fsycl -fsycl-targets=%sycl_triple -DSPIRV_1_3 %s -o %t13.out
17+
1018
#include <CL/sycl.hpp>
1119
#include <algorithm>
1220
#include <cassert>
@@ -120,10 +128,27 @@ void test(queue q, InputContainer input, OutputContainer output,
120128
assert(std::equal(output.begin(), output.begin() + N, expected.begin()));
121129
}
122130

131+
bool isSupportedDevice(device D) {
132+
std::string PlatformName = D.get_platform().get_info<info::platform::name>();
133+
if (PlatformName.find("Level-Zero") != std::string::npos)
134+
return true;
135+
136+
if (PlatformName.find("OpenCL") != std::string::npos) {
137+
std::string Version = D.get_info<info::device::version>();
138+
size_t Offset = Version.find("OpenCL");
139+
if (Offset == std::string::npos)
140+
return false;
141+
Version = Version.substr(Offset + 7, 3);
142+
if (Version >= std::string("2.0"))
143+
return true;
144+
}
145+
146+
return false;
147+
}
148+
123149
int main() {
124150
queue q;
125-
std::string version = q.get_device().get_info<info::device::version>();
126-
if (version < std::string("2.0")) {
151+
if (!isSupportedDevice(q.get_device())) {
127152
std::cout << "Skipping test\n";
128153
return 0;
129154
}
@@ -134,14 +159,20 @@ int main() {
134159
std::iota(input.begin(), input.end(), 0);
135160
std::fill(output.begin(), output.end(), 0);
136161

137-
#if __cplusplus >= 201402L
138162
test(q, input, output, plus<>(), 0);
139163
test(q, input, output, minimum<>(), std::numeric_limits<int>::max());
140164
test(q, input, output, maximum<>(), std::numeric_limits<int>::lowest());
141-
#endif
165+
142166
test(q, input, output, plus<int>(), 0);
143167
test(q, input, output, minimum<int>(), std::numeric_limits<int>::max());
144168
test(q, input, output, maximum<int>(), std::numeric_limits<int>::lowest());
145169

170+
#ifdef SPIRV_1_3
171+
test(q, input, output, multiplies<int>(), 1);
172+
test(q, input, output, bit_or<int>(), 0);
173+
test(q, input, output, bit_xor<int>(), 0);
174+
test(q, input, output, bit_and<int>(), ~0);
175+
#endif // SPIRV_1_3
176+
146177
std::cout << "Test passed." << std::endl;
147178
}

sycl/test/group-algorithm/reduce.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
// RUN: %GPU_RUN_PLACEHOLDER %t.out
88
// RUN: %ACC_RUN_PLACEHOLDER %t.out
99

10+
// TODO: enable compile+runtime checks for operations defined in SPIR-V 1.3.
11+
// That requires either adding a switch to clang (-spirv-max-version=1.3) or
12+
// raising the spirv version from 1.1. to 1.3 for spirv translator
13+
// unconditionally. Using operators specific for spirv 1.3 and higher with
14+
// -spirv-max-version=1.1 being set by default causes assert/check fails
15+
// in spirv translator.
16+
// RUNx: %clangxx -fsycl -fsycl-targets=%sycl_triple -DSPIRV_1_3 %s -o %t13.out
17+
1018
#include <CL/sycl.hpp>
1119
#include <algorithm>
1220
#include <cassert>
@@ -58,10 +66,27 @@ void test(queue q, InputContainer input, OutputContainer output,
5866
std::accumulate(input.begin(), input.end(), init, binary_op));
5967
}
6068

69+
bool isSupportedDevice(device D) {
70+
std::string PlatformName = D.get_platform().get_info<info::platform::name>();
71+
if (PlatformName.find("Level-Zero") != std::string::npos)
72+
return true;
73+
74+
if (PlatformName.find("OpenCL") != std::string::npos) {
75+
std::string Version = D.get_info<info::device::version>();
76+
size_t Offset = Version.find("OpenCL");
77+
if (Offset == std::string::npos)
78+
return false;
79+
Version = Version.substr(Offset + 7, 3);
80+
if (Version >= std::string("2.0"))
81+
return true;
82+
}
83+
84+
return false;
85+
}
86+
6187
int main() {
6288
queue q;
63-
std::string version = q.get_device().get_info<info::device::version>();
64-
if (version < std::string("2.0")) {
89+
if (!isSupportedDevice(q.get_device())) {
6590
std::cout << "Skipping test\n";
6691
return 0;
6792
}
@@ -72,14 +97,20 @@ int main() {
7297
std::iota(input.begin(), input.end(), 0);
7398
std::fill(output.begin(), output.end(), 0);
7499

75-
#if __cplusplus >= 201402L
76100
test(q, input, output, plus<>(), 0);
77101
test(q, input, output, minimum<>(), std::numeric_limits<int>::max());
78102
test(q, input, output, maximum<>(), std::numeric_limits<int>::lowest());
79-
#endif
103+
80104
test(q, input, output, plus<int>(), 0);
81105
test(q, input, output, minimum<int>(), std::numeric_limits<int>::max());
82106
test(q, input, output, maximum<int>(), std::numeric_limits<int>::lowest());
83107

108+
#ifdef SPIRV_1_3
109+
test(q, input, output, multiplies<int>(), 1);
110+
test(q, input, output, bit_or<int>(), 0);
111+
test(q, input, output, bit_xor<int>(), 0);
112+
test(q, input, output, bit_and<int>(), ~0);
113+
#endif // SPIRV_1_3
114+
84115
std::cout << "Test passed." << std::endl;
85116
}

0 commit comments

Comments
 (0)