-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Support for decoding jpegs on GPU with nvjpeg #3792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f878b36
5eb6d73
afd4a2e
8abe4a5
8ae0751
9a2510f
ac3330b
e798157
a07f53a
e485656
1c1e471
3e7486e
5bc5e21
4d4cd45
dd3e445
f560eeb
ab90893
0992fa4
540eaa4
785ba98
7b6eadf
380d5b5
3c73ac9
ef4e8ce
8de3a4a
be5526e
7b5c09f
67aad81
7be3c11
422d2f3
fbb4511
45f4515
7d9c55e
83124df
1c3c4bd
ba63e1e
f167602
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
#include "decode_jpeg_cuda.h" | ||
|
||
#include <ATen/ATen.h> | ||
|
||
#if NVJPEG_FOUND | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <nvjpeg.h> | ||
#endif | ||
|
||
#include <string> | ||
|
||
namespace vision { | ||
namespace image { | ||
|
||
#if !NVJPEG_FOUND | ||
|
||
torch::Tensor decode_jpeg_cuda( | ||
const torch::Tensor& data, | ||
ImageReadMode mode, | ||
torch::Device device) { | ||
TORCH_CHECK( | ||
false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); | ||
} | ||
|
||
#else | ||
|
||
namespace { | ||
static nvjpegHandle_t nvjpeg_handle = nullptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding this on an anonymous namespace as it looks an internal detail of the implementation. Also just checking whether this should be released at any point to avoid memory leaks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah this is a good point. Creating / allocating it at each call has some severe overhead so it makes sense to declare this as a global variable (related discussion: #2786 (comment)). But this means we never really know when to release it, and the memory will only be freed when the process is killed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reference: this is thread-safe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good that it's thread-safe but it's still unclear to me whether we have to find a way to release it or if we can leave it be. We don't have such an idiom at TorchVision but I wonder if there are examples of resources on PyTorch core that are never released. @ezyang how do you handle situations like this on core? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We leak a bunch of global statically initialized objects, as it is the easiest way to avoid destructor ordering problems. If nvjpeg is very simple library you might be able to write an RAII class for this object and have it destruct properly on shutdown. Precedent in PyTorch is the cudnn convolution cache, look at that for some inspiration. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ezyang Thanks for the insights. @NicolasHug Given Ed's input, I think we can mark this as resolved if you want. |
||
} | ||
|
||
torch::Tensor decode_jpeg_cuda( | ||
const torch::Tensor& data, | ||
ImageReadMode mode, | ||
torch::Device device) { | ||
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); | ||
|
||
TORCH_CHECK( | ||
!data.is_cuda(), | ||
"The input tensor must be on CPU when decoding with nvjpeg") | ||
|
||
TORCH_CHECK( | ||
data.dim() == 1 && data.numel() > 0, | ||
"Expected a non empty 1-dimensional tensor"); | ||
|
||
TORCH_CHECK(device.is_cuda(), "Expected a cuda device") | ||
|
||
at::cuda::CUDAGuard device_guard(device); | ||
|
||
// Create global nvJPEG handle | ||
std::once_flag nvjpeg_handle_creation_flag; | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::call_once(nvjpeg_handle_creation_flag, []() { | ||
if (nvjpeg_handle == nullptr) { | ||
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); | ||
|
||
if (create_status != NVJPEG_STATUS_SUCCESS) { | ||
// Reset handle so that one can still call the function again in the | ||
// same process if there was a failure | ||
free(nvjpeg_handle); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it's an opaque handle I think the use of (I would hope that it simply leaves the handle as null if initialisation fails, although I don't see that in the docs - here it just reinits the handle without any freeing when hw backend fails though) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm that's a good point. What would you recommend instead of Before pushing this I did a quick test by inserting nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);
create_status = NVJPEG_STATUS_NOT_INITIALIZED; // <- this and all the tests were failing gracefully with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't use it because I was wondering whether There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess if you assume anything can happen if initialisation fails then it might end up being an arbitrary value like 0xDEADBEEF and all you can do is reset it to null. |
||
nvjpeg_handle = nullptr; | ||
} | ||
TORCH_CHECK( | ||
create_status == NVJPEG_STATUS_SUCCESS, | ||
"nvjpegCreateSimple failed: ", | ||
create_status); | ||
} | ||
}); | ||
|
||
// Create the jpeg state | ||
nvjpegJpegState_t jpeg_state; | ||
nvjpegStatus_t state_status = | ||
nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state); | ||
|
||
TORCH_CHECK( | ||
state_status == NVJPEG_STATUS_SUCCESS, | ||
"nvjpegJpegStateCreate failed: ", | ||
state_status); | ||
|
||
auto datap = data.data_ptr<uint8_t>(); | ||
|
||
// Get the image information | ||
int num_channels; | ||
nvjpegChromaSubsampling_t subsampling; | ||
int widths[NVJPEG_MAX_COMPONENT]; | ||
int heights[NVJPEG_MAX_COMPONENT]; | ||
nvjpegStatus_t info_status = nvjpegGetImageInfo( | ||
nvjpeg_handle, | ||
datap, | ||
data.numel(), | ||
&num_channels, | ||
&subsampling, | ||
widths, | ||
heights); | ||
|
||
if (info_status != NVJPEG_STATUS_SUCCESS) { | ||
nvjpegJpegStateDestroy(jpeg_state); | ||
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); | ||
} | ||
|
||
if (subsampling == NVJPEG_CSS_UNKNOWN) { | ||
nvjpegJpegStateDestroy(jpeg_state); | ||
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); | ||
} | ||
|
||
int width = widths[0]; | ||
int height = heights[0]; | ||
|
||
nvjpegOutputFormat_t ouput_format; | ||
int num_channels_output; | ||
|
||
switch (mode) { | ||
case IMAGE_READ_MODE_UNCHANGED: | ||
num_channels_output = num_channels; | ||
// For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will | ||
// not properly decode RGB images (it's fine for grayscale), so we set | ||
// output_format manually here | ||
if (num_channels == 1) { | ||
ouput_format = NVJPEG_OUTPUT_Y; | ||
} else if (num_channels == 3) { | ||
ouput_format = NVJPEG_OUTPUT_RGB; | ||
} else { | ||
nvjpegJpegStateDestroy(jpeg_state); | ||
TORCH_CHECK( | ||
false, | ||
"When mode is UNCHANGED, only 1 or 3 input channels are allowed."); | ||
} | ||
break; | ||
case IMAGE_READ_MODE_GRAY: | ||
ouput_format = NVJPEG_OUTPUT_Y; | ||
num_channels_output = 1; | ||
break; | ||
case IMAGE_READ_MODE_RGB: | ||
ouput_format = NVJPEG_OUTPUT_RGB; | ||
num_channels_output = 3; | ||
break; | ||
default: | ||
nvjpegJpegStateDestroy(jpeg_state); | ||
TORCH_CHECK( | ||
false, "The provided mode is not supported for JPEG decoding on GPU"); | ||
} | ||
|
||
auto out_tensor = torch::empty( | ||
{int64_t(num_channels_output), int64_t(height), int64_t(width)}, | ||
torch::dtype(torch::kU8).device(device)); | ||
|
||
// nvjpegImage_t is a struct with | ||
// - an array of pointers to each channel | ||
// - the pitch for each channel | ||
// which must be filled in manually | ||
nvjpegImage_t out_image; | ||
|
||
for (int c = 0; c < num_channels_output; c++) { | ||
out_image.channel[c] = out_tensor[c].data_ptr<uint8_t>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is fine for now, but this adds extra overhead as we need to create a full Tensor just to extract the data pointer (and the Tensor construction is heavy). Given that we generally only have 3 channels that shouldn't be much of an issue, but still good to keep in mind. Some alternatives would be to directly use the raw data_ptr adding the correct offsets, like uint8_t * out_tensor_ptr = out_tensor.data_ptr<uint8_t>();
...
out_image.channel[c] = out_tensor_ptr + c * height * width; Also, interesting that |
||
out_image.pitch[c] = width; | ||
} | ||
for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) { | ||
out_image.channel[c] = nullptr; | ||
out_image.pitch[c] = 0; | ||
} | ||
|
||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()); | ||
|
||
nvjpegStatus_t decode_status = nvjpegDecode( | ||
nvjpeg_handle, | ||
jpeg_state, | ||
datap, | ||
data.numel(), | ||
ouput_format, | ||
&out_image, | ||
stream); | ||
|
||
nvjpegJpegStateDestroy(jpeg_state); | ||
|
||
TORCH_CHECK( | ||
decode_status == NVJPEG_STATUS_SUCCESS, | ||
"nvjpegDecode failed: ", | ||
decode_status); | ||
|
||
return out_tensor; | ||
} | ||
|
||
#endif // NVJPEG_FOUND | ||
|
||
} // namespace image | ||
} // namespace vision |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#pragma once | ||
|
||
#include <torch/types.h> | ||
#include "../image_read_mode.h" | ||
|
||
namespace vision { | ||
namespace image { | ||
|
||
C10_EXPORT torch::Tensor decode_jpeg_cuda( | ||
const torch::Tensor& data, | ||
ImageReadMode mode, | ||
torch::Device device); | ||
|
||
} // namespace image | ||
} // namespace vision |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators() | |
.op("image::encode_jpeg", &encode_jpeg) | ||
.op("image::read_file", &read_file) | ||
.op("image::write_file", &write_file) | ||
.op("image::decode_image", &decode_image); | ||
.op("image::decode_image", &decode_image) | ||
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if the dispatcher would make sense here. Since this is the first IO method we add for GPU, it might be worth checking the naming conventions ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using the dispatcher would be good, but I'm not sure how it handles constructor functions (like Indeed, this function always takes CPU tensors, and it's up to a @ezyang do you know if we can use the dispatcher to dispatch taking a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fmassa How about reading the data on CPU since that's needed and then calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still, nvjpeg requires the input data to live on CPU, so we would need to move it back to CPU again within the function, which would be inefficient. I would have preferred if we could pass the tensor directly as a CUDA tensor as well, but I'm not sure this is possible without further overheads There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarifications concerning nvjpeg. I think that we can investigate on future PRs how we could do this more elegantly. No need to block this PR. |
||
|
||
} // namespace image | ||
} // namespace vision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to consider the
mean
different or themax
difference here? What would be the minimum value so thatmax
tests pass here?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The max error can be quite high unfortunately, the minimum threshold for all tests to pass seems to be 52, after which some tests start failing.
In
test_decode_jpeg
, we also test for the MAE (with the same threshold=2)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum, this looks suspicious that the decoding gives such large differences. Something to keep an eye on