diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 8199dbfb3df6d..442abcb63951c 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -828,9 +828,11 @@ void ReleaseCommand::printDot(std::ostream &Stream) const { } MapMemObject::MapMemObject(AllocaCommandBase *SrcAllocaCmd, Requirement Req, - void **DstPtr, QueueImplPtr Queue) + void **DstPtr, QueueImplPtr Queue, + access::mode MapMode) : Command(CommandType::MAP_MEM_OBJ, std::move(Queue)), - MSrcAllocaCmd(SrcAllocaCmd), MSrcReq(std::move(Req)), MDstPtr(DstPtr) { + MSrcAllocaCmd(SrcAllocaCmd), MSrcReq(std::move(Req)), MDstPtr(DstPtr), + MMapMode(MapMode) { emitInstrumentationDataProxy(); } @@ -861,9 +863,8 @@ cl_int MapMemObject::enqueueImp() { RT::PiEvent &Event = MEvent->getHandleRef(); *MDstPtr = MemoryManager::map( MSrcAllocaCmd->getSYCLMemObj(), MSrcAllocaCmd->getMemAllocation(), MQueue, - MSrcReq.MAccessMode, MSrcReq.MDims, MSrcReq.MMemoryRange, - MSrcReq.MAccessRange, MSrcReq.MOffset, MSrcReq.MElemSize, - std::move(RawEvents), Event); + MMapMode, MSrcReq.MDims, MSrcReq.MMemoryRange, MSrcReq.MAccessRange, + MSrcReq.MOffset, MSrcReq.MElemSize, std::move(RawEvents), Event); return CL_SUCCESS; } diff --git a/sycl/source/detail/scheduler/commands.hpp b/sycl/source/detail/scheduler/commands.hpp index 13763cc77a8d3..7e19d281250ca 100644 --- a/sycl/source/detail/scheduler/commands.hpp +++ b/sycl/source/detail/scheduler/commands.hpp @@ -321,7 +321,7 @@ class AllocaSubBufCommand : public AllocaCommandBase { class MapMemObject : public Command { public: MapMemObject(AllocaCommandBase *SrcAllocaCmd, Requirement Req, void **DstPtr, - QueueImplPtr Queue); + QueueImplPtr Queue, access::mode MapMode); void printDot(std::ostream &Stream) const final; const Requirement *getRequirement() const final { return &MSrcReq; } @@ -333,6 +333,7 @@ class MapMemObject : public Command { AllocaCommandBase *MSrcAllocaCmd = nullptr; Requirement MSrcReq; void **MDstPtr = nullptr; + access::mode MMapMode; }; class UnMapMemObject : public Command { diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 62db768e40dfc..38cefea5f867d 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -48,6 +48,24 @@ static bool IsSuitableSubReq(const Requirement *Req) { return Req->MIsSubBuffer; } +// Checks if the required access mode is allowed under the current one +static bool isAccessModeAllowed(access::mode Required, access::mode Current) { + switch (Current) { + case access::mode::read: + return (Required == Current); + case access::mode::write: + assert(false && "Write only access is expected to be mapped as read_write"); + return (Required == Current || Required == access::mode::discard_write); + case access::mode::read_write: + case access::mode::atomic: + case access::mode::discard_write: + case access::mode::discard_read_write: + return true; + } + assert(false); + return false; +} + Scheduler::GraphBuilder::GraphBuilder() { if (const char *EnvVarCStr = SYCLConfig::get()) { std::string GraphPrintOpts(EnvVarCStr); @@ -199,7 +217,8 @@ UpdateHostRequirementCommand *Scheduler::GraphBuilder::insertUpdateHostReqCmd( // Takes linked alloca commands. Makes AllocaCmdDst command active using map // or unmap operation. static Command *insertMapUnmapForLinkedCmds(AllocaCommandBase *AllocaCmdSrc, - AllocaCommandBase *AllocaCmdDst) { + AllocaCommandBase *AllocaCmdDst, + access::mode MapMode) { assert(AllocaCmdSrc->MLinkedAllocaCmd == AllocaCmdDst && "Expected linked alloca commands"); assert(AllocaCmdSrc->MIsActive && @@ -215,9 +234,9 @@ static Command *insertMapUnmapForLinkedCmds(AllocaCommandBase *AllocaCmdSrc, return UnMapCmd; } - MapMemObject *MapCmd = - new MapMemObject(AllocaCmdSrc, *AllocaCmdSrc->getRequirement(), - &AllocaCmdDst->MMemAllocation, AllocaCmdSrc->getQueue()); + MapMemObject *MapCmd = new MapMemObject( + AllocaCmdSrc, *AllocaCmdSrc->getRequirement(), + &AllocaCmdDst->MMemAllocation, AllocaCmdSrc->getQueue(), MapMode); std::swap(AllocaCmdSrc->MIsActive, AllocaCmdDst->MIsActive); @@ -277,7 +296,12 @@ Command *Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record, Command *NewCmd = nullptr; if (AllocaCmdSrc->MLinkedAllocaCmd == AllocaCmdDst) { - NewCmd = insertMapUnmapForLinkedCmds(AllocaCmdSrc, AllocaCmdDst); + // Map write only as read-write + access::mode MapMode = Req->MAccessMode; + if (MapMode == access::mode::write) + MapMode = access::mode::read_write; + NewCmd = insertMapUnmapForLinkedCmds(AllocaCmdSrc, AllocaCmdDst, MapMode); + Record->MHostAccess = MapMode; } else { // Full copy of buffer is needed to avoid loss of data that may be caused @@ -298,6 +322,43 @@ Command *Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record, return NewCmd; } +Command *Scheduler::GraphBuilder::remapMemoryObject( + MemObjRecord *Record, Requirement *Req, AllocaCommandBase *HostAllocaCmd) { + assert(HostAllocaCmd->getQueue()->is_host() && + "Host alloca command expected"); + assert(HostAllocaCmd->MIsActive && "Active alloca command expected"); + + AllocaCommandBase *LinkedAllocaCmd = HostAllocaCmd->MLinkedAllocaCmd; + assert(LinkedAllocaCmd && "Linked alloca command expected"); + + std::set Deps = findDepsForReq(Record, Req, Record->MCurContext); + + UnMapMemObject *UnMapCmd = new UnMapMemObject( + LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement(), + &HostAllocaCmd->MMemAllocation, LinkedAllocaCmd->getQueue()); + + // Map write only as read-write + access::mode MapMode = Req->MAccessMode; + if (MapMode == access::mode::write) + MapMode = access::mode::read_write; + MapMemObject *MapCmd = new MapMemObject( + LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement(), + &HostAllocaCmd->MMemAllocation, LinkedAllocaCmd->getQueue(), MapMode); + + for (Command *Dep : Deps) { + UnMapCmd->addDep(DepDesc{Dep, UnMapCmd->getRequirement(), LinkedAllocaCmd}); + Dep->addUser(UnMapCmd); + } + + MapCmd->addDep(DepDesc{UnMapCmd, MapCmd->getRequirement(), HostAllocaCmd}); + UnMapCmd->addUser(MapCmd); + + updateLeaves(Deps, Record, access::mode::read_write); + addNodeToLeaves(Record, MapCmd, access::mode::read_write); + Record->MHostAccess = MapMode; + return MapCmd; +} + // The function adds copy operation of the up to date'st memory to the memory // pointed by Req. Command *Scheduler::GraphBuilder::addCopyBack(Requirement *Req) { @@ -352,8 +413,11 @@ Command *Scheduler::GraphBuilder::addHostAccessor(Requirement *Req, AllocaCommandBase *HostAllocaCmd = getOrCreateAllocaForReq(Record, Req, HostQueue); - if (!sameCtx(HostAllocaCmd->getQueue()->getContextImplPtr(), - Record->MCurContext)) + if (sameCtx(HostAllocaCmd->getQueue()->getContextImplPtr(), + Record->MCurContext)) { + if (!isAccessModeAllowed(Req->MAccessMode, Record->MHostAccess)) + remapMemoryObject(Record, Req, HostAllocaCmd); + } else insertMemoryMove(Record, Req, HostQueue); Command *UpdateHostAccCmd = insertUpdateHostReqCmd(Record, Req, HostQueue); @@ -603,7 +667,13 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr CommandGroup, AllocaCommandBase *AllocaCmd = getOrCreateAllocaForReq(Record, Req, Queue); // If there is alloca command we need to check if the latest memory is in // required context. - if (!sameCtx(Queue->getContextImplPtr(), Record->MCurContext)) { + if (sameCtx(Queue->getContextImplPtr(), Record->MCurContext)) { + // If the memory is already in the required host context, check if the + // required access mode is valid, remap if not. + if (Record->MCurContext->is_host() && + !isAccessModeAllowed(Req->MAccessMode, Record->MHostAccess)) + remapMemoryObject(Record, Req, AllocaCmd); + } else { // Cannot directly copy memory from OpenCL device to OpenCL device - // create two copies: device->host and host->device. if (!Queue->is_host() && !Record->MCurContext->is_host()) diff --git a/sycl/source/detail/scheduler/scheduler.hpp b/sycl/source/detail/scheduler/scheduler.hpp index 90000f6ab558c..19c79ea5e4e18 100644 --- a/sycl/source/detail/scheduler/scheduler.hpp +++ b/sycl/source/detail/scheduler/scheduler.hpp @@ -50,6 +50,10 @@ struct MemObjRecord { // The context which has the latest state of the memory object. ContextImplPtr MCurContext; + // The mode this object can be accessed with from the host context. + // Valid only if the current context is host. + access::mode MHostAccess = access::mode::read_write; + // The flag indicates that the content of the memory object was/will be // modified. Used while deciding if copy back needed. bool MMemModified = false; @@ -171,6 +175,11 @@ class Scheduler { Command *insertMemoryMove(MemObjRecord *Record, Requirement *Req, const QueueImplPtr &Queue); + // Inserts commands required to remap the memory object to its current host + // context so that the required access mode becomes valid. + Command *remapMemoryObject(MemObjRecord *Record, Requirement *Req, + AllocaCommandBase *HostAllocaCmd); + UpdateHostRequirementCommand * insertUpdateHostReqCmd(MemObjRecord *Record, Requirement *Req, const QueueImplPtr &Queue); diff --git a/sycl/test/scheduler/MemObjRemapping.cpp b/sycl/test/scheduler/MemObjRemapping.cpp new file mode 100644 index 0000000000000..e0b49f6b94b62 --- /dev/null +++ b/sycl/test/scheduler/MemObjRemapping.cpp @@ -0,0 +1,83 @@ +// RUN: %clangxx -fsycl %s -o %t.out +// RUN: env SYCL_PI_TRACE=1 %CPU_RUN_PLACEHOLDER %t.out 2>&1 %CPU_CHECK_PLACEHOLDER +#include +#include +#include + +using namespace cl::sycl; + +class Foo; +class Bar; + +// This test checks that memory objects are remapped on requesting an access mode +// incompatible with the current mapping. Write access is mapped as read-write. +int main() { + queue Q; + + std::size_t Size = 64; + range<1> Range{Size}; + buffer BufA{Range}; + buffer BufB{Range}; + + Q.submit([&](handler &Cgh) { + auto AccA = BufA.get_access(Cgh); + auto AccB = BufB.get_access(Cgh); + Cgh.parallel_for(Range, [=](id<1> Idx) { + AccA[Idx] = Idx[0]; + AccB[Idx] = Idx[0]; + }); + }); + + { + // Check access mode flags + // CHECK: piEnqueueMemBufferMap + // CHECK-NEXT: : + // CHECK-NEXT: : + // CHECK-NEXT: : + // CHECK-NEXT: : 1 + // CHECK: piEnqueueMemBufferMap + // CHECK-NEXT: : + // CHECK-NEXT: : + // CHECK-NEXT: : + // CHECK-NEXT: : 1 + auto AccA = BufA.get_access(); + auto AccB = BufB.get_access(); + for (std::size_t I = 0; I < Size; ++I) { + assert(AccA[I] == I); + assert(AccB[I] == I); + } + } + { + // CHECK: piEnqueueMemUnmap + // CHECK: piEnqueueMemBufferMap + // CHECK-NEXT: : + // CHECK-NEXT: : + // CHECK-NEXT: : + // CHECK-NEXT: : 3 + auto AccA = BufA.get_access(); + for (std::size_t I = 0; I < Size; ++I) + AccA[I] = 2 * I; + } + + queue HostQ{host_selector()}; + // CHECK: piEnqueueMemUnmap + // CHECK: piEnqueueMemBufferMap + // CHECK-NEXT: : + // CHECK-NEXT: : + // CHECK-NEXT: : + // CHECK-NEXT: : 3 + HostQ.submit([&](handler &Cgh) { + auto AccB = BufB.get_access(Cgh); + Cgh.parallel_for(Range, [=](id<1> Idx) { + AccB[Idx] = 2 * Idx[0]; + }); + }); + + // CHECK-NOT: piEnqueueMemBufferMap + auto AccA = BufA.get_access(); + auto AccB = BufB.get_access(); + for (std::size_t I = 0; I < Size; ++I) { + assert(AccA[I] == 2 * I); + assert(AccB[I] == 2 * I); + } +}