From f878b360399dd82ac3d54a2aba41027c5ed828d3 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 10 Oct 2020 20:17:30 +0000 Subject: [PATCH 01/15] Initial stab at nvJPEG support #2742 --- setup.py | 13 ++ test/test_image.py | 14 +++ torchvision/csrc/cpu/image/image.cpp | 3 +- torchvision/csrc/cpu/image/image.h | 1 + torchvision/csrc/cpu/image/readjpeg_cuda.cpp | 118 +++++++++++++++++++ torchvision/csrc/cpu/image/readjpeg_cuda.h | 5 + 6 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 torchvision/csrc/cpu/image/readjpeg_cuda.cpp create mode 100644 torchvision/csrc/cpu/image/readjpeg_cuda.h diff --git a/setup.py b/setup.py index d6674465405..1c4db8e3f48 100644 --- a/setup.py +++ b/setup.py @@ -315,6 +315,19 @@ def get_extensions(): image_library += [jpeg_lib] image_include += [jpeg_include] + # Locating nvjpeg + (nvjpeg_found, nvjpeg_conda, + nvjpeg_include, nvjpeg_lib) = find_library('nvjpeg', vision_include) + + print('NVJPEG found: {0}'.format(nvjpeg_found)) + image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] + if nvjpeg_found: + print('Building torchvision with NVJPEG image support') + image_link_flags.append('nvjpeg') + if nvjpeg_conda: + image_library += [nvjpeg_lib] + image_include += [nvjpeg_include] + image_path = os.path.join(extensions_dir, 'cpu', 'image') image_src = glob.glob(os.path.join(image_path, '*.cpp')) diff --git a/test/test_image.py b/test/test_image.py index ec4ab532a50..fd6157e47a5 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -50,6 +50,20 @@ def test_decode_jpeg(self): with self.assertRaises(RuntimeError): decode_jpeg(torch.empty((100), dtype=torch.uint8)) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_decode_jpeg_cuda(self): + for img_path in get_images(IMAGE_ROOT, ".jpg"): + img_pil = torch.load(img_path.replace('jpg', 'pth')) + img_pil = img_pil.permute(2, 0, 1) + data = read_file(img_path) + img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data) + self.assertTrue(img_nvjpeg.is_cuda) + + # Image.fromarray(img_nvjpeg.permute(1,2,0).cpu().numpy()).save('/tmp/im.png') + # Image.fromarray(img_pil.permute(1,2,0).cpu().numpy()).save('/tmp/impil.png') + + self.assertTrue((img_pil.float() - img_nvjpeg.cpu().float()).abs().mean() < 1.5) + def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) diff --git a/torchvision/csrc/cpu/image/image.cpp b/torchvision/csrc/cpu/image/image.cpp index d9234aceb6e..9bd13bd1e1c 100644 --- a/torchvision/csrc/cpu/image/image.cpp +++ b/torchvision/csrc/cpu/image/image.cpp @@ -19,4 +19,5 @@ static auto registry = torch::RegisterOperators() .op("image::encode_jpeg", &encodeJPEG) .op("image::read_file", &read_file) .op("image::write_file", &write_file) - .op("image::decode_image", &decode_image); + .op("image::decode_image", &decode_image) + .op("image::decode_jpeg_cuda", &decodeJPEG_cuda); diff --git a/torchvision/csrc/cpu/image/image.h b/torchvision/csrc/cpu/image/image.h index 3a652bef244..a5fb0d798ed 100644 --- a/torchvision/csrc/cpu/image/image.h +++ b/torchvision/csrc/cpu/image/image.h @@ -6,6 +6,7 @@ #include "read_image_cpu.h" #include "read_write_file_cpu.h" #include "readjpeg_cpu.h" +#include "readjpeg_cuda.h" #include "readpng_cpu.h" #include "writejpeg_cpu.h" #include "writepng_cpu.h" diff --git a/torchvision/csrc/cpu/image/readjpeg_cuda.cpp b/torchvision/csrc/cpu/image/readjpeg_cuda.cpp new file mode 100644 index 00000000000..fbd313cc362 --- /dev/null +++ b/torchvision/csrc/cpu/image/readjpeg_cuda.cpp @@ -0,0 +1,118 @@ +#include "readjpeg_cuda.h" + +#include +#include + +#if !NVJPEG_FOUND + +torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { + TORCH_CHECK( + false, "decodeJPEG_cuda: torchvision not compiled with nvJPEG support"); +} + +#else + +#include + +void init_nvjpegImage(nvjpegImage_t& img) { + for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { + img.channel[c] = NULL; + img.pitch[c] = 0; + } +} + +torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + auto datap = data.data_ptr(); + + // Create nvJPEG handle (TODO: only initialise the library once) + nvjpegHandle_t nvjpeg_handle; + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + + // Create nvJPEG state (should this be persistent or not?) + nvjpegJpegState_t nvjpeg_state; + nvjpegStatus_t state_status = + nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); + + TORCH_CHECK( + state_status == NVJPEG_STATUS_SUCCESS, + "nvjpegJpegStateCreate failed: ", + state_status); + + // Get the image information + int components; + nvjpegChromaSubsampling_t subsampling; + int widths[NVJPEG_MAX_COMPONENT]; + int heights[NVJPEG_MAX_COMPONENT]; + + nvjpegStatus_t info_status = nvjpegGetImageInfo( + nvjpeg_handle, + datap, + data.numel(), + &components, + &subsampling, + widths, + heights); + + TORCH_CHECK( + info_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetImageInfo failed: ", + info_status); + + TORCH_CHECK(components == 3, "Only RGB for now"); + + int width = widths[0]; + int height = heights[0]; + + // nvjpegImage_t is a struct with + // - an array of pointers to each channel + // - the pitch for each channel + // which must be filled in manually + nvjpegImage_t outImage; + init_nvjpegImage(outImage); + + // TODO device selection + auto tensor = torch::empty( + {int64_t(components), int64_t(height), int64_t(width)}, + torch::dtype(torch::kU8).device(torch::kCUDA)); + + for (int c = 0; c < 3; c++) { + outImage.channel[c] = tensor[c].data_ptr(); + outImage.pitch[c] = width; + } + + // TODO torch cuda stream support + // TODO output besides RGB + nvjpegStatus_t decode_status = nvjpegDecode( + nvjpeg_handle, + nvjpeg_state, + datap, + data.numel(), + NVJPEG_OUTPUT_RGB, + &outImage, + /*stream=*/0); + + TORCH_CHECK( + decode_status == NVJPEG_STATUS_SUCCESS, + "nvjpegDecode failed: ", + decode_status); + + // Destroy the state and (for now) library handle + nvjpegJpegStateDestroy(nvjpeg_state); + nvjpegDestroy(nvjpeg_handle); + + return tensor; +} + +#endif // NVJPEG_FOUND \ No newline at end of file diff --git a/torchvision/csrc/cpu/image/readjpeg_cuda.h b/torchvision/csrc/cpu/image/readjpeg_cuda.h new file mode 100644 index 00000000000..89bc83f2b44 --- /dev/null +++ b/torchvision/csrc/cpu/image/readjpeg_cuda.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +C10_EXPORT torch::Tensor decodeJPEG_cuda(const torch::Tensor& data); From 5eb6d7340c37c3890acd53134d13f69545323251 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 11 Oct 2020 16:32:58 +0000 Subject: [PATCH 02/15] Init nvjpeg once on first call --- torchvision/csrc/cpu/image/readjpeg_cuda.cpp | 22 +++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/torchvision/csrc/cpu/image/readjpeg_cuda.cpp b/torchvision/csrc/cpu/image/readjpeg_cuda.cpp index fbd313cc362..514b0d0ab1c 100644 --- a/torchvision/csrc/cpu/image/readjpeg_cuda.cpp +++ b/torchvision/csrc/cpu/image/readjpeg_cuda.cpp @@ -14,6 +14,8 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { #include +static nvjpegHandle_t nvjpeg_handle = nullptr; + void init_nvjpegImage(nvjpegImage_t& img) { for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { img.channel[c] = NULL; @@ -31,14 +33,15 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { auto datap = data.data_ptr(); - // Create nvJPEG handle (TODO: only initialise the library once) - nvjpegHandle_t nvjpeg_handle; - nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + // Create nvJPEG handle + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); - TORCH_CHECK( - create_status == NVJPEG_STATUS_SUCCESS, - "nvjpegCreateSimple failed: ", - create_status); + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + } // Create nvJPEG state (should this be persistent or not?) nvjpegJpegState_t nvjpeg_state; @@ -108,11 +111,10 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { "nvjpegDecode failed: ", decode_status); - // Destroy the state and (for now) library handle + // Destroy the state nvjpegJpegStateDestroy(nvjpeg_state); - nvjpegDestroy(nvjpeg_handle); return tensor; } -#endif // NVJPEG_FOUND \ No newline at end of file +#endif // NVJPEG_FOUND From 8abe4a56ee4d4915c30d1290195b0956270f312e Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Fri, 22 Jan 2021 22:25:42 +0000 Subject: [PATCH 03/15] Add io/image/cuda search path --- CMakeLists.txt | 2 +- setup.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bf76d97cddf..18c269d79a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,7 +53,7 @@ include(CMakePackageConfigHelpers) set(TVCPP torchvision/csrc) list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops - ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu) + ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda) if(WITH_CUDA) list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast) endif() diff --git a/setup.py b/setup.py index 5ac96506675..70646063507 100644 --- a/setup.py +++ b/setup.py @@ -329,7 +329,8 @@ def get_extensions(): image_include += [nvjpeg_include] image_path = os.path.join(extensions_dir, 'io', 'image') - image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) + image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) + + glob.glob(os.path.join(image_path, 'cuda', '*.cpp'))) if png_found or jpeg_found: ext_modules.append(extension( From 8ae07510f73efdea16ad8e6a16d33a42f9e56d74 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Fri, 22 Jan 2021 22:56:57 +0000 Subject: [PATCH 04/15] Update test --- test/test_image.py | 11 +++++++++-- torchvision/csrc/io/image/cuda/readjpeg_cuda.h | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index f702d8cb883..4db9c72c9df 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -78,7 +78,13 @@ def test_decode_jpeg(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_decode_jpeg_cuda(self): for img_path in get_images(IMAGE_ROOT, ".jpg"): - img_pil = torch.load(img_path.replace('jpg', 'pth')) + with Image.open(img_path) as img: + img_pil = torch.from_numpy(np.array(img)) + + if len(img_pil.shape) != 3 or img_pil.shape[2] != 3: + # only RGB supported so far + continue + img_pil = img_pil.permute(2, 0, 1) data = read_file(img_path) img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data) @@ -87,7 +93,8 @@ def test_decode_jpeg_cuda(self): # Image.fromarray(img_nvjpeg.permute(1,2,0).cpu().numpy()).save('/tmp/im.png') # Image.fromarray(img_pil.permute(1,2,0).cpu().numpy()).save('/tmp/impil.png') - self.assertTrue((img_pil.float() - img_nvjpeg.cpu().float()).abs().mean() < 1.5) + # Some difference expected between jpeg implementations + self.assertTrue((img_pil.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h b/torchvision/csrc/io/image/cuda/readjpeg_cuda.h index 89bc83f2b44..25167ddaa21 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.h @@ -1,5 +1,5 @@ #pragma once -#include +#include C10_EXPORT torch::Tensor decodeJPEG_cuda(const torch::Tensor& data); From 9a2510ff00a03f23cb0ac332eac31177024de861 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 00:29:34 +0000 Subject: [PATCH 05/15] Building CUDA ext should mean nvjpeg exists It will be included in the cuda sdk lib and include paths set up by CUDAExtension --- setup.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 70646063507..73ea2120433 100644 --- a/setup.py +++ b/setup.py @@ -316,17 +316,14 @@ def get_extensions(): image_include += [jpeg_include] # Locating nvjpeg - (nvjpeg_found, nvjpeg_conda, - nvjpeg_include, nvjpeg_lib) = find_library('nvjpeg', vision_include) + # Should be included in CUDA_HOME + nvjpeg_found = extension is CUDAExtension print('NVJPEG found: {0}'.format(nvjpeg_found)) image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] if nvjpeg_found: print('Building torchvision with NVJPEG image support') image_link_flags.append('nvjpeg') - if nvjpeg_conda: - image_library += [nvjpeg_lib] - image_include += [nvjpeg_include] image_path = os.path.join(extensions_dir, 'io', 'image') image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) From ac3330b02c661e1aed99a8101c47120711f31aa1 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 00:48:24 +0000 Subject: [PATCH 06/15] Check if nvjpeg.h is actually there --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 73ea2120433..250d56f86a3 100644 --- a/setup.py +++ b/setup.py @@ -317,7 +317,7 @@ def get_extensions(): # Locating nvjpeg # Should be included in CUDA_HOME - nvjpeg_found = extension is CUDAExtension + nvjpeg_found = extension is CUDAExtension and os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) print('NVJPEG found: {0}'.format(nvjpeg_found)) image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] From e7981578889be7a5f13a7f94298d081f573abada Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 14:13:43 +0000 Subject: [PATCH 07/15] Add ImageReadMode support for nvjpeg --- test/test_image.py | 26 +++----- .../csrc/io/image/cuda/readjpeg_cuda.cpp | 62 ++++++++++++++----- .../csrc/io/image/cuda/readjpeg_cuda.h | 5 +- 3 files changed, 62 insertions(+), 31 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 4db9c72c9df..285d161dfcc 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -77,24 +77,18 @@ def test_decode_jpeg(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_decode_jpeg_cuda(self): + conversion = [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB] for img_path in get_images(IMAGE_ROOT, ".jpg"): - with Image.open(img_path) as img: - img_pil = torch.from_numpy(np.array(img)) - - if len(img_pil.shape) != 3 or img_pil.shape[2] != 3: - # only RGB supported so far - continue - - img_pil = img_pil.permute(2, 0, 1) - data = read_file(img_path) - img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data) - self.assertTrue(img_nvjpeg.is_cuda) - - # Image.fromarray(img_nvjpeg.permute(1,2,0).cpu().numpy()).save('/tmp/im.png') - # Image.fromarray(img_pil.permute(1,2,0).cpu().numpy()).save('/tmp/impil.png') + if Image.open(img_path).mode == 'CMYK': + # not supported + continue + for mode in conversion: + data = read_file(img_path) + img_ljpeg = decode_image(data, mode=mode) + img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value) - # Some difference expected between jpeg implementations - self.assertTrue((img_pil.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) + # Some difference expected between jpeg implementations + self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp index 514b0d0ab1c..ec9cb21291f 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp @@ -1,11 +1,10 @@ #include "readjpeg_cuda.h" -#include #include #if !NVJPEG_FOUND -torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { +torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { TORCH_CHECK( false, "decodeJPEG_cuda: torchvision not compiled with nvJPEG support"); } @@ -23,7 +22,7 @@ void init_nvjpegImage(nvjpegImage_t& img) { } } -torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { +torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional @@ -68,16 +67,51 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { widths, heights); - TORCH_CHECK( - info_status == NVJPEG_STATUS_SUCCESS, - "nvjpegGetImageInfo failed: ", - info_status); + if (info_status != NVJPEG_STATUS_SUCCESS) { + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); + } - TORCH_CHECK(components == 3, "Only RGB for now"); + if (subsampling == NVJPEG_CSS_UNKNOWN) { + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); + } int width = widths[0]; int height = heights[0]; + nvjpegOutputFormat_t outputFormat; + int outputComponents; + + switch (mode) { + case IMAGE_READ_MODE_UNCHANGED: + if (components == 1) { + outputFormat = NVJPEG_OUTPUT_Y; + outputComponents = 1; + } else if (components == 3) { + outputFormat = NVJPEG_OUTPUT_RGB; + outputComponents = 3; + } else { + TORCH_CHECK( + false, "The provided mode is not supported for JPEG files on GPU"); + } + break; + case IMAGE_READ_MODE_GRAY: + // This will do 0.299*R + 0.587*G + 0.114*B like opencv + // TODO check if that is the same as libjpeg + outputFormat = NVJPEG_OUTPUT_Y; + outputComponents = 1; + break; + case IMAGE_READ_MODE_RGB: + outputFormat = NVJPEG_OUTPUT_RGB; + outputComponents = 3; + break; + default: + // CMYK as input might work with nvjpegDecodeParamsSetAllowCMYK() + TORCH_CHECK( + false, "The provided mode is not supported for JPEG files on GPU"); + } + // nvjpegImage_t is a struct with // - an array of pointers to each channel // - the pitch for each channel @@ -87,10 +121,10 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { // TODO device selection auto tensor = torch::empty( - {int64_t(components), int64_t(height), int64_t(width)}, + {int64_t(outputComponents), int64_t(height), int64_t(width)}, torch::dtype(torch::kU8).device(torch::kCUDA)); - for (int c = 0; c < 3; c++) { + for (int c = 0; c < outputComponents; c++) { outImage.channel[c] = tensor[c].data_ptr(); outImage.pitch[c] = width; } @@ -102,18 +136,18 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { nvjpeg_state, datap, data.numel(), - NVJPEG_OUTPUT_RGB, + outputFormat, &outImage, /*stream=*/0); + // Destroy the state + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK( decode_status == NVJPEG_STATUS_SUCCESS, "nvjpegDecode failed: ", decode_status); - // Destroy the state - nvjpegJpegStateDestroy(nvjpeg_state); - return tensor; } diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h b/torchvision/csrc/io/image/cuda/readjpeg_cuda.h index 25167ddaa21..aa5194f8652 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.h @@ -1,5 +1,8 @@ #pragma once #include +#include "../image_read_mode.h" -C10_EXPORT torch::Tensor decodeJPEG_cuda(const torch::Tensor& data); +C10_EXPORT torch::Tensor decodeJPEG_cuda( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); From a07f53a8a8cc7197b325afde4e5537b61c712a53 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 14:19:14 +0000 Subject: [PATCH 08/15] lint --- test/test_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 285d161dfcc..337da6330ee 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -80,8 +80,8 @@ def test_decode_jpeg_cuda(self): conversion = [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB] for img_path in get_images(IMAGE_ROOT, ".jpg"): if Image.open(img_path).mode == 'CMYK': - # not supported - continue + # not supported + continue for mode in conversion: data = read_file(img_path) img_ljpeg = decode_image(data, mode=mode) From e485656d0145952f5b62899139b0511e927dd8df Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 14:28:21 +0000 Subject: [PATCH 09/15] Call nvjpegJpegStateDestroy when bailing out --- torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp index ec9cb21291f..16af9a429d3 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp @@ -92,6 +92,7 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { outputFormat = NVJPEG_OUTPUT_RGB; outputComponents = 3; } else { + nvjpegJpegStateDestroy(nvjpeg_state); TORCH_CHECK( false, "The provided mode is not supported for JPEG files on GPU"); } @@ -108,6 +109,7 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { break; default: // CMYK as input might work with nvjpegDecodeParamsSetAllowCMYK() + nvjpegJpegStateDestroy(nvjpeg_state); TORCH_CHECK( false, "The provided mode is not supported for JPEG files on GPU"); } From 1c1e471ee12d3132f543fd49a38d0a43548d6265 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 15:05:52 +0000 Subject: [PATCH 10/15] Use at::cuda::getCurrentCUDAStream() --- torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp index 16af9a429d3..9b1af0f6fd3 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp @@ -11,13 +11,15 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { #else +#include +#include #include static nvjpegHandle_t nvjpeg_handle = nullptr; void init_nvjpegImage(nvjpegImage_t& img) { for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { - img.channel[c] = NULL; + img.channel[c] = nullptr; img.pitch[c] = 0; } } @@ -131,8 +133,8 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { outImage.pitch[c] = width; } - // TODO torch cuda stream support - // TODO output besides RGB + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + nvjpegStatus_t decode_status = nvjpegDecode( nvjpeg_handle, nvjpeg_state, @@ -140,7 +142,7 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { data.numel(), outputFormat, &outImage, - /*stream=*/0); + stream); // Destroy the state nvjpegJpegStateDestroy(nvjpeg_state); From 5bc5e21f805acd75053b4aaf94973ff7a1b604a3 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 14:34:52 +0000 Subject: [PATCH 11/15] Changes to match #3312 --- .../{readjpeg_cuda.cpp => decode_jpeg_cuda.cpp} | 14 ++++++++++---- .../cuda/{readjpeg_cuda.h => decode_jpeg_cuda.h} | 8 +++++++- torchvision/csrc/io/image/image.cpp | 2 +- torchvision/csrc/io/image/image.h | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) rename torchvision/csrc/io/image/cuda/{readjpeg_cuda.cpp => decode_jpeg_cuda.cpp} (91%) rename torchvision/csrc/io/image/cuda/{readjpeg_cuda.h => decode_jpeg_cuda.h} (55%) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp similarity index 91% rename from torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp rename to torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 9b1af0f6fd3..21a5bd8b518 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,12 +1,15 @@ -#include "readjpeg_cuda.h" +#include "decode_jpeg_cuda.h" #include +namespace vision { +namespace image { + #if !NVJPEG_FOUND -torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { TORCH_CHECK( - false, "decodeJPEG_cuda: torchvision not compiled with nvJPEG support"); + false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); } #else @@ -24,7 +27,7 @@ void init_nvjpegImage(nvjpegImage_t& img) { } } -torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional @@ -156,3 +159,6 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { } #endif // NVJPEG_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h similarity index 55% rename from torchvision/csrc/io/image/cuda/readjpeg_cuda.h rename to torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h index aa5194f8652..3345dfd770e 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h @@ -3,6 +3,12 @@ #include #include "../image_read_mode.h" -C10_EXPORT torch::Tensor decodeJPEG_cuda( +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_jpeg_cuda( const torch::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 3dd4cdcf5c0..37d64013cb2 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -22,7 +22,7 @@ static auto registry = torch::RegisterOperators() .op("image::read_file", &read_file) .op("image::write_file", &write_file) .op("image::decode_image", &decode_image) - .op("image::decode_jpeg_cuda", &decodeJPEG_cuda); + .op("image::decode_jpeg_cuda", &decode_jpeg_cuda); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index cb380d5ebdd..05bac44c77d 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -6,4 +6,4 @@ #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" -#include "cuda/readjpeg_cuda.h" +#include "cuda/decode_jpeg_cuda.h" From 4d4cd45cf0366c4c081fcf2e5eb051a45c3e0f3e Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 14:58:36 +0000 Subject: [PATCH 12/15] Move includes outside namespace --- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 21a5bd8b518..bb028004046 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,6 +1,8 @@ #include "decode_jpeg_cuda.h" #include +#include +#include namespace vision { namespace image { @@ -14,8 +16,6 @@ torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { #else -#include -#include #include static nvjpegHandle_t nvjpeg_handle = nullptr; From dd3e445d692702298f76513e220e9362f55e3556 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 15:02:49 +0000 Subject: [PATCH 13/15] Lint --- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index bb028004046..541b13bd072 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,8 +1,8 @@ #include "decode_jpeg_cuda.h" -#include #include #include +#include namespace vision { namespace image { From f560eebb42b8f9d87b4633da3ea64e844678fbbf Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 15:10:42 +0000 Subject: [PATCH 14/15] Guard includes so cpu builds work --- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 541b13bd072..3f7cf888895 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,7 +1,12 @@ #include "decode_jpeg_cuda.h" #include + +#if NVJPEG_FOUND #include +#include +#endif + #include namespace vision { @@ -16,8 +21,6 @@ torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { #else -#include - static nvjpegHandle_t nvjpeg_handle = nullptr; void init_nvjpegImage(nvjpegImage_t& img) { From ab90893b13ae49f0fd39f31141269648af6ef84f Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 18:44:10 +0000 Subject: [PATCH 15/15] Add device argument --- test/test_image.py | 2 +- .../csrc/io/image/cuda/decode_jpeg_cuda.cpp | 19 ++++++++++++++++--- .../csrc/io/image/cuda/decode_jpeg_cuda.h | 3 ++- torchvision/io/image.py | 9 +++++++-- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 337da6330ee..712edba80d9 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -85,7 +85,7 @@ def test_decode_jpeg_cuda(self): for mode in conversion: data = read_file(img_path) img_ljpeg = decode_image(data, mode=mode) - img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value) + img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value, 'cuda') # Some difference expected between jpeg implementations self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 3f7cf888895..ed194062dba 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -4,6 +4,7 @@ #if NVJPEG_FOUND #include +#include #include #endif @@ -14,7 +15,10 @@ namespace image { #if !NVJPEG_FOUND -torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { TORCH_CHECK( false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); } @@ -30,7 +34,10 @@ void init_nvjpegImage(nvjpegImage_t& img) { } } -torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional @@ -38,6 +45,12 @@ torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); + TORCH_CHECK( + device.is_cuda(), "Expected a cuda device" + ) + + at::cuda::CUDAGuard device_guard(device); + auto datap = data.data_ptr(); // Create nvJPEG handle @@ -132,7 +145,7 @@ torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { // TODO device selection auto tensor = torch::empty( {int64_t(outputComponents), int64_t(height), int64_t(width)}, - torch::dtype(torch::kU8).device(torch::kCUDA)); + torch::dtype(torch::kU8).device(device)); for (int c = 0; c < outputComponents; c++) { outImage.channel[c] = tensor[c].data_ptr(); diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h index 3345dfd770e..496b355e9b7 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h @@ -8,7 +8,8 @@ namespace image { C10_EXPORT torch::Tensor decode_jpeg_cuda( const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + ImageReadMode mode, + torch::Device device); } // namespace image } // namespace vision diff --git a/torchvision/io/image.py b/torchvision/io/image.py index e193555e447..0c14ae007b4 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -149,7 +149,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): write_file(filename, output) -def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: +def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, + device: torch.device = 'cpu') -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. Optionally converts the image to the desired format. @@ -166,7 +167,11 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG Returns: output (Tensor[image_channels, image_height, image_width]) """ - output = torch.ops.image.decode_jpeg(input, mode.value) + device = torch.device(device) + if device.type == 'cuda': + output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) + else: + output = torch.ops.image.decode_jpeg(input, mode.value) return output