Skip to content

Commit a7a6de2

Browse files
[SYCL] Implement the rest of geometric built-ins (#8718)
Co-authored-by: Dmitry Vodopyanov <dmitry.vodopyanov@intel.com>
1 parent 09551d9 commit a7a6de2

File tree

4 files changed

+234
-0
lines changed

4 files changed

+234
-0
lines changed

sycl/include/sycl/builtins.hpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ namespace detail {
2626
template <class T, size_t N> vec<T, 2> to_vec2(marray<T, N> x, size_t start) {
2727
return {x[start], x[start + 1]};
2828
}
29+
template <class T, size_t N> vec<T, N> to_vec(marray<T, N> x) {
30+
vec<T, N> vec;
31+
for (size_t i = 0; i < N; i++)
32+
vec[i] = x[i];
33+
return vec;
34+
}
35+
template <class T, int N> marray<T, N> to_marray(vec<T, N> x) {
36+
marray<T, N> marray;
37+
for (size_t i = 0; i < N; i++)
38+
marray[i] = x[i];
39+
return marray;
40+
}
2941
} // namespace detail
3042

3143
#ifdef __SYCL_DEVICE_ONLY__
@@ -1805,6 +1817,70 @@ fast_normalize(T p) __NOEXC {
18051817
return __sycl_std::__invoke_fast_normalize<T>(p);
18061818
}
18071819

1820+
// marray geometric functions
1821+
1822+
#define __SYCL_MARRAY_GEOMETRIC_FUNCTION_OVERLOAD_IMPL(NAME, ...) \
1823+
vec<detail::marray_element_t<T>, T::size()> result_v; \
1824+
result_v = NAME(__VA_ARGS__); \
1825+
return detail::to_marray(result_v);
1826+
1827+
template <typename T>
1828+
std::enable_if_t<detail::is_gencrossmarray<T>::value, T> cross(T p0,
1829+
T p1) __NOEXC {
1830+
__SYCL_MARRAY_GEOMETRIC_FUNCTION_OVERLOAD_IMPL(cross, detail::to_vec(p0),
1831+
detail::to_vec(p1))
1832+
}
1833+
1834+
template <typename T>
1835+
std::enable_if_t<detail::is_gengeomarray<T>::value, T> normalize(T p) __NOEXC {
1836+
__SYCL_MARRAY_GEOMETRIC_FUNCTION_OVERLOAD_IMPL(normalize, detail::to_vec(p))
1837+
}
1838+
1839+
template <typename T>
1840+
std::enable_if_t<detail::is_gengeomarrayfloat<T>::value, T>
1841+
fast_normalize(T p) __NOEXC {
1842+
__SYCL_MARRAY_GEOMETRIC_FUNCTION_OVERLOAD_IMPL(fast_normalize,
1843+
detail::to_vec(p))
1844+
}
1845+
1846+
#undef __SYCL_MARRAY_GEOMETRIC_FUNCTION_OVERLOAD_IMPL
1847+
1848+
#define __SYCL_MARRAY_GEOMETRIC_FUNCTION_IS_GENGEOMARRAY_BINOP_OVERLOAD(NAME) \
1849+
template <typename T> \
1850+
std::enable_if_t<detail::is_gengeomarray<T>::value, \
1851+
detail::marray_element_t<T>> \
1852+
NAME(T p0, T p1) __NOEXC { \
1853+
return NAME(detail::to_vec(p0), detail::to_vec(p1)); \
1854+
}
1855+
1856+
// clang-format off
1857+
__SYCL_MARRAY_GEOMETRIC_FUNCTION_IS_GENGEOMARRAY_BINOP_OVERLOAD(dot)
1858+
__SYCL_MARRAY_GEOMETRIC_FUNCTION_IS_GENGEOMARRAY_BINOP_OVERLOAD(distance)
1859+
// clang-format on
1860+
1861+
#undef __SYCL_MARRAY_GEOMETRIC_FUNCTION_IS_GENGEOMARRAY_BINOP_OVERLOAD
1862+
1863+
template <typename T>
1864+
std::enable_if_t<detail::is_gengeomarray<T>::value, detail::marray_element_t<T>>
1865+
length(T p) __NOEXC {
1866+
return __sycl_std::__invoke_length<detail::marray_element_t<T>>(
1867+
detail::to_vec(p));
1868+
}
1869+
1870+
template <typename T>
1871+
std::enable_if_t<detail::is_gengeomarrayfloat<T>::value,
1872+
detail::marray_element_t<T>>
1873+
fast_distance(T p0, T p1) __NOEXC {
1874+
return fast_distance(detail::to_vec(p0), detail::to_vec(p1));
1875+
}
1876+
1877+
template <typename T>
1878+
std::enable_if_t<detail::is_gengeomarrayfloat<T>::value,
1879+
detail::marray_element_t<T>>
1880+
fast_length(T p) __NOEXC {
1881+
return fast_length(detail::to_vec(p));
1882+
}
1883+
18081884
/* SYCL 1.2.1 ---- 4.13.7 Relational functions. -----------------------------*/
18091885
/* SYCL 2020 ---- 4.17.9 Relational functions. -----------------------------*/
18101886

sycl/include/sycl/detail/generic_type_lists.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ using vector_geo_float_list =
108108
using vector_geo_double_list =
109109
type_list<vec<double, 1>, vec<double, 2>, vec<double, 3>, vec<double, 4>>;
110110

111+
using marray_geo_float_list =
112+
type_list<marray<float, 2>, marray<float, 3>, marray<float, 4>>;
113+
114+
using marray_geo_double_list =
115+
type_list<marray<double, 2>, marray<double, 3>, marray<double, 4>>;
116+
111117
using geo_half_list = type_list<scalar_geo_half_list, vector_geo_half_list>;
112118

113119
using geo_float_list = type_list<scalar_geo_float_list, vector_geo_float_list>;
@@ -121,6 +127,9 @@ using scalar_geo_list = type_list<scalar_geo_half_list, scalar_geo_float_list,
121127
using vector_geo_list = type_list<vector_geo_half_list, vector_geo_float_list,
122128
vector_geo_double_list>;
123129

130+
using marray_geo_list =
131+
type_list<marray_geo_float_list, marray_geo_double_list>;
132+
124133
using geo_list = type_list<scalar_geo_list, vector_geo_list>;
125134

126135
// cross floating point types
@@ -133,6 +142,9 @@ using cross_double_list = type_list<vec<double, 3>, vec<double, 4>>;
133142
using cross_floating_list =
134143
type_list<cross_float_list, cross_double_list, cross_half_list>;
135144

145+
using cross_marray_list = type_list<marray<float, 3>, marray<float, 4>,
146+
marray<double, 3>, marray<double, 4>>;
147+
136148
using scalar_default_char_list = type_list<char>;
137149

138150
using vector_default_char_list =

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ using is_gengeofloat = is_contained<T, gtl::geo_float_list>;
7070
template <typename T>
7171
using is_gengeodouble = is_contained<T, gtl::geo_double_list>;
7272

73+
template <typename T>
74+
using is_gengeomarrayfloat = is_contained<T, gtl::marray_geo_float_list>;
75+
76+
template <typename T>
77+
using is_gengeomarray = is_contained<T, gtl::marray_geo_list>;
78+
7379
template <typename T> using is_gengeohalf = is_contained<T, gtl::geo_half_list>;
7480

7581
template <typename T>
@@ -97,6 +103,9 @@ using is_gencrosshalf = is_contained<T, gtl::cross_half_list>;
97103
template <typename T>
98104
using is_gencross = is_contained<T, gtl::cross_floating_list>;
99105

106+
template <typename T>
107+
using is_gencrossmarray = is_contained<T, gtl::cross_marray_list>;
108+
100109
template <typename T>
101110
using is_charn = is_contained<T, gtl::vector_default_char_list>;
102111

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
2+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
3+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
4+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
5+
6+
#include <CL/sycl.hpp>
7+
8+
#define TEST(FUNC, MARRAY_ELEM_TYPE, DIM, EXPECTED, DELTA, ...) \
9+
{ \
10+
{ \
11+
MARRAY_ELEM_TYPE result[DIM]; \
12+
{ \
13+
sycl::buffer<MARRAY_ELEM_TYPE> b(result, sycl::range{DIM}); \
14+
Queue.submit([&](sycl::handler &cgh) { \
15+
sycl::accessor res_access{b, cgh}; \
16+
cgh.single_task([=]() { \
17+
sycl::marray<MARRAY_ELEM_TYPE, DIM> res = FUNC(__VA_ARGS__); \
18+
for (int i = 0; i < DIM; i++) \
19+
res_access[i] = res[i]; \
20+
}); \
21+
}); \
22+
} \
23+
for (int i = 0; i < DIM; i++) { \
24+
assert(abs(result[i] - EXPECTED[i]) <= DELTA); \
25+
} \
26+
} \
27+
}
28+
29+
#define TEST2(FUNC, TYPE, EXPECTED, DELTA, ...) \
30+
{ \
31+
{ \
32+
TYPE result; \
33+
{ \
34+
sycl::buffer<TYPE> b(&result, 1); \
35+
Queue.submit([&](sycl::handler &cgh) { \
36+
sycl::accessor res_access{b, cgh}; \
37+
cgh.single_task([=]() { res_access[0] = FUNC(__VA_ARGS__); }); \
38+
}); \
39+
} \
40+
assert(abs(result - EXPECTED) <= DELTA); \
41+
} \
42+
}
43+
44+
#define EXPECTED(TYPE, ...) ((TYPE[]){__VA_ARGS__})
45+
46+
int main() {
47+
sycl::device Dev;
48+
sycl::queue Queue(Dev);
49+
// clang-format off
50+
sycl::marray<float, 2> MFloatD2 = {1.f, 2.f};
51+
sycl::marray<float, 2> MFloatD2_2 = {3.f, 5.f};
52+
sycl::marray<float, 3> MFloatD3 = {1.f, 2.f, 3.f};
53+
sycl::marray<float, 3> MFloatD3_2 = {1.f, 5.f, 7.f};
54+
sycl::marray<float, 4> MFloatD4 = {1.f, 2.f, 3.f, 4.f};
55+
sycl::marray<float, 4> MFloatD4_2 = {1.f, 5.f, 7.f, 4.f};
56+
57+
sycl::marray<double, 2> MDoubleD2 = {1.0, 2.0};
58+
sycl::marray<double, 2> MDoubleD2_2 = {3.0, 5.0};
59+
sycl::marray<double, 3> MDoubleD3 = {1.0, 2.0, 3.0};
60+
sycl::marray<double, 3> MDoubleD3_2 = {1.0, 5.0, 7.0};
61+
sycl::marray<double, 4> MDoubleD4 = {1.0, 2.0, 3.0, 4.0};
62+
sycl::marray<double, 4> MDoubleD4_2 = {1.0, 5.0, 7.0, 4.0};
63+
// clang-format on
64+
65+
TEST(sycl::cross, float, 3, EXPECTED(float, -1.f, -4.f, 3.f), 0, MFloatD3,
66+
MFloatD3_2);
67+
TEST(sycl::cross, float, 4, EXPECTED(float, -1.f, -4.f, 3.f, 0.f), 0,
68+
MFloatD4, MFloatD4_2);
69+
if (Dev.has(sycl::aspect::fp64)) {
70+
TEST(sycl::cross, double, 3, EXPECTED(double, -1.f, -4.f, 3.f), 0,
71+
MDoubleD3, MDoubleD3_2);
72+
TEST(sycl::cross, double, 4, EXPECTED(double, -1.f, -4.f, 3.f, 0.f), 0,
73+
MDoubleD4, MDoubleD4_2);
74+
}
75+
76+
TEST2(sycl::dot, float, 13.f, 0, MFloatD2, MFloatD2_2);
77+
TEST2(sycl::dot, float, 32.f, 0, MFloatD3, MFloatD3_2);
78+
TEST2(sycl::dot, float, 48.f, 0, MFloatD4, MFloatD4_2);
79+
if (Dev.has(sycl::aspect::fp64)) {
80+
TEST2(sycl::dot, double, 13, 0, MDoubleD2, MDoubleD2_2);
81+
TEST2(sycl::dot, double, 32, 0, MDoubleD3, MDoubleD3_2);
82+
TEST2(sycl::dot, double, 48, 0, MDoubleD4, MDoubleD4_2);
83+
}
84+
85+
TEST2(sycl::length, float, 2.236068f, 1e-6, MFloatD2);
86+
TEST2(sycl::length, float, 3.741657f, 1e-6, MFloatD3);
87+
TEST2(sycl::length, float, 5.477225f, 1e-6, MFloatD4);
88+
if (Dev.has(sycl::aspect::fp64)) {
89+
TEST2(sycl::length, double, 2.236068, 1e-6, MDoubleD2);
90+
TEST2(sycl::length, double, 3.741657, 1e-6, MDoubleD3);
91+
TEST2(sycl::length, double, 5.477225, 1e-6, MDoubleD4);
92+
}
93+
94+
TEST2(sycl::distance, float, 3.605551f, 1e-6, MFloatD2, MFloatD2_2);
95+
TEST2(sycl::distance, float, 5.f, 0, MFloatD3, MFloatD3_2);
96+
TEST2(sycl::distance, float, 5.f, 0, MFloatD4, MFloatD4_2);
97+
if (Dev.has(sycl::aspect::fp64)) {
98+
TEST2(sycl::distance, double, 3.605551, 1e-6, MDoubleD2, MDoubleD2_2);
99+
TEST2(sycl::distance, double, 5.0, 0, MDoubleD3, MDoubleD3_2);
100+
TEST2(sycl::distance, double, 5.0, 0, MDoubleD4, MDoubleD4_2);
101+
}
102+
103+
TEST(sycl::normalize, float, 2, EXPECTED(float, 0.447213f, 0.894427f), 1e-6,
104+
MFloatD2);
105+
TEST(sycl::normalize, float, 3,
106+
EXPECTED(float, 0.267261f, 0.534522f, 0.801784f), 1e-6, MFloatD3);
107+
TEST(sycl::normalize, float, 4,
108+
EXPECTED(float, 0.182574f, 0.365148f, 0.547723f, 0.730297f), 1e-6,
109+
MFloatD4);
110+
if (Dev.has(sycl::aspect::fp64)) {
111+
TEST(sycl::normalize, double, 2, EXPECTED(double, 0.447213, 0.894427), 1e-6,
112+
MDoubleD2);
113+
TEST(sycl::normalize, double, 3,
114+
EXPECTED(double, 0.267261, 0.534522, 0.801784), 1e-6, MDoubleD3);
115+
TEST(sycl::normalize, double, 4,
116+
EXPECTED(double, 0.182574, 0.365148, 0.547723, 0.730297), 1e-6,
117+
MDoubleD4);
118+
}
119+
120+
TEST2(sycl::fast_distance, float, 3.605551f, 1e-6, MFloatD2, MFloatD2_2);
121+
TEST2(sycl::fast_distance, float, 5.f, 0, MFloatD3, MFloatD3_2);
122+
TEST2(sycl::fast_distance, float, 5.f, 0, MFloatD4, MFloatD4_2);
123+
124+
TEST2(sycl::fast_length, float, 2.236068f, 1e-6, MFloatD2);
125+
TEST2(sycl::fast_length, float, 3.741657f, 1e-6, MFloatD3);
126+
TEST2(sycl::fast_length, float, 5.477225f, 1e-6, MFloatD4);
127+
128+
TEST(sycl::fast_normalize, float, 2, EXPECTED(float, 0.447213f, 0.894427f),
129+
1e-3, MFloatD2);
130+
TEST(sycl::fast_normalize, float, 3,
131+
EXPECTED(float, 0.267261f, 0.534522f, 0.801784f), 1e-3, MFloatD3);
132+
TEST(sycl::fast_normalize, float, 4,
133+
EXPECTED(float, 0.182574f, 0.365148f, 0.547723f, 0.730297f), 1e-3,
134+
MFloatD4);
135+
136+
return 0;
137+
}

0 commit comments

Comments
 (0)