Skip to content

[SYCL] Make reduction compatible with MSVC host compiler #6601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 53 additions & 58 deletions sycl/include/sycl/ext/oneapi/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,57 +192,52 @@ struct ReducerTraits<reducer<T, BinaryOperation, Dims, Extent, View, Subst>> {
/// Also, for int32/64 types the atomic_combine() is lowered to
/// sycl::atomic::fetch_add().
template <class Reducer> class combiner {
using T = typename ReducerTraits<Reducer>::type;
using BinaryOperation = typename ReducerTraits<Reducer>::op;
using Ty = typename ReducerTraits<Reducer>::type;
using BinaryOp = typename ReducerTraits<Reducer>::op;
static constexpr int Dims = ReducerTraits<Reducer>::dims;
static constexpr size_t Extent = ReducerTraits<Reducer>::extent;

public:
template <typename _T = T, int _Dims = Dims>
enable_if_t<(_Dims == 0) &&
sycl::detail::IsPlus<_T, BinaryOperation>::value &&
template <typename _T = Ty, int _Dims = Dims>
enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOp>::value &&
sycl::detail::is_geninteger<_T>::value>
operator++() {
static_cast<Reducer *>(this)->combine(static_cast<T>(1));
static_cast<Reducer *>(this)->combine(static_cast<_T>(1));
}

template <typename _T = T, int _Dims = Dims>
enable_if_t<(_Dims == 0) &&
sycl::detail::IsPlus<_T, BinaryOperation>::value &&
template <typename _T = Ty, int _Dims = Dims>
enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOp>::value &&
sycl::detail::is_geninteger<_T>::value>
operator++(int) {
static_cast<Reducer *>(this)->combine(static_cast<T>(1));
static_cast<Reducer *>(this)->combine(static_cast<_T>(1));
}

template <typename _T = T, int _Dims = Dims>
enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOperation>::value>
template <typename _T = Ty, int _Dims = Dims>
enable_if_t<(_Dims == 0) && sycl::detail::IsPlus<_T, BinaryOp>::value>
operator+=(const _T &Partial) {
static_cast<Reducer *>(this)->combine(Partial);
}

template <typename _T = T, int _Dims = Dims>
enable_if_t<(_Dims == 0) &&
sycl::detail::IsMultiplies<_T, BinaryOperation>::value>
template <typename _T = Ty, int _Dims = Dims>
enable_if_t<(_Dims == 0) && sycl::detail::IsMultiplies<_T, BinaryOp>::value>
operator*=(const _T &Partial) {
static_cast<Reducer *>(this)->combine(Partial);
}

template <typename _T = T, int _Dims = Dims>
enable_if_t<(_Dims == 0) && sycl::detail::IsBitOR<_T, BinaryOperation>::value>
template <typename _T = Ty, int _Dims = Dims>
enable_if_t<(_Dims == 0) && sycl::detail::IsBitOR<_T, BinaryOp>::value>
operator|=(const _T &Partial) {
static_cast<Reducer *>(this)->combine(Partial);
}

template <typename _T = T, int _Dims = Dims>
enable_if_t<(_Dims == 0) &&
sycl::detail::IsBitXOR<_T, BinaryOperation>::value>
template <typename _T = Ty, int _Dims = Dims>
enable_if_t<(_Dims == 0) && sycl::detail::IsBitXOR<_T, BinaryOp>::value>
operator^=(const _T &Partial) {
static_cast<Reducer *>(this)->combine(Partial);
}

template <typename _T = T, int _Dims = Dims>
enable_if_t<(_Dims == 0) &&
sycl::detail::IsBitAND<_T, BinaryOperation>::value>
template <typename _T = Ty, int _Dims = Dims>
enable_if_t<(_Dims == 0) && sycl::detail::IsBitAND<_T, BinaryOp>::value>
operator&=(const _T &Partial) {
static_cast<Reducer *>(this)->combine(Partial);
}
Expand All @@ -266,53 +261,53 @@ template <class Reducer> class combiner {
}
}

template <class _T, access::address_space Space, class BinaryOperation>
template <class _T, access::address_space Space, class BinaryOp>
static constexpr bool BasicCheck =
std::is_same<typename remove_AS<_T>::type, T>::value &&
std::is_same<typename remove_AS<_T>::type, Ty>::value &&
(Space == access::address_space::global_space ||
Space == access::address_space::local_space);

