Skip to content

Commit 2d28cd4

Browse files
[SYCL] Fix is_device_copyable with range rounding (#4478)
Suppose we have the following code: ``` template <> struct is_device_copyable<DeviceCopyable> : std::true_type {}; ... DeviceCopyable DevCop(0); Q.submit([=](sycl::handler& cgh){ const sycl::range<2> range(1026, 1026); cgh.parallel_for(range,[=](sycl::item<2> item) { (void)DevCop; }); }); ``` This code doesn't compile because range rounding optimization wraps kernel lambda function, so we have something like this: ``` |- WrapperLambda |-KernelLambda ||-DevCop |-Other Wrapper captures ``` According to the implementation of is_device_copyable, we check whether all the fields of WrapperLambda are device copyable. KernelLambda is not device copyable since there is no corresponding template specialization of is_device_copyable. It's not possible to provide a template specialization for is_device_copyable with kernel lambda since lambda types don't have the name. To fix this issue, we create a functor class RoundedRangeKernel and provide an is_device_copyable trait for this class, which simply forwards check to nested kernel lambda.
1 parent b94f23a commit 2d28cd4

File tree

5 files changed

+85
-14
lines changed

5 files changed

+85
-14
lines changed

sycl/include/CL/sycl/handler.hpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,42 @@ checkValueRange(const T &V) {
198198
#endif
199199
}
200200

201+
template <typename TransformedArgType, int Dims, typename KernelType>
202+
class RoundedRangeKernel {
203+
public:
204+
RoundedRangeKernel(range<Dims> NumWorkItems, KernelType KernelFunc)
205+
: NumWorkItems(NumWorkItems), KernelFunc(KernelFunc) {}
206+
207+
void operator()(TransformedArgType Arg) const {
208+
if (Arg[0] >= NumWorkItems[0])
209+
return;
210+
Arg.set_allowed_range(NumWorkItems);
211+
KernelFunc(Arg);
212+
}
213+
214+
private:
215+
range<Dims> NumWorkItems;
216+
KernelType KernelFunc;
217+
};
218+
219+
template <typename TransformedArgType, int Dims, typename KernelType>
220+
class RoundedRangeKernelWithKH {
221+
public:
222+
RoundedRangeKernelWithKH(range<Dims> NumWorkItems, KernelType KernelFunc)
223+
: NumWorkItems(NumWorkItems), KernelFunc(KernelFunc) {}
224+
225+
void operator()(TransformedArgType Arg, kernel_handler KH) const {
226+
if (Arg[0] >= NumWorkItems[0])
227+
return;
228+
Arg.set_allowed_range(NumWorkItems);
229+
KernelFunc(Arg, KH);
230+
}
231+
232+
private:
233+
range<Dims> NumWorkItems;
234+
KernelType KernelFunc;
235+
};
236+
201237
} // namespace detail
202238

203239
namespace ext {
@@ -2455,19 +2491,12 @@ class __SYCL_EXPORT handler {
24552491
range<Dims> NumWorkItems) {
24562492
if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
24572493
KernelType, TransformedArgType>()) {
2458-
return [=](TransformedArgType Arg, kernel_handler KH) {
2459-
if (Arg[0] >= NumWorkItems[0])
2460-
return;
2461-
Arg.set_allowed_range(NumWorkItems);
2462-
KernelFunc(Arg, KH);
2463-
};
2494+
return detail::RoundedRangeKernelWithKH<TransformedArgType, Dims,
2495+
KernelType>(NumWorkItems,
2496+
KernelFunc);
24642497
} else {
2465-
return [=](TransformedArgType Arg) {
2466-
if (Arg[0] >= NumWorkItems[0])
2467-
return;
2468-
Arg.set_allowed_range(NumWorkItems);
2469-
KernelFunc(Arg);
2470-
};
2498+
return detail::RoundedRangeKernel<TransformedArgType, Dims, KernelType>(
2499+
NumWorkItems, KernelFunc);
24712500
}
24722501
}
24732502
};

sycl/include/CL/sycl/id.hpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616

