Skip to content

[WIP] nvJPEG support #2786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ include(CMakePackageConfigHelpers)

set(TVCPP torchvision/csrc)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu)
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda)
if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif()
Expand Down
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,19 @@ def get_extensions():
image_library += [jpeg_lib]
image_include += [jpeg_include]

# Locating nvjpeg
# Should be included in CUDA_HOME
nvjpeg_found = extension is CUDAExtension and os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h'))

print('NVJPEG found: {0}'.format(nvjpeg_found))
image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))]
if nvjpeg_found:
print('Building torchvision with NVJPEG image support')
image_link_flags.append('nvjpeg')

image_path = os.path.join(extensions_dir, 'io', 'image')
image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
+ glob.glob(os.path.join(image_path, 'cuda', '*.cpp')))

if png_found or jpeg_found:
ext_modules.append(extension(
Expand Down
15 changes: 15 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,21 @@ def test_decode_jpeg(self):
with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_decode_jpeg_cuda(self):
conversion = [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]
for img_path in get_images(IMAGE_ROOT, ".jpg"):
if Image.open(img_path).mode == 'CMYK':
# not supported
continue
for mode in conversion:
data = read_file(img_path)
img_ljpeg = decode_image(data, mode=mode)
img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value)

# Some difference expected between jpeg implementations
self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.)

def test_damaged_images(self):
# Test image with bad Huffman encoding (should not raise)
bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
Expand Down
158 changes: 158 additions & 0 deletions torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#include "readjpeg_cuda.h"

#include <string>

#if !NVJPEG_FOUND

torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(
false, "decodeJPEG_cuda: torchvision not compiled with nvJPEG support");
}

#else

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <nvjpeg.h>

static nvjpegHandle_t nvjpeg_handle = nullptr;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there device-specific state associated with the nvjpegHandle? i.e. is it safe/optimal to create a nvjpeg_handle on one device, then switch CUDA devices and use the previously created nvjpeg_handle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs say:

The library handle is used in any consecutive nvJPEG library calls, and should be initialized first. The library handle is thread safe, and can be used by multiple threads simultaneously.

With the links above where the struct is just defined as a global variable, and the fact that they call it a "library handle", I'd say it's fair to assume that it's not device-specific.

From a quick test with 2 GPUs, this passes fine:

                img_nvjpeg = decode_jpeg(data, mode=mode, device='cuda:0')
                img_nvjpeg2 = decode_jpeg(data, mode=mode, device='cuda:1')
                self.assertTrue((img_nvjpeg.cpu().float() - img_nvjpeg2.cpu().float()).abs().mean() < 1e-10)


void init_nvjpegImage(nvjpegImage_t& img) {
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
img.channel[c] = nullptr;
img.pitch[c] = 0;
}
}

torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");

auto datap = data.data_ptr<uint8_t>();

// Create nvJPEG handle
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);

TORCH_CHECK(
create_status == NVJPEG_STATUS_SUCCESS,
"nvjpegCreateSimple failed: ",
create_status);
}

// Create nvJPEG state (should this be persistent or not?)
nvjpegJpegState_t nvjpeg_state;
nvjpegStatus_t state_status =
nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state);

TORCH_CHECK(
state_status == NVJPEG_STATUS_SUCCESS,
"nvjpegJpegStateCreate failed: ",
state_status);

// Get the image information
int components;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];

nvjpegStatus_t info_status = nvjpegGetImageInfo(
nvjpeg_handle,
datap,
data.numel(),
&components,
&subsampling,
widths,
heights);

if (info_status != NVJPEG_STATUS_SUCCESS) {
nvjpegJpegStateDestroy(nvjpeg_state);
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status);
}

if (subsampling == NVJPEG_CSS_UNKNOWN) {
nvjpegJpegStateDestroy(nvjpeg_state);
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling");
}

int width = widths[0];
int height = heights[0];

nvjpegOutputFormat_t outputFormat;
int outputComponents;

switch (mode) {
case IMAGE_READ_MODE_UNCHANGED:
if (components == 1) {
outputFormat = NVJPEG_OUTPUT_Y;
outputComponents = 1;
} else if (components == 3) {
outputFormat = NVJPEG_OUTPUT_RGB;
outputComponents = 3;
} else {
nvjpegJpegStateDestroy(nvjpeg_state);
TORCH_CHECK(
false, "The provided mode is not supported for JPEG files on GPU");
}
break;
case IMAGE_READ_MODE_GRAY:
// This will do 0.299*R + 0.587*G + 0.114*B like opencv
// TODO check if that is the same as libjpeg
outputFormat = NVJPEG_OUTPUT_Y;
outputComponents = 1;
break;
case IMAGE_READ_MODE_RGB:
outputFormat = NVJPEG_OUTPUT_RGB;
outputComponents = 3;
break;
default:
// CMYK as input might work with nvjpegDecodeParamsSetAllowCMYK()
nvjpegJpegStateDestroy(nvjpeg_state);
TORCH_CHECK(
false, "The provided mode is not supported for JPEG files on GPU");
}

// nvjpegImage_t is a struct with
// - an array of pointers to each channel
// - the pitch for each channel
// which must be filled in manually
nvjpegImage_t outImage;
init_nvjpegImage(outImage);

// TODO device selection
auto tensor = torch::empty(
{int64_t(outputComponents), int64_t(height), int64_t(width)},
torch::dtype(torch::kU8).device(torch::kCUDA));

for (int c = 0; c < outputComponents; c++) {
outImage.channel[c] = tensor[c].data_ptr<uint8_t>();
outImage.pitch[c] = width;
}

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

nvjpegStatus_t decode_status = nvjpegDecode(
nvjpeg_handle,
nvjpeg_state,
datap,
data.numel(),
outputFormat,
&outImage,
stream);

// Destroy the state
nvjpegJpegStateDestroy(nvjpeg_state);

TORCH_CHECK(
decode_status == NVJPEG_STATUS_SUCCESS,
"nvjpegDecode failed: ",
decode_status);

return tensor;
}

#endif // NVJPEG_FOUND
8 changes: 8 additions & 0 deletions torchvision/csrc/io/image/cuda/readjpeg_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

C10_EXPORT torch::Tensor decodeJPEG_cuda(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ static auto registry = torch::RegisterOperators()
.op("image::encode_jpeg", &encodeJPEG)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);
.op("image::decode_image", &decode_image)
.op("image::decode_jpeg_cuda", &decodeJPEG_cuda);
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
#include "cpu/readpng_cpu.h"
#include "cpu/writejpeg_cpu.h"
#include "cpu/writepng_cpu.h"
#include "cuda/readjpeg_cuda.h"