public:
/// Atomic ADD operation: *ReduVarPtr += MValue;
template <access::address_space Space = access::address_space::global_space,
typename _T = T, class _BinaryOperation = BinaryOperation>
typename _T = Ty, class _BinaryOperation = BinaryOp>
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
(IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value ||
IsReduOptForAtomic64Op<T, _BinaryOperation>::value) &&
sycl::detail::IsPlus<T, _BinaryOperation>::value>
(IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value ||
IsReduOptForAtomic64Op<_T, _BinaryOperation>::value) &&
sycl::detail::IsPlus<_T, _BinaryOperation>::value>
atomic_combine(_T *ReduVarPtr) const {
atomic_combine_impl<Space>(
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_add(Val); });
}

/// Atomic BITWISE OR operation: *ReduVarPtr |= MValue;
template <access::address_space Space = access::address_space::global_space,
typename _T = T, class _BinaryOperation = BinaryOperation>
typename _T = Ty, class _BinaryOperation = BinaryOp>
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
sycl::detail::IsBitOR<T, _BinaryOperation>::value>
IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value &&
sycl::detail::IsBitOR<_T, _BinaryOperation>::value>
atomic_combine(_T *ReduVarPtr) const {
atomic_combine_impl<Space>(
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_or(Val); });
}

/// Atomic BITWISE XOR operation: *ReduVarPtr ^= MValue;
template <access::address_space Space = access::address_space::global_space,
typename _T = T, class _BinaryOperation = BinaryOperation>
typename _T = Ty, class _BinaryOperation = BinaryOp>
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
sycl::detail::IsBitXOR<T, _BinaryOperation>::value>
IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value &&
sycl::detail::IsBitXOR<_T, _BinaryOperation>::value>
atomic_combine(_T *ReduVarPtr) const {
atomic_combine_impl<Space>(
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_xor(Val); });
}

