Skip to content

Commit 7e3cca4

Browse files
authored
[SYCL] Added support of rounding modes for non-host devices (#1463)
Signed-off-by: Aleksander Fadeev <aleksander.fadeev@intel.com>
1 parent de1c363 commit 7e3cca4

File tree

3 files changed

+69
-14
lines changed

3 files changed

+69
-14
lines changed

clang/lib/Sema/SPIRVBuiltins.td

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ class ConstOCLSPVBuiltin<string _Name, list<Type> _Signature> :
286286

287287
// OpenCL v1.0/1.2/2.0 s6.1.1: Built-in Scalar Data Types.
288288
def Bool : IntType<"bool", QualType<"BoolTy">, 1>;
289-
def TrueChar : IntType<"char", QualType<"CharTy", 0, 1>, 8>;
290-
def Char : IntType<"schar", QualType<"SignedCharTy", 0, 1>, 8>;
289+
def TrueChar : IntType<"_char", QualType<"CharTy", 0, 1>, 8>;
290+
def Char : IntType<"char", QualType<"SignedCharTy", 0, 1>, 8>;
291291
def SChar : IntType<"schar", QualType<"SignedCharTy", 0, 1>, 8>;
292292
def UChar : UIntType<"uchar", QualType<"UnsignedCharTy">, 8>;
293293
def Short : IntType<"short", QualType<"ShortTy", 0, 1>, 16>;
@@ -713,8 +713,10 @@ foreach name = ["GenericPtrMemSemantics"] in {
713713

714714
foreach IType = [UChar, UShort, UInt, ULong] in {
715715
foreach FType = [Float, Double, Half] in {
716-
def : SPVBuiltin<"ConvertFToU_R" # IType.Name, [IType, FType], Attr.Const>;
717716
def : SPVBuiltin<"ConvertUToF_R" # FType.Name, [FType, IType], Attr.Const>;
717+
foreach rnd = ["", "_rte", "_rtz", "_rtp", "_rtn"] in {
718+
def : SPVBuiltin<"ConvertFToU_R" # IType.Name # rnd, [IType, FType], Attr.Const>;
719+
}
718720
foreach v = [2, 3, 4, 8, 16] in {
719721
def : SPVBuiltin<"ConvertFToU_R" # IType.Name # v,
720722
[VectorType<IType, v>, VectorType<FType, v>],
@@ -728,8 +730,10 @@ foreach IType = [UChar, UShort, UInt, ULong] in {
728730

729731
foreach IType = [Char, Short, Int, Long] in {
730732
foreach FType = [Float, Double, Half] in {
731-
def : SPVBuiltin<"ConvertFToS_R" # IType.Name, [IType, FType], Attr.Const>;
732733
def : SPVBuiltin<"ConvertSToF_R" # FType.Name, [FType, IType], Attr.Const>;
734+
foreach rnd = ["", "_rte", "_rtz", "_rtp", "_rtn"] in {
735+
def : SPVBuiltin<"ConvertFToS_R" # IType.Name # rnd, [IType, FType], Attr.Const>;
736+
}
733737
foreach v = [2, 3, 4, 8, 16] in {
734738
def : SPVBuiltin<"ConvertFToS_R" # IType.Name # v,
735739
[VectorType<IType, v>, VectorType<FType, v>],

sycl/include/CL/sycl/types.hpp

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,10 @@ convertImpl(T Value) {
231231
return static_cast<R>(Value);
232232
}
233233

234+
#ifndef __SYCL_DEVICE_ONLY__
234235
// float to int
235236
template <typename T, typename R, rounding_mode roundingMode>
236237
detail::enable_if_t<is_float_to_int<T, R>::value, R> convertImpl(T Value) {
237-
#ifndef __SYCL_DEVICE_ONLY__
238238
switch (roundingMode) {
239239
// Round to nearest even is default rounding mode for floating-point types
240240
case rounding_mode::automatic:
@@ -264,11 +264,62 @@ detail::enable_if_t<is_float_to_int<T, R>::value, R> convertImpl(T Value) {
264264
assert(!"Unsupported rounding mode!");
265265
return static_cast<R>(Value);
266266
};
267-
#else
268-
// TODO implement device side conversion.
269-
return static_cast<R>(Value);
270-
#endif
271267
}
268+
#else
269+
270+
template <rounding_mode Mode>
271+
using RteOrAutomatic = detail::bool_constant<Mode == rounding_mode::automatic ||
272+
Mode == rounding_mode::rte>;
273+
274+
template <rounding_mode Mode>
275+
using Rtz = detail::bool_constant<Mode == rounding_mode::rtz>;
276+
277+
template <rounding_mode Mode>
278+
using Rtp = detail::bool_constant<Mode == rounding_mode::rtp>;
279+
280+
template <rounding_mode Mode>
281+
using Rtn = detail::bool_constant<Mode == rounding_mode::rtn>;
282+
283+
// Convert floating-point type to integer type
284+
#define __SYCL_GENERATE_CONVERT_IMPL(SPIRVOp, DestType, RoundingMode, \
285+
RoundingModeCondition) \
286+
template <typename T, typename R, rounding_mode roundingMode> \
287+
detail::enable_if_t<is_float_to_int<T, R>::value && \
288+
std::is_same<R, DestType>::value && \
289+
RoundingModeCondition<roundingMode>::value, \
290+
R> \
291+
convertImpl(T Value) { \
292+
using OpenCLT = cl::sycl::detail::ConvertToOpenCLType_t<T>; \
293+
OpenCLT OpValue = cl::sycl::detail::convertDataToType<T, OpenCLT>(Value); \
294+
return __spirv_Convert##SPIRVOp##_R##DestType##_##RoundingMode(OpValue); \
295+
}
296+
297+
#define __SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(RoundingMode, \
298+
RoundingModeCondition) \
299+
__SYCL_GENERATE_CONVERT_IMPL(FToS, int, RoundingMode, RoundingModeCondition) \
300+
__SYCL_GENERATE_CONVERT_IMPL(FToS, char, RoundingMode, \
301+
RoundingModeCondition) \
302+
__SYCL_GENERATE_CONVERT_IMPL(FToS, short, RoundingMode, \
303+
RoundingModeCondition) \
304+
__SYCL_GENERATE_CONVERT_IMPL(FToS, long, RoundingMode, \
305+
RoundingModeCondition) \
306+
__SYCL_GENERATE_CONVERT_IMPL(FToU, uint, RoundingMode, \
307+
RoundingModeCondition) \
308+
__SYCL_GENERATE_CONVERT_IMPL(FToU, uchar, RoundingMode, \
309+
RoundingModeCondition) \
310+
__SYCL_GENERATE_CONVERT_IMPL(FToU, ushort, RoundingMode, \
311+
RoundingModeCondition) \
312+
__SYCL_GENERATE_CONVERT_IMPL(FToU, ulong, RoundingMode, RoundingModeCondition)
313+
314+
__SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(rte, RteOrAutomatic)
315+
__SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(rtz, Rtz)
316+
__SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(rtp, Rtp)
317+
__SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(rtn, Rtn)
318+
319+
#undef __SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE
320+
#undef __SYCL_GENERATE_CONVERT_IMPL
321+
322+
#endif // __SYCL_DEVICE_ONLY__
272323

273324
} // namespace detail
274325

sycl/test/basic_tests/vec_convert.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
// XFAIL: cuda
12
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
23
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
3-
// RUNx: %CPU_RUN_PLACEHOLDER %t.out
4-
// RUNx: %GPU_RUN_PLACEHOLDER %t.out
5-
// RUNx: %ACC_RUN_PLACEHOLDER %t.out
4+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
5+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
6+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
67
//==------------ vec_convert.cpp - SYCL vec class convert method test ------==//
78
//
89
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -15,8 +16,7 @@
1516

1617
#include <cassert>
1718

18-
// TODO uncomment run lines on non-host devices when the rounding modes will
19-
// be implemented.
19+
// TODO make the test to pass on cuda
2020

2121
using namespace cl::sycl;
2222

0 commit comments

Comments
 (0)