From 8dab361b694d47606a5dd5b1d630f1f4215cc105 Mon Sep 17 00:00:00 2001 From: Nicolas Miller Date: Thu, 2 Sep 2021 17:42:21 +0100 Subject: [PATCH 1/2] [SYCL][CUDA] Fix context clearing in PiCuda tests `cuCtxSetCurrent(nullptr)` will only discard the top of the context stack so the current context may still not be `nullptr` after this. To fix this, this patch introduces a small utility function to pop the entire context stack when we're trying to reset it in the tests. --- sycl/unittests/pi/cuda/CudaUtils.hpp | 20 ++++++++++++++++++++ sycl/unittests/pi/cuda/test_commands.cpp | 3 ++- sycl/unittests/pi/cuda/test_contexts.cpp | 5 +++-- sycl/unittests/pi/cuda/test_mem_obj.cpp | 3 ++- 4 files changed, 27 insertions(+), 4 deletions(-) create mode 100644 sycl/unittests/pi/cuda/CudaUtils.hpp diff --git a/sycl/unittests/pi/cuda/CudaUtils.hpp b/sycl/unittests/pi/cuda/CudaUtils.hpp new file mode 100644 index 0000000000000..f7cb8b40492d3 --- /dev/null +++ b/sycl/unittests/pi/cuda/CudaUtils.hpp @@ -0,0 +1,20 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include + +namespace pi { + +// utility function to clear the CUDA context stack +inline void clearCudaContext() { + CUcontext ctxt = nullptr; + do { + cuCtxSetCurrent(nullptr); + cuCtxGetCurrent(&ctxt); + } while (ctxt != nullptr); +} + +} // namespace pi diff --git a/sycl/unittests/pi/cuda/test_commands.cpp b/sycl/unittests/pi/cuda/test_commands.cpp index d3d9ad4baf31e..996a649524dfa 100644 --- a/sycl/unittests/pi/cuda/test_commands.cpp +++ b/sycl/unittests/pi/cuda/test_commands.cpp @@ -11,6 +11,7 @@ #include #include "TestGetPlugin.hpp" +#include "CudaUtils.hpp" #include #include #include @@ -34,7 +35,7 @@ struct CudaCommandsTest : public ::testing::Test { GTEST_SKIP(); } - cuCtxSetCurrent(nullptr); + pi::clearCudaContext(); pi_uint32 numPlatforms = 0; ASSERT_EQ(plugin->getBackend(), backend::cuda); diff --git a/sycl/unittests/pi/cuda/test_contexts.cpp b/sycl/unittests/pi/cuda/test_contexts.cpp index 4007341f94839..5ad24e6836131 100644 --- a/sycl/unittests/pi/cuda/test_contexts.cpp +++ b/sycl/unittests/pi/cuda/test_contexts.cpp @@ -15,6 +15,7 @@ #include #include "TestGetPlugin.hpp" +#include "CudaUtils.hpp" #include #include #include @@ -63,7 +64,7 @@ struct CudaContextsTest : public ::testing::Test { TEST_F(CudaContextsTest, ContextLifetime) { // start with no active context - cuCtxSetCurrent(nullptr); + pi::clearCudaContext(); // create a context pi_context context; @@ -149,7 +150,7 @@ TEST_F(CudaContextsTest, ContextLifetimeExisting) { // still able to work correctly in that thread. TEST_F(CudaContextsTest, ContextThread) { // start with no active context - cuCtxSetCurrent(nullptr); + pi::clearCudaContext(); // create two PI contexts pi_context context1; diff --git a/sycl/unittests/pi/cuda/test_mem_obj.cpp b/sycl/unittests/pi/cuda/test_mem_obj.cpp index b3d85682279fc..0122e4cc74174 100644 --- a/sycl/unittests/pi/cuda/test_mem_obj.cpp +++ b/sycl/unittests/pi/cuda/test_mem_obj.cpp @@ -11,6 +11,7 @@ #include #include "TestGetPlugin.hpp" +#include "CudaUtils.hpp" #include #include #include @@ -34,7 +35,7 @@ struct CudaTestMemObj : public ::testing::Test { GTEST_SKIP(); } - cuCtxSetCurrent(nullptr); + pi::clearCudaContext(); pi_uint32 numPlatforms = 0; ASSERT_EQ(plugin->getBackend(), backend::cuda); From dbe979af4dcbfbad18fd4564824fb4358cc1177d Mon Sep 17 00:00:00 2001 From: Nicolas Miller Date: Fri, 3 Sep 2021 13:49:08 +0100 Subject: [PATCH 2/2] [SYCL][CUDA] Fix header ordering format --- sycl/unittests/pi/cuda/test_commands.cpp | 2 +- sycl/unittests/pi/cuda/test_contexts.cpp | 2 +- sycl/unittests/pi/cuda/test_mem_obj.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sycl/unittests/pi/cuda/test_commands.cpp b/sycl/unittests/pi/cuda/test_commands.cpp index 996a649524dfa..9de39910aa199 100644 --- a/sycl/unittests/pi/cuda/test_commands.cpp +++ b/sycl/unittests/pi/cuda/test_commands.cpp @@ -10,8 +10,8 @@ #include -#include "TestGetPlugin.hpp" #include "CudaUtils.hpp" +#include "TestGetPlugin.hpp" #include #include #include diff --git a/sycl/unittests/pi/cuda/test_contexts.cpp b/sycl/unittests/pi/cuda/test_contexts.cpp index 5ad24e6836131..405ea3c136d41 100644 --- a/sycl/unittests/pi/cuda/test_contexts.cpp +++ b/sycl/unittests/pi/cuda/test_contexts.cpp @@ -14,8 +14,8 @@ #include -#include "TestGetPlugin.hpp" #include "CudaUtils.hpp" +#include "TestGetPlugin.hpp" #include #include #include diff --git a/sycl/unittests/pi/cuda/test_mem_obj.cpp b/sycl/unittests/pi/cuda/test_mem_obj.cpp index 0122e4cc74174..352420b57c044 100644 --- a/sycl/unittests/pi/cuda/test_mem_obj.cpp +++ b/sycl/unittests/pi/cuda/test_mem_obj.cpp @@ -10,8 +10,8 @@ #include -#include "TestGetPlugin.hpp" #include "CudaUtils.hpp" +#include "TestGetPlugin.hpp" #include #include #include