/// Atomic BITWISE AND operation: *ReduVarPtr &= MValue;
template <access::address_space Space = access::address_space::global_space,
typename _T = T, class _BinaryOperation = BinaryOperation>
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
sycl::detail::IsBitAND<T, _BinaryOperation>::value &&
typename _T = Ty, class _BinaryOperation = BinaryOp>
enable_if_t<std::is_same<typename remove_AS<_T>::type, _T>::value &&
IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value &&
sycl::detail::IsBitAND<_T, _BinaryOperation>::value &&
(Space == access::address_space::global_space ||
Space == access::address_space::local_space)>
atomic_combine(_T *ReduVarPtr) const {
Expand All @@ -322,23 +317,23 @@ template <class Reducer> class combiner {

/// Atomic MIN operation: *ReduVarPtr = sycl::minimum(*ReduVarPtr, MValue);
template <access::address_space Space = access::address_space::global_space,
typename _T = T, class _BinaryOperation = BinaryOperation>
typename _T = Ty, class _BinaryOperation = BinaryOp>
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
(IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value ||
IsReduOptForAtomic64Op<T, _BinaryOperation>::value) &&
sycl::detail::IsMinimum<T, _BinaryOperation>::value>
(IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value ||
IsReduOptForAtomic64Op<_T, _BinaryOperation>::value) &&
sycl::detail::IsMinimum<_T, _BinaryOperation>::value>
atomic_combine(_T *ReduVarPtr) const {
atomic_combine_impl<Space>(
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_min(Val); });
}

/// Atomic MAX operation: *ReduVarPtr = sycl::maximum(*ReduVarPtr, MValue);
template <access::address_space Space = access::address_space::global_space,
typename _T = T, class _BinaryOperation = BinaryOperation>
typename _T = Ty, class _BinaryOperation = BinaryOp>
enable_if_t<BasicCheck<_T, Space, _BinaryOperation> &&
(IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value ||
IsReduOptForAtomic64Op<T, _BinaryOperation>::value) &&
sycl::detail::IsMaximum<T, _BinaryOperation>::value>
(IsReduOptForFastAtomicFetch<_T, _BinaryOperation>::value ||
IsReduOptForAtomic64Op<_T, _BinaryOperation>::value) &&
sycl::detail::IsMaximum<_T, _BinaryOperation>::value>
atomic_combine(_T *ReduVarPtr) const {
atomic_combine_impl<Space>(
ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_max(Val); });
Expand Down Expand Up @@ -928,7 +923,7 @@ bool reduCGFuncForRangeFastAtomics(handler &CGH, KernelType KernelFunc,
const range<Dims> &Range,
const nd_range<1> &NDRange,
Reduction &Redu) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
auto Out = Redu.getReadWriteAccessorToInitializedMem(CGH);
auto GroupSum = Reduction::getReadWriteLocalAcc(NElements, CGH);
using Name = __sycl_reduction_kernel<reduction::main_krn::RangeFastAtomics,
Expand Down Expand Up @@ -976,7 +971,7 @@ template <typename KernelName, typename KernelType, int Dims, class Reduction>
bool reduCGFuncForRangeFastReduce(handler &CGH, KernelType KernelFunc,
const range<Dims> &Range,
const nd_range<1> &NDRange, Reduction &Redu) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
size_t WGSize = NDRange.get_local_range().size();
size_t NWorkGroups = NDRange.get_group_range().size();

Expand Down Expand Up @@ -1078,7 +1073,7 @@ template <typename KernelName, typename KernelType, int Dims, class Reduction>
bool reduCGFuncForRangeBasic(handler &CGH, KernelType KernelFunc,
const range<Dims> &Range,
const nd_range<1> &NDRange, Reduction &Redu) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
size_t WGSize = NDRange.get_local_range().size();
size_t NWorkGroups = NDRange.get_group_range().size();

Expand Down Expand Up @@ -1230,7 +1225,7 @@ template <typename KernelName, typename KernelType, int Dims, class Reduction>
void reduCGFuncForNDRangeBothFastReduceAndAtomics(
handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
Reduction &, typename Reduction::rw_accessor_type Out) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
using Name = __sycl_reduction_kernel<
reduction::main_krn::NDRangeBothFastReduceAndAtomics, KernelName>;
CGH.parallel_for<Name>(Range, [=](nd_item<Dims> NDIt) {
Expand Down Expand Up @@ -1266,7 +1261,7 @@ void reduCGFuncForNDRangeFastAtomicsOnly(
handler &CGH, bool IsPow2WG, KernelType KernelFunc,
const nd_range<Dims> &Range, Reduction &,
typename Reduction::rw_accessor_type Out) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
size_t WGSize = Range.get_local_range().size();

// Use local memory to reduce elements in work-groups into zero-th element.
Expand Down Expand Up @@ -1345,7 +1340,7 @@ template <typename KernelName, typename KernelType, int Dims, class Reduction>
void reduCGFuncForNDRangeFastReduceOnly(
handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
Reduction &Redu, typename Reduction::rw_accessor_type Out) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
size_t NWorkGroups = Range.get_group_range().size();
bool IsUpdateOfUserVar =
!Reduction::is_usm && !Redu.initializeToIdentity() && NWorkGroups == 1;
Expand Down Expand Up @@ -1392,7 +1387,7 @@ void reduCGFuncForNDRangeBasic(handler &CGH, bool IsPow2WG,
KernelType KernelFunc,
const nd_range<Dims> &Range, Reduction &Redu,
typename Reduction::rw_accessor_type Out) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
size_t WGSize = Range.get_local_range().size();
size_t NWorkGroups = Range.get_group_range().size();

Expand Down Expand Up @@ -1477,7 +1472,7 @@ void reduAuxCGFuncFastReduceImpl(handler &CGH, bool UniformWG,
size_t NWorkItems, size_t NWorkGroups,
size_t WGSize, Reduction &Redu, InputT In,
OutputT Out) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
using Name =
__sycl_reduction_kernel<reduction::aux_krn::FastReduce, KernelName>;
bool IsUpdateOfUserVar =
Expand Down Expand Up @@ -1523,7 +1518,7 @@ void reduAuxCGFuncNoFastReduceNorAtomicImpl(handler &CGH, bool UniformPow2WG,
size_t NWorkGroups, size_t WGSize,
Reduction &Redu, InputT In,
OutputT Out) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
bool IsUpdateOfUserVar =
!Reduction::is_usm && !Redu.initializeToIdentity() && NWorkGroups == 1;

Expand Down Expand Up @@ -1642,7 +1637,7 @@ reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
template <typename KernelName, class Reduction>
std::enable_if_t<Reduction::is_usm>
reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
auto InAcc = Redu.getReadAccToPreviousPartialReds(CGH);
auto UserVarPtr = Redu.getUserRedVar();
bool IsUpdateOfUserVar = !Redu.initializeToIdentity();
Expand Down Expand Up @@ -2120,7 +2115,7 @@ void reduCGFuncAtomic64(handler &CGH, KernelType KernelFunc,
static_assert(
Reduction::has_float64_atomics,
"Only suitable for reductions that have FP64 atomic operations.");
constexpr size_t NElements = Reduction::num_elements;
size_t NElements = Reduction::num_elements;
using Name =
__sycl_reduction_kernel<reduction::main_krn::NDRangeAtomic64, KernelName>;
CGH.parallel_for<Name>(Range, [=](nd_item<Dims> NDIt) {
Expand Down