1717
__SYCL_INLINE_NAMESPACE(cl) {
1818
namespace sycl {
19+
// Forward declarations
20+
namespace detail {
21+
template <typename TransformedArgType, int Dims, typename KernelType>
22+
class RoundedRangeKernel;
23+
template <typename TransformedArgType, int Dims, typename KernelType>
24+
class RoundedRangeKernelWithKH;
25+
} // namespace detail
1926
template <int dimensions> class range;
2027
template <int dimensions, bool with_offset> class item;
2128

@@ -241,7 +248,10 @@ template <int dimensions = 1> class id : public detail::array<dimensions> {
241248
#undef __SYCL_GEN_OPT
242249

243250
private:
244-
friend class handler;
251+
// Friend to get access to private method set_allowed_range().
252+
template <typename, int, typename> friend class detail::RoundedRangeKernel;
253+
template <typename, int, typename>
254+
friend class detail::RoundedRangeKernelWithKH;
245255
void set_allowed_range(range<dimensions> rnwi) { (void)rnwi[0]; }
246256
};
247257

sycl/include/CL/sycl/item.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ __SYCL_INLINE_NAMESPACE(cl) {
2121
namespace sycl {
2222
namespace detail {
2323
class Builder;
24+
template <typename TransformedArgType, int Dims, typename KernelType>
25+
class RoundedRangeKernel;
26+
template <typename TransformedArgType, int Dims, typename KernelType>
27+
class RoundedRangeKernelWithKH;
2428
}
2529
template <int dimensions> class id;
2630
template <int dimensions> class range;
@@ -120,7 +124,10 @@ template <int dimensions = 1, bool with_offset = true> class item {
120124
friend class detail::Builder;
121125

122126
private:
123-
friend class handler;
127+
// Friend to get access to private method set_allowed_range().
128+
template <typename, int, typename> friend class detail::RoundedRangeKernel;
129+
template <typename, int, typename>
130+
friend class detail::RoundedRangeKernelWithKH;
124131
void set_allowed_range(const range<dimensions> rnwi) { MImpl.MExtent = rnwi; }
125132

126133
detail::ItemBase<dimensions, with_offset> MImpl;

sycl/include/CL/sycl/types.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,12 @@ convertImpl(T Value) {
503503

504504
#endif // __SYCL_DEVICE_ONLY__
505505

506+
// Forward declarations
507+
template <typename TransformedArgType, int Dims, typename KernelType>
508+
class RoundedRangeKernel;
509+
template <typename TransformedArgType, int Dims, typename KernelType>
510+
class RoundedRangeKernelWithKH;
511+
506512
} // namespace detail
507513

508514
#if defined(_WIN32) && (_MSC_VER)
@@ -2395,6 +2401,19 @@ template <typename FuncT>
23952401
struct CheckDeviceCopyable
23962402
: CheckFieldsAreDeviceCopyable<FuncT, __builtin_num_fields(FuncT)>,
23972403
CheckBasesAreDeviceCopyable<FuncT, __builtin_num_bases(FuncT)> {};
2404+
2405+
// Below are two specializations for CheckDeviceCopyable when a kernel lambda
2406+
// is wrapped after range rounding optimization.
2407+
template <typename TransformedArgType, int Dims, typename KernelType>
2408+
struct CheckDeviceCopyable<
2409+
RoundedRangeKernel<TransformedArgType, Dims, KernelType>>
2410+
: CheckDeviceCopyable<KernelType> {};
2411+
2412+
template <typename TransformedArgType, int Dims, typename KernelType>
2413+
struct CheckDeviceCopyable<
2414+
RoundedRangeKernelWithKH<TransformedArgType, Dims, KernelType>>
2415+
: CheckDeviceCopyable<KernelType> {};
2416+
23982417
#endif // __SYCL_DEVICE_ONLY__
23992418
} // namespace detail
24002419

sycl/test/basic_tests/is_device_copyable.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,10 @@ void test() {
8787

8888
Q.single_task<class TestB>(FunctorA{});
8989
Q.single_task<class TestC>(FunctorB{});
90+
91+
Q.submit([=](sycl::handler &cgh) {
92+
const sycl::range<2> range(1026, 1026);
93+
cgh.parallel_for(range,
94+
[=](sycl::item<2> item) { int A = IamBadButCopyable.i; });
95+
});
9096
}

0 commit comments

Comments
 (0)