diff --git a/sycl/unittests/pi/cuda/CudaUtils.hpp b/sycl/unittests/pi/cuda/CudaUtils.hpp new file mode 100644 index 0000000000000..f7cb8b40492d3 --- /dev/null +++ b/sycl/unittests/pi/cuda/CudaUtils.hpp @@ -0,0 +1,20 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include + +namespace pi { + +// utility function to clear the CUDA context stack +inline void clearCudaContext() { + CUcontext ctxt = nullptr; + do { + cuCtxSetCurrent(nullptr); + cuCtxGetCurrent(&ctxt); + } while (ctxt != nullptr); +} + +} // namespace pi diff --git a/sycl/unittests/pi/cuda/test_commands.cpp b/sycl/unittests/pi/cuda/test_commands.cpp index d3d9ad4baf31e..9de39910aa199 100644 --- a/sycl/unittests/pi/cuda/test_commands.cpp +++ b/sycl/unittests/pi/cuda/test_commands.cpp @@ -10,6 +10,7 @@ #include +#include "CudaUtils.hpp" #include "TestGetPlugin.hpp" #include #include @@ -34,7 +35,7 @@ struct CudaCommandsTest : public ::testing::Test { GTEST_SKIP(); } - cuCtxSetCurrent(nullptr); + pi::clearCudaContext(); pi_uint32 numPlatforms = 0; ASSERT_EQ(plugin->getBackend(), backend::cuda); diff --git a/sycl/unittests/pi/cuda/test_contexts.cpp b/sycl/unittests/pi/cuda/test_contexts.cpp index 4007341f94839..405ea3c136d41 100644 --- a/sycl/unittests/pi/cuda/test_contexts.cpp +++ b/sycl/unittests/pi/cuda/test_contexts.cpp @@ -14,6 +14,7 @@ #include +#include "CudaUtils.hpp" #include "TestGetPlugin.hpp" #include #include @@ -63,7 +64,7 @@ struct CudaContextsTest : public ::testing::Test { TEST_F(CudaContextsTest, ContextLifetime) { // start with no active context - cuCtxSetCurrent(nullptr); + pi::clearCudaContext(); // create a context pi_context context; @@ -149,7 +150,7 @@ TEST_F(CudaContextsTest, ContextLifetimeExisting) { // still able to work correctly in that thread. TEST_F(CudaContextsTest, ContextThread) { // start with no active context - cuCtxSetCurrent(nullptr); + pi::clearCudaContext(); // create two PI contexts pi_context context1; diff --git a/sycl/unittests/pi/cuda/test_mem_obj.cpp b/sycl/unittests/pi/cuda/test_mem_obj.cpp index b3d85682279fc..352420b57c044 100644 --- a/sycl/unittests/pi/cuda/test_mem_obj.cpp +++ b/sycl/unittests/pi/cuda/test_mem_obj.cpp @@ -10,6 +10,7 @@ #include +#include "CudaUtils.hpp" #include "TestGetPlugin.hpp" #include #include @@ -34,7 +35,7 @@ struct CudaTestMemObj : public ::testing::Test { GTEST_SKIP(); } - cuCtxSetCurrent(nullptr); + pi::clearCudaContext(); pi_uint32 numPlatforms = 0; ASSERT_EQ(plugin->getBackend(), backend::cuda);