Skip to content

[SYCL] Refactor USM allocator to improve ABI stability #1064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 39 additions & 32 deletions sycl/include/CL/sycl/usm/usm_allocator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#pragma once

#include <CL/sycl/context.hpp>
#include <CL/sycl/detail/usm_impl.hpp>
#include <CL/sycl/device.hpp>
#include <CL/sycl/exception.hpp>
#include <CL/sycl/queue.hpp>
Expand All @@ -20,6 +19,11 @@
__SYCL_INLINE namespace cl {
namespace sycl {

// Forward declarations.
void *aligned_alloc(size_t alignment, size_t size, const device &dev,
const context &ctxt, usm::alloc kind);
void free(void *ptr, const context &ctxt);

template <typename T, usm::alloc AllocKind, size_t Alignment = 0>
class usm_allocator {
public:
Expand All @@ -36,14 +40,19 @@ class usm_allocator {

usm_allocator() = delete;
usm_allocator(const context &Ctxt, const device &Dev)
: mContext(Ctxt), mDevice(Dev) {}
: MContext(Ctxt), MDevice(Dev) {}
usm_allocator(const queue &Q)
: mContext(Q.get_context()), mDevice(Q.get_device()) {}
: MContext(Q.get_context()), MDevice(Q.get_device()) {}
usm_allocator(const usm_allocator &Other)
: mContext(Other.mContext), mDevice(Other.mDevice) {}

// Construct an object
// Note: AllocKind == alloc::device is not allowed
: MContext(Other.MContext), MDevice(Other.MDevice) {}

/// Constructs an object on memory pointed by Ptr.
///
/// Note: AllocKind == alloc::device is not allowed.
///
/// @param Ptr is a pointer to memory that will be used to construct the
/// object.
/// @param Val is a value to initialize the newly constructed object.
template <
usm::alloc AllocT = AllocKind,
typename std::enable_if<AllocT != usm::alloc::device, int>::type = 0>
Expand All @@ -59,8 +68,11 @@ class usm_allocator {
"Device pointers do not support construct on host");
}

// Destroy an object
// Note:: AllocKind == alloc::device is not allowed
/// Destroys an object.
///
/// Note:: AllocKind == alloc::device is not allowed
///
/// @param Ptr is a pointer to memory where the object resides.
template <
usm::alloc AllocT = AllocKind,
typename std::enable_if<AllocT != usm::alloc::device, int>::type = 0>
Expand All @@ -76,7 +88,10 @@ class usm_allocator {
"Device pointers do not support destroy on host");
}

// Note:: AllocKind == alloc::device is not allowed
/// Note:: AllocKind == alloc::device is not allowed.
///
/// @param Val is a reference to object.
/// @return an address of the object referenced by Val.
template <
usm::alloc AllocT = AllocKind,
typename std::enable_if<AllocT != usm::alloc::device, int>::type = 0>
Expand Down Expand Up @@ -107,35 +122,27 @@ class usm_allocator {
"Device pointers do not support address on host");
}

// Allocate memory
template <
usm::alloc AllocT = AllocKind,
typename std::enable_if<AllocT == usm::alloc::host, int>::type = 0>
pointer allocate(size_t Size) {
auto Result = reinterpret_cast<pointer>(detail::usm::alignedAllocHost(
getAlignment(), Size * sizeof(value_type), mContext, AllocKind));
if (!Result) {
throw memory_allocation_error();
}
return Result;
}
/// Allocates memory.
///
/// @param NumberOfElements is a count of elements to allocate memory for.
pointer allocate(size_t NumberOfElements) {

template <usm::alloc AllocT = AllocKind,
typename std::enable_if<AllocT != usm::alloc::host, int>::type = 0>
pointer allocate(size_t Size) {
auto Result = reinterpret_cast<pointer>(
detail::usm::alignedAlloc(getAlignment(), Size * sizeof(value_type),
mContext, mDevice, AllocKind));
aligned_alloc(getAlignment(), NumberOfElements * sizeof(value_type),
MDevice, MContext, AllocKind));
if (!Result) {
throw memory_allocation_error();
}
return Result;
}

// Deallocate memory
void deallocate(pointer Ptr, size_t size) {
/// Deallocates memory.
///
/// @param Ptr is a pointer to memory being deallocated.
/// @param Size is a number of elements previously passed to allocate.
void deallocate(pointer Ptr, size_t Size) {
if (Ptr) {
detail::usm::free(Ptr, mContext);
free(Ptr, MContext);
}
}

Expand All @@ -151,8 +158,8 @@ class usm_allocator {
return Alignment;
}

const context mContext;
const device mDevice;
const context MContext;
const device MDevice;
};

} // namespace sycl
Expand Down