Skip to content

Commit f82feb3

Browse files
committed
Use at::cuda::getCurrentCUDAStream()
1 parent e485656 commit f82feb3

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) {
1111

1212
#else
1313

14+
#include <ATen/ATen.h>
15+
#include <ATen/cuda/CUDAContext.h>
1416
#include <nvjpeg.h>
1517

1618
static nvjpegHandle_t nvjpeg_handle = nullptr;
1719

1820
void init_nvjpegImage(nvjpegImage_t& img) {
1921
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
20-
img.channel[c] = NULL;
22+
img.channel[c] = nullptr;
2123
img.pitch[c] = 0;
2224
}
2325
}
@@ -132,15 +134,17 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) {
132134
}
133135

134136
// TODO torch cuda stream support
135-
// TODO output besides RGB
137+
138+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
139+
136140
nvjpegStatus_t decode_status = nvjpegDecode(
137141
nvjpeg_handle,
138142
nvjpeg_state,
139143
datap,
140144
data.numel(),
141145
outputFormat,
142146
&outImage,
143-
/*stream=*/0);
147+
stream);
144148

145149
// Destroy the state
146150
nvjpegJpegStateDestroy(nvjpeg_state);

0 commit comments

Comments
 (0)