From f878b360399dd82ac3d54a2aba41027c5ed828d3 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 10 Oct 2020 20:17:30 +0000 Subject: [PATCH 01/28] Initial stab at nvJPEG support #2742 --- setup.py | 13 ++ test/test_image.py | 14 +++ torchvision/csrc/cpu/image/image.cpp | 3 +- torchvision/csrc/cpu/image/image.h | 1 + torchvision/csrc/cpu/image/readjpeg_cuda.cpp | 118 +++++++++++++++++++ torchvision/csrc/cpu/image/readjpeg_cuda.h | 5 + 6 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 torchvision/csrc/cpu/image/readjpeg_cuda.cpp create mode 100644 torchvision/csrc/cpu/image/readjpeg_cuda.h diff --git a/setup.py b/setup.py index d6674465405..1c4db8e3f48 100644 --- a/setup.py +++ b/setup.py @@ -315,6 +315,19 @@ def get_extensions(): image_library += [jpeg_lib] image_include += [jpeg_include] + # Locating nvjpeg + (nvjpeg_found, nvjpeg_conda, + nvjpeg_include, nvjpeg_lib) = find_library('nvjpeg', vision_include) + + print('NVJPEG found: {0}'.format(nvjpeg_found)) + image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] + if nvjpeg_found: + print('Building torchvision with NVJPEG image support') + image_link_flags.append('nvjpeg') + if nvjpeg_conda: + image_library += [nvjpeg_lib] + image_include += [nvjpeg_include] + image_path = os.path.join(extensions_dir, 'cpu', 'image') image_src = glob.glob(os.path.join(image_path, '*.cpp')) diff --git a/test/test_image.py b/test/test_image.py index ec4ab532a50..fd6157e47a5 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -50,6 +50,20 @@ def test_decode_jpeg(self): with self.assertRaises(RuntimeError): decode_jpeg(torch.empty((100), dtype=torch.uint8)) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_decode_jpeg_cuda(self): + for img_path in get_images(IMAGE_ROOT, ".jpg"): + img_pil = torch.load(img_path.replace('jpg', 'pth')) + img_pil = img_pil.permute(2, 0, 1) + data = read_file(img_path) + img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data) + self.assertTrue(img_nvjpeg.is_cuda) + + # Image.fromarray(img_nvjpeg.permute(1,2,0).cpu().numpy()).save('/tmp/im.png') + # Image.fromarray(img_pil.permute(1,2,0).cpu().numpy()).save('/tmp/impil.png') + + self.assertTrue((img_pil.float() - img_nvjpeg.cpu().float()).abs().mean() < 1.5) + def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) diff --git a/torchvision/csrc/cpu/image/image.cpp b/torchvision/csrc/cpu/image/image.cpp index d9234aceb6e..9bd13bd1e1c 100644 --- a/torchvision/csrc/cpu/image/image.cpp +++ b/torchvision/csrc/cpu/image/image.cpp @@ -19,4 +19,5 @@ static auto registry = torch::RegisterOperators() .op("image::encode_jpeg", &encodeJPEG) .op("image::read_file", &read_file) .op("image::write_file", &write_file) - .op("image::decode_image", &decode_image); + .op("image::decode_image", &decode_image) + .op("image::decode_jpeg_cuda", &decodeJPEG_cuda); diff --git a/torchvision/csrc/cpu/image/image.h b/torchvision/csrc/cpu/image/image.h index 3a652bef244..a5fb0d798ed 100644 --- a/torchvision/csrc/cpu/image/image.h +++ b/torchvision/csrc/cpu/image/image.h @@ -6,6 +6,7 @@ #include "read_image_cpu.h" #include "read_write_file_cpu.h" #include "readjpeg_cpu.h" +#include "readjpeg_cuda.h" #include "readpng_cpu.h" #include "writejpeg_cpu.h" #include "writepng_cpu.h" diff --git a/torchvision/csrc/cpu/image/readjpeg_cuda.cpp b/torchvision/csrc/cpu/image/readjpeg_cuda.cpp new file mode 100644 index 00000000000..fbd313cc362 --- /dev/null +++ b/torchvision/csrc/cpu/image/readjpeg_cuda.cpp @@ -0,0 +1,118 @@ +#include "readjpeg_cuda.h" + +#include +#include + +#if !NVJPEG_FOUND + +torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { + TORCH_CHECK( + false, "decodeJPEG_cuda: torchvision not compiled with nvJPEG support"); +} + +#else + +#include + +void init_nvjpegImage(nvjpegImage_t& img) { + for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { + img.channel[c] = NULL; + img.pitch[c] = 0; + } +} + +torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + auto datap = data.data_ptr(); + + // Create nvJPEG handle (TODO: only initialise the library once) + nvjpegHandle_t nvjpeg_handle; + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + + // Create nvJPEG state (should this be persistent or not?) + nvjpegJpegState_t nvjpeg_state; + nvjpegStatus_t state_status = + nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); + + TORCH_CHECK( + state_status == NVJPEG_STATUS_SUCCESS, + "nvjpegJpegStateCreate failed: ", + state_status); + + // Get the image information + int components; + nvjpegChromaSubsampling_t subsampling; + int widths[NVJPEG_MAX_COMPONENT]; + int heights[NVJPEG_MAX_COMPONENT]; + + nvjpegStatus_t info_status = nvjpegGetImageInfo( + nvjpeg_handle, + datap, + data.numel(), + &components, + &subsampling, + widths, + heights); + + TORCH_CHECK( + info_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetImageInfo failed: ", + info_status); + + TORCH_CHECK(components == 3, "Only RGB for now"); + + int width = widths[0]; + int height = heights[0]; + + // nvjpegImage_t is a struct with + // - an array of pointers to each channel + // - the pitch for each channel + // which must be filled in manually + nvjpegImage_t outImage; + init_nvjpegImage(outImage); + + // TODO device selection + auto tensor = torch::empty( + {int64_t(components), int64_t(height), int64_t(width)}, + torch::dtype(torch::kU8).device(torch::kCUDA)); + + for (int c = 0; c < 3; c++) { + outImage.channel[c] = tensor[c].data_ptr(); + outImage.pitch[c] = width; + } + + // TODO torch cuda stream support + // TODO output besides RGB + nvjpegStatus_t decode_status = nvjpegDecode( + nvjpeg_handle, + nvjpeg_state, + datap, + data.numel(), + NVJPEG_OUTPUT_RGB, + &outImage, + /*stream=*/0); + + TORCH_CHECK( + decode_status == NVJPEG_STATUS_SUCCESS, + "nvjpegDecode failed: ", + decode_status); + + // Destroy the state and (for now) library handle + nvjpegJpegStateDestroy(nvjpeg_state); + nvjpegDestroy(nvjpeg_handle); + + return tensor; +} + +#endif // NVJPEG_FOUND \ No newline at end of file diff --git a/torchvision/csrc/cpu/image/readjpeg_cuda.h b/torchvision/csrc/cpu/image/readjpeg_cuda.h new file mode 100644 index 00000000000..89bc83f2b44 --- /dev/null +++ b/torchvision/csrc/cpu/image/readjpeg_cuda.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +C10_EXPORT torch::Tensor decodeJPEG_cuda(const torch::Tensor& data); From 5eb6d7340c37c3890acd53134d13f69545323251 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 11 Oct 2020 16:32:58 +0000 Subject: [PATCH 02/28] Init nvjpeg once on first call --- torchvision/csrc/cpu/image/readjpeg_cuda.cpp | 22 +++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/torchvision/csrc/cpu/image/readjpeg_cuda.cpp b/torchvision/csrc/cpu/image/readjpeg_cuda.cpp index fbd313cc362..514b0d0ab1c 100644 --- a/torchvision/csrc/cpu/image/readjpeg_cuda.cpp +++ b/torchvision/csrc/cpu/image/readjpeg_cuda.cpp @@ -14,6 +14,8 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { #include +static nvjpegHandle_t nvjpeg_handle = nullptr; + void init_nvjpegImage(nvjpegImage_t& img) { for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { img.channel[c] = NULL; @@ -31,14 +33,15 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { auto datap = data.data_ptr(); - // Create nvJPEG handle (TODO: only initialise the library once) - nvjpegHandle_t nvjpeg_handle; - nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + // Create nvJPEG handle + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); - TORCH_CHECK( - create_status == NVJPEG_STATUS_SUCCESS, - "nvjpegCreateSimple failed: ", - create_status); + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + } // Create nvJPEG state (should this be persistent or not?) nvjpegJpegState_t nvjpeg_state; @@ -108,11 +111,10 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { "nvjpegDecode failed: ", decode_status); - // Destroy the state and (for now) library handle + // Destroy the state nvjpegJpegStateDestroy(nvjpeg_state); - nvjpegDestroy(nvjpeg_handle); return tensor; } -#endif // NVJPEG_FOUND \ No newline at end of file +#endif // NVJPEG_FOUND From 8abe4a56ee4d4915c30d1290195b0956270f312e Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Fri, 22 Jan 2021 22:25:42 +0000 Subject: [PATCH 03/28] Add io/image/cuda search path --- CMakeLists.txt | 2 +- setup.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bf76d97cddf..18c269d79a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,7 +53,7 @@ include(CMakePackageConfigHelpers) set(TVCPP torchvision/csrc) list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops - ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu) + ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda) if(WITH_CUDA) list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast) endif() diff --git a/setup.py b/setup.py index 5ac96506675..70646063507 100644 --- a/setup.py +++ b/setup.py @@ -329,7 +329,8 @@ def get_extensions(): image_include += [nvjpeg_include] image_path = os.path.join(extensions_dir, 'io', 'image') - image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) + image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) + + glob.glob(os.path.join(image_path, 'cuda', '*.cpp'))) if png_found or jpeg_found: ext_modules.append(extension( From 8ae07510f73efdea16ad8e6a16d33a42f9e56d74 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Fri, 22 Jan 2021 22:56:57 +0000 Subject: [PATCH 04/28] Update test --- test/test_image.py | 11 +++++++++-- torchvision/csrc/io/image/cuda/readjpeg_cuda.h | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index f702d8cb883..4db9c72c9df 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -78,7 +78,13 @@ def test_decode_jpeg(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_decode_jpeg_cuda(self): for img_path in get_images(IMAGE_ROOT, ".jpg"): - img_pil = torch.load(img_path.replace('jpg', 'pth')) + with Image.open(img_path) as img: + img_pil = torch.from_numpy(np.array(img)) + + if len(img_pil.shape) != 3 or img_pil.shape[2] != 3: + # only RGB supported so far + continue + img_pil = img_pil.permute(2, 0, 1) data = read_file(img_path) img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data) @@ -87,7 +93,8 @@ def test_decode_jpeg_cuda(self): # Image.fromarray(img_nvjpeg.permute(1,2,0).cpu().numpy()).save('/tmp/im.png') # Image.fromarray(img_pil.permute(1,2,0).cpu().numpy()).save('/tmp/impil.png') - self.assertTrue((img_pil.float() - img_nvjpeg.cpu().float()).abs().mean() < 1.5) + # Some difference expected between jpeg implementations + self.assertTrue((img_pil.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h b/torchvision/csrc/io/image/cuda/readjpeg_cuda.h index 89bc83f2b44..25167ddaa21 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.h @@ -1,5 +1,5 @@ #pragma once -#include +#include C10_EXPORT torch::Tensor decodeJPEG_cuda(const torch::Tensor& data); From 9a2510ff00a03f23cb0ac332eac31177024de861 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 00:29:34 +0000 Subject: [PATCH 05/28] Building CUDA ext should mean nvjpeg exists It will be included in the cuda sdk lib and include paths set up by CUDAExtension --- setup.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 70646063507..73ea2120433 100644 --- a/setup.py +++ b/setup.py @@ -316,17 +316,14 @@ def get_extensions(): image_include += [jpeg_include] # Locating nvjpeg - (nvjpeg_found, nvjpeg_conda, - nvjpeg_include, nvjpeg_lib) = find_library('nvjpeg', vision_include) + # Should be included in CUDA_HOME + nvjpeg_found = extension is CUDAExtension print('NVJPEG found: {0}'.format(nvjpeg_found)) image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] if nvjpeg_found: print('Building torchvision with NVJPEG image support') image_link_flags.append('nvjpeg') - if nvjpeg_conda: - image_library += [nvjpeg_lib] - image_include += [nvjpeg_include] image_path = os.path.join(extensions_dir, 'io', 'image') image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) From ac3330b02c661e1aed99a8101c47120711f31aa1 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 00:48:24 +0000 Subject: [PATCH 06/28] Check if nvjpeg.h is actually there --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 73ea2120433..250d56f86a3 100644 --- a/setup.py +++ b/setup.py @@ -317,7 +317,7 @@ def get_extensions(): # Locating nvjpeg # Should be included in CUDA_HOME - nvjpeg_found = extension is CUDAExtension + nvjpeg_found = extension is CUDAExtension and os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) print('NVJPEG found: {0}'.format(nvjpeg_found)) image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] From e7981578889be7a5f13a7f94298d081f573abada Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 14:13:43 +0000 Subject: [PATCH 07/28] Add ImageReadMode support for nvjpeg --- test/test_image.py | 26 +++----- .../csrc/io/image/cuda/readjpeg_cuda.cpp | 62 ++++++++++++++----- .../csrc/io/image/cuda/readjpeg_cuda.h | 5 +- 3 files changed, 62 insertions(+), 31 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 4db9c72c9df..285d161dfcc 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -77,24 +77,18 @@ def test_decode_jpeg(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_decode_jpeg_cuda(self): + conversion = [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB] for img_path in get_images(IMAGE_ROOT, ".jpg"): - with Image.open(img_path) as img: - img_pil = torch.from_numpy(np.array(img)) - - if len(img_pil.shape) != 3 or img_pil.shape[2] != 3: - # only RGB supported so far - continue - - img_pil = img_pil.permute(2, 0, 1) - data = read_file(img_path) - img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data) - self.assertTrue(img_nvjpeg.is_cuda) - - # Image.fromarray(img_nvjpeg.permute(1,2,0).cpu().numpy()).save('/tmp/im.png') - # Image.fromarray(img_pil.permute(1,2,0).cpu().numpy()).save('/tmp/impil.png') + if Image.open(img_path).mode == 'CMYK': + # not supported + continue + for mode in conversion: + data = read_file(img_path) + img_ljpeg = decode_image(data, mode=mode) + img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value) - # Some difference expected between jpeg implementations - self.assertTrue((img_pil.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) + # Some difference expected between jpeg implementations + self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp index 514b0d0ab1c..ec9cb21291f 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp @@ -1,11 +1,10 @@ #include "readjpeg_cuda.h" -#include #include #if !NVJPEG_FOUND -torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { +torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { TORCH_CHECK( false, "decodeJPEG_cuda: torchvision not compiled with nvJPEG support"); } @@ -23,7 +22,7 @@ void init_nvjpegImage(nvjpegImage_t& img) { } } -torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { +torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional @@ -68,16 +67,51 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { widths, heights); - TORCH_CHECK( - info_status == NVJPEG_STATUS_SUCCESS, - "nvjpegGetImageInfo failed: ", - info_status); + if (info_status != NVJPEG_STATUS_SUCCESS) { + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); + } - TORCH_CHECK(components == 3, "Only RGB for now"); + if (subsampling == NVJPEG_CSS_UNKNOWN) { + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); + } int width = widths[0]; int height = heights[0]; + nvjpegOutputFormat_t outputFormat; + int outputComponents; + + switch (mode) { + case IMAGE_READ_MODE_UNCHANGED: + if (components == 1) { + outputFormat = NVJPEG_OUTPUT_Y; + outputComponents = 1; + } else if (components == 3) { + outputFormat = NVJPEG_OUTPUT_RGB; + outputComponents = 3; + } else { + TORCH_CHECK( + false, "The provided mode is not supported for JPEG files on GPU"); + } + break; + case IMAGE_READ_MODE_GRAY: + // This will do 0.299*R + 0.587*G + 0.114*B like opencv + // TODO check if that is the same as libjpeg + outputFormat = NVJPEG_OUTPUT_Y; + outputComponents = 1; + break; + case IMAGE_READ_MODE_RGB: + outputFormat = NVJPEG_OUTPUT_RGB; + outputComponents = 3; + break; + default: + // CMYK as input might work with nvjpegDecodeParamsSetAllowCMYK() + TORCH_CHECK( + false, "The provided mode is not supported for JPEG files on GPU"); + } + // nvjpegImage_t is a struct with // - an array of pointers to each channel // - the pitch for each channel @@ -87,10 +121,10 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { // TODO device selection auto tensor = torch::empty( - {int64_t(components), int64_t(height), int64_t(width)}, + {int64_t(outputComponents), int64_t(height), int64_t(width)}, torch::dtype(torch::kU8).device(torch::kCUDA)); - for (int c = 0; c < 3; c++) { + for (int c = 0; c < outputComponents; c++) { outImage.channel[c] = tensor[c].data_ptr(); outImage.pitch[c] = width; } @@ -102,18 +136,18 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data) { nvjpeg_state, datap, data.numel(), - NVJPEG_OUTPUT_RGB, + outputFormat, &outImage, /*stream=*/0); + // Destroy the state + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK( decode_status == NVJPEG_STATUS_SUCCESS, "nvjpegDecode failed: ", decode_status); - // Destroy the state - nvjpegJpegStateDestroy(nvjpeg_state); - return tensor; } diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h b/torchvision/csrc/io/image/cuda/readjpeg_cuda.h index 25167ddaa21..aa5194f8652 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.h @@ -1,5 +1,8 @@ #pragma once #include +#include "../image_read_mode.h" -C10_EXPORT torch::Tensor decodeJPEG_cuda(const torch::Tensor& data); +C10_EXPORT torch::Tensor decodeJPEG_cuda( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); From a07f53a8a8cc7197b325afde4e5537b61c712a53 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 14:19:14 +0000 Subject: [PATCH 08/28] lint --- test/test_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 285d161dfcc..337da6330ee 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -80,8 +80,8 @@ def test_decode_jpeg_cuda(self): conversion = [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB] for img_path in get_images(IMAGE_ROOT, ".jpg"): if Image.open(img_path).mode == 'CMYK': - # not supported - continue + # not supported + continue for mode in conversion: data = read_file(img_path) img_ljpeg = decode_image(data, mode=mode) From e485656d0145952f5b62899139b0511e927dd8df Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 14:28:21 +0000 Subject: [PATCH 09/28] Call nvjpegJpegStateDestroy when bailing out --- torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp index ec9cb21291f..16af9a429d3 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp @@ -92,6 +92,7 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { outputFormat = NVJPEG_OUTPUT_RGB; outputComponents = 3; } else { + nvjpegJpegStateDestroy(nvjpeg_state); TORCH_CHECK( false, "The provided mode is not supported for JPEG files on GPU"); } @@ -108,6 +109,7 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { break; default: // CMYK as input might work with nvjpegDecodeParamsSetAllowCMYK() + nvjpegJpegStateDestroy(nvjpeg_state); TORCH_CHECK( false, "The provided mode is not supported for JPEG files on GPU"); } From 1c1e471ee12d3132f543fd49a38d0a43548d6265 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 15:05:52 +0000 Subject: [PATCH 10/28] Use at::cuda::getCurrentCUDAStream() --- torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp index 16af9a429d3..9b1af0f6fd3 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp @@ -11,13 +11,15 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { #else +#include +#include #include static nvjpegHandle_t nvjpeg_handle = nullptr; void init_nvjpegImage(nvjpegImage_t& img) { for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { - img.channel[c] = NULL; + img.channel[c] = nullptr; img.pitch[c] = 0; } } @@ -131,8 +133,8 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { outImage.pitch[c] = width; } - // TODO torch cuda stream support - // TODO output besides RGB + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + nvjpegStatus_t decode_status = nvjpegDecode( nvjpeg_handle, nvjpeg_state, @@ -140,7 +142,7 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { data.numel(), outputFormat, &outImage, - /*stream=*/0); + stream); // Destroy the state nvjpegJpegStateDestroy(nvjpeg_state); From 5bc5e21f805acd75053b4aaf94973ff7a1b604a3 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 14:34:52 +0000 Subject: [PATCH 11/28] Changes to match #3312 --- .../{readjpeg_cuda.cpp => decode_jpeg_cuda.cpp} | 14 ++++++++++---- .../cuda/{readjpeg_cuda.h => decode_jpeg_cuda.h} | 8 +++++++- torchvision/csrc/io/image/image.cpp | 2 +- torchvision/csrc/io/image/image.h | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) rename torchvision/csrc/io/image/cuda/{readjpeg_cuda.cpp => decode_jpeg_cuda.cpp} (91%) rename torchvision/csrc/io/image/cuda/{readjpeg_cuda.h => decode_jpeg_cuda.h} (55%) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp similarity index 91% rename from torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp rename to torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 9b1af0f6fd3..21a5bd8b518 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,12 +1,15 @@ -#include "readjpeg_cuda.h" +#include "decode_jpeg_cuda.h" #include +namespace vision { +namespace image { + #if !NVJPEG_FOUND -torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { TORCH_CHECK( - false, "decodeJPEG_cuda: torchvision not compiled with nvJPEG support"); + false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); } #else @@ -24,7 +27,7 @@ void init_nvjpegImage(nvjpegImage_t& img) { } } -torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional @@ -156,3 +159,6 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { } #endif // NVJPEG_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h similarity index 55% rename from torchvision/csrc/io/image/cuda/readjpeg_cuda.h rename to torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h index aa5194f8652..3345dfd770e 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h @@ -3,6 +3,12 @@ #include #include "../image_read_mode.h" -C10_EXPORT torch::Tensor decodeJPEG_cuda( +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_jpeg_cuda( const torch::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 3dd4cdcf5c0..37d64013cb2 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -22,7 +22,7 @@ static auto registry = torch::RegisterOperators() .op("image::read_file", &read_file) .op("image::write_file", &write_file) .op("image::decode_image", &decode_image) - .op("image::decode_jpeg_cuda", &decodeJPEG_cuda); + .op("image::decode_jpeg_cuda", &decode_jpeg_cuda); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index cb380d5ebdd..05bac44c77d 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -6,4 +6,4 @@ #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" -#include "cuda/readjpeg_cuda.h" +#include "cuda/decode_jpeg_cuda.h" From 4d4cd45cf0366c4c081fcf2e5eb051a45c3e0f3e Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 14:58:36 +0000 Subject: [PATCH 12/28] Move includes outside namespace --- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 21a5bd8b518..bb028004046 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,6 +1,8 @@ #include "decode_jpeg_cuda.h" #include +#include +#include namespace vision { namespace image { @@ -14,8 +16,6 @@ torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { #else -#include -#include #include static nvjpegHandle_t nvjpeg_handle = nullptr; From dd3e445d692702298f76513e220e9362f55e3556 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 15:02:49 +0000 Subject: [PATCH 13/28] Lint --- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index bb028004046..541b13bd072 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,8 +1,8 @@ #include "decode_jpeg_cuda.h" -#include #include #include +#include namespace vision { namespace image { From f560eebb42b8f9d87b4633da3ea64e844678fbbf Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 15:10:42 +0000 Subject: [PATCH 14/28] Guard includes so cpu builds work --- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 541b13bd072..3f7cf888895 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -1,7 +1,12 @@ #include "decode_jpeg_cuda.h" #include + +#if NVJPEG_FOUND #include +#include +#endif + #include namespace vision { @@ -16,8 +21,6 @@ torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { #else -#include - static nvjpegHandle_t nvjpeg_handle = nullptr; void init_nvjpegImage(nvjpegImage_t& img) { From ab90893b13ae49f0fd39f31141269648af6ef84f Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 7 Feb 2021 18:44:10 +0000 Subject: [PATCH 15/28] Add device argument --- test/test_image.py | 2 +- .../csrc/io/image/cuda/decode_jpeg_cuda.cpp | 19 ++++++++++++++++--- .../csrc/io/image/cuda/decode_jpeg_cuda.h | 3 ++- torchvision/io/image.py | 9 +++++++-- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 337da6330ee..712edba80d9 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -85,7 +85,7 @@ def test_decode_jpeg_cuda(self): for mode in conversion: data = read_file(img_path) img_ljpeg = decode_image(data, mode=mode) - img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value) + img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value, 'cuda') # Some difference expected between jpeg implementations self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 3f7cf888895..ed194062dba 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -4,6 +4,7 @@ #if NVJPEG_FOUND #include +#include #include #endif @@ -14,7 +15,10 @@ namespace image { #if !NVJPEG_FOUND -torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { TORCH_CHECK( false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); } @@ -30,7 +34,10 @@ void init_nvjpegImage(nvjpegImage_t& img) { } } -torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional @@ -38,6 +45,12 @@ torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); + TORCH_CHECK( + device.is_cuda(), "Expected a cuda device" + ) + + at::cuda::CUDAGuard device_guard(device); + auto datap = data.data_ptr(); // Create nvJPEG handle @@ -132,7 +145,7 @@ torch::Tensor decode_jpeg_cuda(const torch::Tensor& data, ImageReadMode mode) { // TODO device selection auto tensor = torch::empty( {int64_t(outputComponents), int64_t(height), int64_t(width)}, - torch::dtype(torch::kU8).device(torch::kCUDA)); + torch::dtype(torch::kU8).device(device)); for (int c = 0; c < outputComponents; c++) { outImage.channel[c] = tensor[c].data_ptr(); diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h index 3345dfd770e..496b355e9b7 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h @@ -8,7 +8,8 @@ namespace image { C10_EXPORT torch::Tensor decode_jpeg_cuda( const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + ImageReadMode mode, + torch::Device device); } // namespace image } // namespace vision diff --git a/torchvision/io/image.py b/torchvision/io/image.py index e193555e447..0c14ae007b4 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -149,7 +149,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): write_file(filename, output) -def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: +def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, + device: torch.device = 'cpu') -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. Optionally converts the image to the desired format. @@ -166,7 +167,11 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG Returns: output (Tensor[image_channels, image_height, image_width]) """ - output = torch.ops.image.decode_jpeg(input, mode.value) + device = torch.device(device) + if device.type == 'cuda': + output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) + else: + output = torch.ops.image.decode_jpeg(input, mode.value) return output From 540eaa4ef5857ea5f4ec3cf4975489e607413e9d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 27 Apr 2021 07:18:47 -0700 Subject: [PATCH 16/28] WIP --- setup.py | 2 ++ test/test_image.py | 1 + 2 files changed, 3 insertions(+) diff --git a/setup.py b/setup.py index 499a8d664d5..53148d83cb0 100644 --- a/setup.py +++ b/setup.py @@ -318,8 +318,10 @@ def get_extensions(): # Locating nvjpeg # Should be included in CUDA_HOME nvjpeg_found = extension is CUDAExtension and os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) + nvjpeg_found = True print('NVJPEG found: {0}'.format(nvjpeg_found)) + print(f"{CUDA_HOME}") image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] if nvjpeg_found: print('Building torchvision with NVJPEG image support') diff --git a/test/test_image.py b/test/test_image.py index 712edba80d9..58a3fc21b10 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -86,6 +86,7 @@ def test_decode_jpeg_cuda(self): data = read_file(img_path) img_ljpeg = decode_image(data, mode=mode) img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value, 'cuda') + # img_nvjpeg = decode_jpeg(data, mode=mode.value, device='cuda') # Some difference expected between jpeg implementations self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) From 7b6eadfdb5e7e63204befe8caa05b00834e9ae44 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 28 Apr 2021 16:42:07 +0000 Subject: [PATCH 17/28] WIP --- test/test_image.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 58a3fc21b10..b9c1d9891b2 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -85,8 +85,10 @@ def test_decode_jpeg_cuda(self): for mode in conversion: data = read_file(img_path) img_ljpeg = decode_image(data, mode=mode) - img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value, 'cuda') - # img_nvjpeg = decode_jpeg(data, mode=mode.value, device='cuda') + img_nvjpeg = decode_jpeg(data, mode=mode, device='cuda:0') + + img_nvjpeg2 = decode_jpeg(data, mode=mode, device='cuda:1') + self.assertTrue((img_nvjpeg.cpu().float() - img_nvjpeg2.cpu().float()).abs().mean() < 1e-10) # Some difference expected between jpeg implementations self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) From 380d5b5b563a55d26a69e66bdb1c9076b521031c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 4 May 2021 10:13:55 +0000 Subject: [PATCH 18/28] WIP --- nvjpeg_bench.py | 47 +++++++++ .../csrc/io/image/cuda/decode_jpeg_cuda.cpp | 97 +++++++++++++++++++ .../csrc/io/image/cuda/decode_jpeg_cuda.h | 9 ++ torchvision/csrc/io/image/image.cpp | 3 +- 4 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 nvjpeg_bench.py diff --git a/nvjpeg_bench.py b/nvjpeg_bench.py new file mode 100644 index 00000000000..4b1453684e6 --- /dev/null +++ b/nvjpeg_bench.py @@ -0,0 +1,47 @@ +import torch +from torch.utils.benchmark import Timer +from torchvision.io.image import decode_jpeg, read_file, ImageReadMode, write_jpeg, encode_jpeg +from torchvision import transforms as T + +img_path = 'big_2kx2k.jpg' +img_path = 'test/assets/encode_jpeg/grace_hopper_517x606.jpg' +data = read_file(img_path) +batch_size = 32 +batch_data = torch.cat([data] * batch_size, dim=0) +img = decode_jpeg(data) + +def sumup(name, mean, median, throughput, fps): + print( + f"{name:<10} mean: {mean:.3f} ms, median: {median:.3f} ms, " + f"Throughput = {throughput:.3f} Megapixel / sec, " + f"{fps:.3f} fps" + ) + +print(f"{img.shape = }") +print(f"{data.shape = }") +height, width = img.shape[-2:] + +num_pixels = height * width +num_runs = 30 + +stmt = "decode_jpeg(data, device='{}')" +setup = 'from torchvision.io.image import decode_jpeg' +globals = {'data': data} + +for device in ('cpu', 'cuda'): + t = Timer(stmt=stmt.format(device), setup=setup, globals=globals).timeit(num_runs) + sumup(device, t.mean * 1000, t.median * 1000, num_pixels / 1e6 / t.median, 1 / t.median) + +# Benchmark batch +stmt = "torch.ops.image.decode_jpeg_batch_cuda(batch_data, mode, device, batch_size, height, width)" +setup = 'import torch' +globals = { + 'batch_data': batch_data, 'mode': ImageReadMode.UNCHANGED.value, 'device': torch.device('cuda'), 'batch_size': batch_size, + 'height': height, 'width': width +} +t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs) + +sumup("BATCH cuda", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median) + +out = torch.ops.image.decode_jpeg_batch_cuda(batch_data, ImageReadMode.UNCHANGED.value, torch.device('cuda'), batch_size, height, width) +write_jpeg(out[0].to('cpu'), 'saved_imgs/first.jpg') \ No newline at end of file diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index ed194062dba..995f36ca788 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -174,6 +174,103 @@ torch::Tensor decode_jpeg_cuda( return tensor; } + +torch::Tensor decode_jpeg_batch_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device, + int64_t batch_size, + int64_t height, + int64_t width) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + TORCH_CHECK( + device.is_cuda(), "Expected a cuda device" + ) + + at::cuda::CUDAGuard device_guard(device); + + // Create nvJPEG handle + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + } + + // Create nvJPEG state (should this be persistent or not?) + nvjpegJpegState_t nvjpeg_state; + nvjpegStatus_t state_status = + nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); + + TORCH_CHECK( + state_status == NVJPEG_STATUS_SUCCESS, + "nvjpegJpegStateCreate failed: ", + state_status); + + nvjpegOutputFormat_t output_format = NVJPEG_OUTPUT_RGB; + int outputComponents = 3; + int max_cpu_threads = 1; + + nvjpegStatus_t decode_batched_initialize_state = + nvjpegDecodeBatchedInitialize( + nvjpeg_handle, nvjpeg_state, batch_size, max_cpu_threads, output_format); + + TORCH_CHECK( + decode_batched_initialize_state == NVJPEG_STATUS_SUCCESS, + "nvjpegDecodeBatchedInitialize failed: ", + decode_batched_initialize_state); + + std::vector iout(batch_size); + std::vector lengths(batch_size); + + auto datap = data.data_ptr(); + std::vector batched_bitstreams; + int single_encoded_jpeg_size = data.numel() / batch_size; // TODO change this. This assume all the images in the batch are the same. + + auto tensor = torch::empty( + {batch_size, int64_t(outputComponents), int64_t(height), int64_t(width)}, + torch::dtype(torch::kU8).device(device)); + + for (size_t img_idx = 0; img_idx < batch_size; img_idx++) { + init_nvjpegImage(iout[img_idx]); + lengths[img_idx] = single_encoded_jpeg_size; + batched_bitstreams.push_back((const unsigned char*)(datap + (img_idx * single_encoded_jpeg_size))); + for (int c = 0; c < outputComponents; c++) { + iout[img_idx].channel[c] = tensor[img_idx][c].data_ptr(); + iout[img_idx].pitch[c] = width; + } + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + nvjpegStatus_t decode_batched_status = nvjpegDecodeBatched( + nvjpeg_handle, + nvjpeg_state, + batched_bitstreams.data(), + lengths.data(), + iout.data(), + stream + ); + + // Destroy the state + nvjpegJpegStateDestroy(nvjpeg_state); + + TORCH_CHECK( + decode_batched_status == NVJPEG_STATUS_SUCCESS, + "nvjpegDecodeBatched failed: ", + decode_batched_status); + + return tensor; +} + #endif // NVJPEG_FOUND } // namespace image diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h index 496b355e9b7..a2f7698d868 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h @@ -11,5 +11,14 @@ C10_EXPORT torch::Tensor decode_jpeg_cuda( ImageReadMode mode, torch::Device device); +C10_EXPORT torch::Tensor decode_jpeg_batch_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device, + int64_t batch_size, + int64_t height, + int64_t width + ); + } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 37d64013cb2..ce9bd043557 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -22,7 +22,8 @@ static auto registry = torch::RegisterOperators() .op("image::read_file", &read_file) .op("image::write_file", &write_file) .op("image::decode_image", &decode_image) - .op("image::decode_jpeg_cuda", &decode_jpeg_cuda); + .op("image::decode_jpeg_cuda", &decode_jpeg_cuda) + .op("image::decode_jpeg_batch_cuda", &decode_jpeg_batch_cuda); } // namespace image } // namespace vision From ef4e8ce14cfbcce9dbb859363cc21fff107165f3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 7 May 2021 13:07:20 +0000 Subject: [PATCH 19/28] clean up --- .circleci/config.yml.in | 10 +- nvjpeg_bench.py | 59 ++--- test/common_utils.py | 22 +- test/test_image.py | 58 +++-- .../csrc/io/image/cuda/decode_jpeg_cuda.cpp | 211 +++++------------- torchvision/csrc/io/image/image.cpp | 3 +- torchvision/io/image.py | 9 +- 7 files changed, 156 insertions(+), 216 deletions(-) diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 0984b8bb961..fe811d75dbe 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -337,7 +337,7 @@ jobs: binary_macos_wheel: <<: *binary_common macos: - xcode: "12.0" + xcode: "9.4.1" steps: - checkout_merge - designate_upload_channel @@ -397,7 +397,7 @@ jobs: binary_macos_conda: <<: *binary_common macos: - xcode: "12.0" + xcode: "9.4.1" steps: - checkout_merge - designate_upload_channel @@ -648,7 +648,7 @@ jobs: command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux/scripts/install.sh - run: name: Run tests - command: docker run -e CIRCLECI -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh - run: name: Post Process command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/post_process.sh @@ -739,7 +739,7 @@ jobs: unittest_macos_cpu: <<: *binary_common macos: - xcode: "12.0" + xcode: "9.4.1" resource_class: large steps: - checkout @@ -815,7 +815,7 @@ jobs: cmake_macos_cpu: <<: *binary_common macos: - xcode: "12.0" + xcode: "9.4.1" steps: - checkout_merge - designate_upload_channel diff --git a/nvjpeg_bench.py b/nvjpeg_bench.py index 4b1453684e6..22e15d59286 100644 --- a/nvjpeg_bench.py +++ b/nvjpeg_bench.py @@ -2,46 +2,53 @@ from torch.utils.benchmark import Timer from torchvision.io.image import decode_jpeg, read_file, ImageReadMode, write_jpeg, encode_jpeg from torchvision import transforms as T +import sys -img_path = 'big_2kx2k.jpg' -img_path = 'test/assets/encode_jpeg/grace_hopper_517x606.jpg' +img_path = sys.argv[1] data = read_file(img_path) -batch_size = 32 -batch_data = torch.cat([data] * batch_size, dim=0) img = decode_jpeg(data) +write_jpeg(T.Resize((300, 300))(img), 'lol.jpg') + def sumup(name, mean, median, throughput, fps): print( - f"{name:<10} mean: {mean:.3f} ms, median: {median:.3f} ms, " - f"Throughput = {throughput:.3f} Megapixel / sec, " - f"{fps:.3f} fps" + f"{name:<20} - mean: {mean:<7.2f} ms, median: {median:<7.2f} ms, " + f"Throughput = {throughput:<7.1f} Megapixel / sec, " + f"{fps:<7.1f} fps" ) -print(f"{img.shape = }") -print(f"{data.shape = }") +print(f"Using {img_path}") +print(f"{img.shape = }, {data.shape = }") height, width = img.shape[-2:] num_pixels = height * width num_runs = 30 -stmt = "decode_jpeg(data, device='{}')" -setup = 'from torchvision.io.image import decode_jpeg' -globals = {'data': data} -for device in ('cpu', 'cuda'): - t = Timer(stmt=stmt.format(device), setup=setup, globals=globals).timeit(num_runs) - sumup(device, t.mean * 1000, t.median * 1000, num_pixels / 1e6 / t.median, 1 / t.median) +for batch_size in (1, 4, 16, 32, 64): + print(f"{batch_size = }") + + # non-batch implem + for device in ('cpu', 'cuda'): + if batch_size >= 32 and height >= 1000 and device == 'cuda': + print(f"skipping for-loop for {batch_size = } and {device = }") + continue + stmt = f"for _ in range(batch_size): decode_jpeg(data, device='{device}')" + setup = 'from torchvision.io.image import decode_jpeg' + globals = {'data': data, 'batch_size': batch_size} + + t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs) + sumup(f"for-loop {device}", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median) -# Benchmark batch -stmt = "torch.ops.image.decode_jpeg_batch_cuda(batch_data, mode, device, batch_size, height, width)" -setup = 'import torch' -globals = { - 'batch_data': batch_data, 'mode': ImageReadMode.UNCHANGED.value, 'device': torch.device('cuda'), 'batch_size': batch_size, - 'height': height, 'width': width -} -t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs) + # # Batch implem + # stmt = "torch.ops.image.decode_jpeg_batch_cuda(batch_data, mode, device, batch_size, height, width)" + # setup = 'import torch' + # batch_data = torch.cat([data] * batch_size, dim=0) + # globals = { + # 'batch_data': batch_data, 'mode': ImageReadMode.UNCHANGED.value, 'device': torch.device('cuda'), 'batch_size': batch_size, + # 'height': height, 'width': width + # } + # t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs) -sumup("BATCH cuda", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median) + # sumup(f"BATCH cuda", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median) -out = torch.ops.image.decode_jpeg_batch_cuda(batch_data, ImageReadMode.UNCHANGED.value, torch.device('cuda'), batch_size, height, width) -write_jpeg(out[0].to('cpu'), 'saved_imgs/first.jpg') \ No newline at end of file diff --git a/test/common_utils.py b/test/common_utils.py index 2c2fa73cf04..2a4aab5f65b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -24,6 +24,9 @@ PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG) IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true' +IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None +IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available' @contextlib.contextmanager @@ -407,11 +410,8 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names): def cpu_and_gpu(): import pytest # noqa - # ignore CPU tests in RE as they're already covered by another contbuild - IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None - IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" - CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available' + # ignore CPU tests in RE as they're already covered by another contbuild devices = [] if IN_RE_WORKER else ['cpu'] if torch.cuda.is_available(): @@ -427,3 +427,17 @@ def cpu_and_gpu(): devices.append(pytest.param('cuda', marks=cuda_marks)) return devices + + +def needs_cuda(test_func): + import pytest # noqa + + if IN_FBCODE and not IN_RE_WORKER: + # We don't want to skip in fbcode, so we just don't collect + # TODO: slightly more robust way would be to detect if we're in a sandcastle instance + # so that the test will still be collected (and skipped) in the devvms. + return pytest.mark.dont_collect(test_func) + elif torch.cuda.is_available(): + return test_func + else: + return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func) diff --git a/test/test_image.py b/test/test_image.py index b9c1d9891b2..67287e71f42 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -3,10 +3,11 @@ import os import unittest +import pytest import numpy as np import torch from PIL import Image -from common_utils import get_tmp_dir +from common_utils import get_tmp_dir, needs_cuda from torchvision.io.image import ( decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, @@ -75,24 +76,6 @@ def test_decode_jpeg(self): with self.assertRaises(RuntimeError): decode_jpeg(torch.empty((100), dtype=torch.uint8)) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_decode_jpeg_cuda(self): - conversion = [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB] - for img_path in get_images(IMAGE_ROOT, ".jpg"): - if Image.open(img_path).mode == 'CMYK': - # not supported - continue - for mode in conversion: - data = read_file(img_path) - img_ljpeg = decode_image(data, mode=mode) - img_nvjpeg = decode_jpeg(data, mode=mode, device='cuda:0') - - img_nvjpeg2 = decode_jpeg(data, mode=mode, device='cuda:1') - self.assertTrue((img_nvjpeg.cpu().float() - img_nvjpeg2.cpu().float()).abs().mean() < 1e-10) - - # Some difference expected between jpeg implementations - self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) - def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) @@ -296,5 +279,42 @@ def test_write_file_non_ascii(self): os.unlink(fpath) +@needs_cuda +@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) +@pytest.mark.parametrize('img_path', get_images(IMAGE_ROOT, ".jpg")) +@pytest.mark.parametrize('scripted', (False, True)) +def test_decode_jpeg_cuda(mode, img_path, scripted): + if 'cmyk' in img_path: + pytest.xfail("Decoding a CMYK jpeg isn't supported") + tester = ImageTester() + data = read_file(img_path) + img = decode_image(data, mode=mode) + f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg + img_nvjpeg = f(data, mode=mode, device='cuda') + + # Some difference expected between jpeg implementations + tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2) + + +@needs_cuda +@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda'))) +def test_decode_jpeg_cuda_device_param(cuda_device): + """Make sure we can pass a string or a torch.device as device param""" + data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) + decode_jpeg(data, device=cuda_device) + + +@needs_cuda +def test_decode_jpeg_cuda_errors(): + data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) + with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): + decode_jpeg(data.reshape(-1, 1), device='cuda') + with pytest.raises(RuntimeError, match="input tensor must be on CPU"): + decode_jpeg(data.to('cuda'), device='cuda') + with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): + decode_jpeg(data.to(torch.float), device='cuda') + with pytest.raises(RuntimeError, match="Expected a cuda device"): + torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu') + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 995f36ca788..6b6d9d0c451 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -27,33 +27,23 @@ torch::Tensor decode_jpeg_cuda( static nvjpegHandle_t nvjpeg_handle = nullptr; -void init_nvjpegImage(nvjpegImage_t& img) { - for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { - img.channel[c] = nullptr; - img.pitch[c] = 0; - } -} - torch::Tensor decode_jpeg_cuda( const torch::Tensor& data, ImageReadMode mode, torch::Device device) { - // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional + TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); + !data.is_cuda(), + "The input tensor must be on CPU when decoding with nvjpeg") TORCH_CHECK( - device.is_cuda(), "Expected a cuda device" - ) + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); - at::cuda::CUDAGuard device_guard(device); + TORCH_CHECK(device.is_cuda(), "Expected a cuda device") - auto datap = data.data_ptr(); - - // Create nvJPEG handle + // Create global nvJPEG handle if (nvjpeg_handle == nullptr) { nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); @@ -63,212 +53,117 @@ torch::Tensor decode_jpeg_cuda( create_status); } - // Create nvJPEG state (should this be persistent or not?) - nvjpegJpegState_t nvjpeg_state; + // Create the jpeg state + nvjpegJpegState_t jpeg_state; nvjpegStatus_t state_status = - nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); + nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state); TORCH_CHECK( state_status == NVJPEG_STATUS_SUCCESS, "nvjpegJpegStateCreate failed: ", state_status); + auto datap = data.data_ptr(); + // Get the image information - int components; + int num_channels; nvjpegChromaSubsampling_t subsampling; int widths[NVJPEG_MAX_COMPONENT]; int heights[NVJPEG_MAX_COMPONENT]; - nvjpegStatus_t info_status = nvjpegGetImageInfo( nvjpeg_handle, datap, data.numel(), - &components, + &num_channels, &subsampling, widths, heights); if (info_status != NVJPEG_STATUS_SUCCESS) { - nvjpegJpegStateDestroy(nvjpeg_state); + nvjpegJpegStateDestroy(jpeg_state); TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); } if (subsampling == NVJPEG_CSS_UNKNOWN) { - nvjpegJpegStateDestroy(nvjpeg_state); + nvjpegJpegStateDestroy(jpeg_state); TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); } int width = widths[0]; int height = heights[0]; - nvjpegOutputFormat_t outputFormat; - int outputComponents; + nvjpegOutputFormat_t ouput_format; + int num_channels_output; switch (mode) { case IMAGE_READ_MODE_UNCHANGED: - if (components == 1) { - outputFormat = NVJPEG_OUTPUT_Y; - outputComponents = 1; - } else if (components == 3) { - outputFormat = NVJPEG_OUTPUT_RGB; - outputComponents = 3; - } else { - nvjpegJpegStateDestroy(nvjpeg_state); + num_channels_output = num_channels; + // For some reason, setting out_format to NVJPEG_OUTPUT_UNCHANGED will not + // properly decode RGB images (it's fine for grayscale), so we set + // output_format manually here + if (num_channels == 1) { + ouput_format = NVJPEG_OUTPUT_Y; + } else if (num_channels == 3) { + ouput_format = NVJPEG_OUTPUT_RGB; + } + else { + nvjpegJpegStateDestroy(jpeg_state); TORCH_CHECK( - false, "The provided mode is not supported for JPEG files on GPU"); + false, "When mode is UNCHANGED, only 1 or 3 input channels are allowed."); } break; case IMAGE_READ_MODE_GRAY: - // This will do 0.299*R + 0.587*G + 0.114*B like opencv - // TODO check if that is the same as libjpeg - outputFormat = NVJPEG_OUTPUT_Y; - outputComponents = 1; + ouput_format = NVJPEG_OUTPUT_Y; + num_channels_output = 1; break; case IMAGE_READ_MODE_RGB: - outputFormat = NVJPEG_OUTPUT_RGB; - outputComponents = 3; + ouput_format = NVJPEG_OUTPUT_RGB; + num_channels_output = 3; break; default: - // CMYK as input might work with nvjpegDecodeParamsSetAllowCMYK() - nvjpegJpegStateDestroy(nvjpeg_state); + nvjpegJpegStateDestroy(jpeg_state); TORCH_CHECK( - false, "The provided mode is not supported for JPEG files on GPU"); + false, "The provided mode is not supported for JPEG decoding on GPU"); } + auto out_tensor = torch::empty( + {int64_t(num_channels_output), int64_t(height), int64_t(width)}, + torch::dtype(torch::kU8).device(device)); + // nvjpegImage_t is a struct with // - an array of pointers to each channel // - the pitch for each channel // which must be filled in manually - nvjpegImage_t outImage; - init_nvjpegImage(outImage); + nvjpegImage_t out_image; - // TODO device selection - auto tensor = torch::empty( - {int64_t(outputComponents), int64_t(height), int64_t(width)}, - torch::dtype(torch::kU8).device(device)); - - for (int c = 0; c < outputComponents; c++) { - outImage.channel[c] = tensor[c].data_ptr(); - outImage.pitch[c] = width; + for (int c = 0; c < num_channels_output; c++) { + out_image.channel[c] = out_tensor[c].data_ptr(); + out_image.pitch[c] = width; + } + for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) { + out_image.channel[c] = nullptr; + out_image.pitch[c] = 0; } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); nvjpegStatus_t decode_status = nvjpegDecode( nvjpeg_handle, - nvjpeg_state, + jpeg_state, datap, data.numel(), - outputFormat, - &outImage, + ouput_format, + &out_image, stream); - // Destroy the state - nvjpegJpegStateDestroy(nvjpeg_state); + nvjpegJpegStateDestroy(jpeg_state); TORCH_CHECK( decode_status == NVJPEG_STATUS_SUCCESS, "nvjpegDecode failed: ", decode_status); - return tensor; -} - - -torch::Tensor decode_jpeg_batch_cuda( - const torch::Tensor& data, - ImageReadMode mode, - torch::Device device, - int64_t batch_size, - int64_t height, - int64_t width) { - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - TORCH_CHECK( - device.is_cuda(), "Expected a cuda device" - ) - - at::cuda::CUDAGuard device_guard(device); - - // Create nvJPEG handle - if (nvjpeg_handle == nullptr) { - nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); - - TORCH_CHECK( - create_status == NVJPEG_STATUS_SUCCESS, - "nvjpegCreateSimple failed: ", - create_status); - } - - // Create nvJPEG state (should this be persistent or not?) - nvjpegJpegState_t nvjpeg_state; - nvjpegStatus_t state_status = - nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); - - TORCH_CHECK( - state_status == NVJPEG_STATUS_SUCCESS, - "nvjpegJpegStateCreate failed: ", - state_status); - - nvjpegOutputFormat_t output_format = NVJPEG_OUTPUT_RGB; - int outputComponents = 3; - int max_cpu_threads = 1; - - nvjpegStatus_t decode_batched_initialize_state = - nvjpegDecodeBatchedInitialize( - nvjpeg_handle, nvjpeg_state, batch_size, max_cpu_threads, output_format); - - TORCH_CHECK( - decode_batched_initialize_state == NVJPEG_STATUS_SUCCESS, - "nvjpegDecodeBatchedInitialize failed: ", - decode_batched_initialize_state); - - std::vector iout(batch_size); - std::vector lengths(batch_size); - - auto datap = data.data_ptr(); - std::vector batched_bitstreams; - int single_encoded_jpeg_size = data.numel() / batch_size; // TODO change this. This assume all the images in the batch are the same. - - auto tensor = torch::empty( - {batch_size, int64_t(outputComponents), int64_t(height), int64_t(width)}, - torch::dtype(torch::kU8).device(device)); - - for (size_t img_idx = 0; img_idx < batch_size; img_idx++) { - init_nvjpegImage(iout[img_idx]); - lengths[img_idx] = single_encoded_jpeg_size; - batched_bitstreams.push_back((const unsigned char*)(datap + (img_idx * single_encoded_jpeg_size))); - for (int c = 0; c < outputComponents; c++) { - iout[img_idx].channel[c] = tensor[img_idx][c].data_ptr(); - iout[img_idx].pitch[c] = width; - } - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - nvjpegStatus_t decode_batched_status = nvjpegDecodeBatched( - nvjpeg_handle, - nvjpeg_state, - batched_bitstreams.data(), - lengths.data(), - iout.data(), - stream - ); - - // Destroy the state - nvjpegJpegStateDestroy(nvjpeg_state); - - TORCH_CHECK( - decode_batched_status == NVJPEG_STATUS_SUCCESS, - "nvjpegDecodeBatched failed: ", - decode_batched_status); - - return tensor; + return out_tensor; } #endif // NVJPEG_FOUND diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index ce9bd043557..37d64013cb2 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -22,8 +22,7 @@ static auto registry = torch::RegisterOperators() .op("image::read_file", &read_file) .op("image::write_file", &write_file) .op("image::decode_image", &decode_image) - .op("image::decode_jpeg_cuda", &decode_jpeg_cuda) - .op("image::decode_jpeg_batch_cuda", &decode_jpeg_batch_cuda); + .op("image::decode_jpeg_cuda", &decode_jpeg_cuda); } // namespace image } // namespace vision diff --git a/torchvision/io/image.py b/torchvision/io/image.py index c9cb946187a..627eabe166a 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -149,7 +149,7 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, - device: torch.device = 'cpu') -> torch.Tensor: + device: str = 'cpu') -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. Optionally converts the image to the desired format. @@ -157,11 +157,16 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG Args: input (Tensor[1]): a one dimensional uint8 tensor containing - the raw bytes of the JPEG image. + the raw bytes of the JPEG image. This tensor must be on CPU, + regardless of the ``device`` parameter. mode (ImageReadMode): the read mode used for optionally converting the image. Default: `ImageReadMode.UNCHANGED`. See `ImageReadMode` class for more information on various available modes. + device (str or torch.device): The device on which the decoded image will + be stored. If a cuda device is specified, the image will be decoded + with `nvjpeg `_. This is only + supported for CUDA version >= TODO Returns: output (Tensor[image_channels, image_height, image_width]) From 8de3a4a1ee9218bec68cf670bc73a40eeae07de6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 7 May 2021 13:08:06 +0000 Subject: [PATCH 20/28] remove bench --- nvjpeg_bench.py | 54 ------------------------------------------------- 1 file changed, 54 deletions(-) delete mode 100644 nvjpeg_bench.py diff --git a/nvjpeg_bench.py b/nvjpeg_bench.py deleted file mode 100644 index 22e15d59286..00000000000 --- a/nvjpeg_bench.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -from torch.utils.benchmark import Timer -from torchvision.io.image import decode_jpeg, read_file, ImageReadMode, write_jpeg, encode_jpeg -from torchvision import transforms as T -import sys - -img_path = sys.argv[1] -data = read_file(img_path) -img = decode_jpeg(data) -write_jpeg(T.Resize((300, 300))(img), 'lol.jpg') - - -def sumup(name, mean, median, throughput, fps): - print( - f"{name:<20} - mean: {mean:<7.2f} ms, median: {median:<7.2f} ms, " - f"Throughput = {throughput:<7.1f} Megapixel / sec, " - f"{fps:<7.1f} fps" - ) - -print(f"Using {img_path}") -print(f"{img.shape = }, {data.shape = }") -height, width = img.shape[-2:] - -num_pixels = height * width -num_runs = 30 - - -for batch_size in (1, 4, 16, 32, 64): - print(f"{batch_size = }") - - # non-batch implem - for device in ('cpu', 'cuda'): - if batch_size >= 32 and height >= 1000 and device == 'cuda': - print(f"skipping for-loop for {batch_size = } and {device = }") - continue - stmt = f"for _ in range(batch_size): decode_jpeg(data, device='{device}')" - setup = 'from torchvision.io.image import decode_jpeg' - globals = {'data': data, 'batch_size': batch_size} - - t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs) - sumup(f"for-loop {device}", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median) - - # # Batch implem - # stmt = "torch.ops.image.decode_jpeg_batch_cuda(batch_data, mode, device, batch_size, height, width)" - # setup = 'import torch' - # batch_data = torch.cat([data] * batch_size, dim=0) - # globals = { - # 'batch_data': batch_data, 'mode': ImageReadMode.UNCHANGED.value, 'device': torch.device('cuda'), 'batch_size': batch_size, - # 'height': height, 'width': width - # } - # t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs) - - # sumup(f"BATCH cuda", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median) - From be5526e8498198700a8adb1019cfb919e7bd8fb5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 7 May 2021 13:11:47 +0000 Subject: [PATCH 21/28] cleanup --- .circleci/config.yml.in | 10 +++++----- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h | 9 --------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index fe811d75dbe..0984b8bb961 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -337,7 +337,7 @@ jobs: binary_macos_wheel: <<: *binary_common macos: - xcode: "9.4.1" + xcode: "12.0" steps: - checkout_merge - designate_upload_channel @@ -397,7 +397,7 @@ jobs: binary_macos_conda: <<: *binary_common macos: - xcode: "9.4.1" + xcode: "12.0" steps: - checkout_merge - designate_upload_channel @@ -648,7 +648,7 @@ jobs: command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux/scripts/install.sh - run: name: Run tests - command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh + command: docker run -e CIRCLECI -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh - run: name: Post Process command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/post_process.sh @@ -739,7 +739,7 @@ jobs: unittest_macos_cpu: <<: *binary_common macos: - xcode: "9.4.1" + xcode: "12.0" resource_class: large steps: - checkout @@ -815,7 +815,7 @@ jobs: cmake_macos_cpu: <<: *binary_common macos: - xcode: "9.4.1" + xcode: "12.0" steps: - checkout_merge - designate_upload_channel diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h index a2f7698d868..496b355e9b7 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h @@ -11,14 +11,5 @@ C10_EXPORT torch::Tensor decode_jpeg_cuda( ImageReadMode mode, torch::Device device); -C10_EXPORT torch::Tensor decode_jpeg_batch_cuda( - const torch::Tensor& data, - ImageReadMode mode, - torch::Device device, - int64_t batch_size, - int64_t height, - int64_t width - ); - } // namespace image } // namespace vision From 7b5c09f1a683c9e6c123642245646eee9610329e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 7 May 2021 13:15:03 +0000 Subject: [PATCH 22/28] linting --- test/test_image.py | 1 + torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 67287e71f42..11c8f3d7a03 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -316,5 +316,6 @@ def test_decode_jpeg_cuda_errors(): with pytest.raises(RuntimeError, match="Expected a cuda device"): torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu') + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 6b6d9d0c451..8962baf33ee 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -105,11 +105,11 @@ torch::Tensor decode_jpeg_cuda( ouput_format = NVJPEG_OUTPUT_Y; } else if (num_channels == 3) { ouput_format = NVJPEG_OUTPUT_RGB; - } - else { + } else { nvjpegJpegStateDestroy(jpeg_state); TORCH_CHECK( - false, "When mode is UNCHANGED, only 1 or 3 input channels are allowed."); + false, + "When mode is UNCHANGED, only 1 or 3 input channels are allowed."); } break; case IMAGE_READ_MODE_GRAY: From 67aad819a2ef20db9262d8b14a0af6cdabe8af13 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 7 May 2021 13:49:14 +0000 Subject: [PATCH 23/28] fix setup.py --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 13f2eb7fce5..766e01adbc0 100644 --- a/setup.py +++ b/setup.py @@ -318,7 +318,6 @@ def get_extensions(): # Locating nvjpeg # Should be included in CUDA_HOME nvjpeg_found = extension is CUDAExtension and os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) - nvjpeg_found = True print('NVJPEG found: {0}'.format(nvjpeg_found)) print(f"{CUDA_HOME}") From 7be3c110b67b47049c44d2562d58c9d3e1b38972 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 7 May 2021 14:38:19 +0000 Subject: [PATCH 24/28] rocm --- setup.py | 9 ++++++--- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 4 ++-- torchvision/io/image.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 766e01adbc0..4cc3d0698a4 100644 --- a/setup.py +++ b/setup.py @@ -316,11 +316,14 @@ def get_extensions(): image_include += [jpeg_include] # Locating nvjpeg - # Should be included in CUDA_HOME - nvjpeg_found = extension is CUDAExtension and os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) + # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI + nvjpeg_found = ( + extension is CUDAExtension and + CUDA_HOME is not None and + os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) + ) print('NVJPEG found: {0}'.format(nvjpeg_found)) - print(f"{CUDA_HOME}") image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] if nvjpeg_found: print('Building torchvision with NVJPEG image support') diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 8962baf33ee..5ea5a508d4b 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -98,8 +98,8 @@ torch::Tensor decode_jpeg_cuda( switch (mode) { case IMAGE_READ_MODE_UNCHANGED: num_channels_output = num_channels; - // For some reason, setting out_format to NVJPEG_OUTPUT_UNCHANGED will not - // properly decode RGB images (it's fine for grayscale), so we set + // For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will + // not properly decode RGB images (it's fine for grayscale), so we set // output_format manually here if (num_channels == 1) { ouput_format = NVJPEG_OUTPUT_Y; diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 627eabe166a..399a6f0d1ac 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -166,7 +166,7 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG device (str or torch.device): The device on which the decoded image will be stored. If a cuda device is specified, the image will be decoded with `nvjpeg `_. This is only - supported for CUDA version >= TODO + supported for CUDA version >= 10.1 Returns: output (Tensor[image_channels, image_height, image_width]) From fbb4511989293fd2f4e91ed0078bc579c6864a9b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 7 May 2021 16:15:38 +0000 Subject: [PATCH 25/28] use proper device for stream --- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 5ea5a508d4b..fc440e48511 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -4,7 +4,6 @@ #if NVJPEG_FOUND #include -#include #include #endif @@ -145,7 +144,7 @@ torch::Tensor decode_jpeg_cuda( out_image.pitch[c] = 0; } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()); nvjpegStatus_t decode_status = nvjpegDecode( nvjpeg_handle, From 45f451582b697be4bcfce80526e63a5ed3ed490c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 7 May 2021 16:25:14 +0000 Subject: [PATCH 26/28] put back device guard --- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index fc440e48511..d649c9aa57d 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -4,6 +4,7 @@ #if NVJPEG_FOUND #include +#include #include #endif @@ -42,6 +43,8 @@ torch::Tensor decode_jpeg_cuda( TORCH_CHECK(device.is_cuda(), "Expected a cuda device") + at::cuda::CUDAGuard device_guard(device); + // Create global nvJPEG handle if (nvjpeg_handle == nullptr) { nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); From ba63e1ef13f076d3e176c76d58faf58c3190f9ec Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 11 May 2021 11:17:24 +0000 Subject: [PATCH 27/28] Add unnamed namespace and use call_once to safely create handle --- .../csrc/io/image/cuda/decode_jpeg_cuda.cpp | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index d649c9aa57d..01b45d2e1b6 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -25,7 +25,9 @@ torch::Tensor decode_jpeg_cuda( #else +namespace { static nvjpegHandle_t nvjpeg_handle = nullptr; +} torch::Tensor decode_jpeg_cuda( const torch::Tensor& data, @@ -46,14 +48,23 @@ torch::Tensor decode_jpeg_cuda( at::cuda::CUDAGuard device_guard(device); // Create global nvJPEG handle - if (nvjpeg_handle == nullptr) { - nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); - - TORCH_CHECK( - create_status == NVJPEG_STATUS_SUCCESS, - "nvjpegCreateSimple failed: ", - create_status); - } + static std::once_flag nvjpeg_handle_creation_flag; + std::call_once(nvjpeg_handle_creation_flag, []() { + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + + if (create_status != NVJPEG_STATUS_SUCCESS) { + // Reset handle so that one can still call the function again in the + // same process if there was a failure + free(nvjpeg_handle); + nvjpeg_handle = nullptr; + } + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + } + }); // Create the jpeg state nvjpegJpegState_t jpeg_state; From f1676022fdc48e4e2406ab441a53a6a708cb3e7d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 11 May 2021 11:52:31 +0000 Subject: [PATCH 28/28] once_flag shouldn't be static --- torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp index 01b45d2e1b6..68f63ced427 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -48,7 +48,7 @@ torch::Tensor decode_jpeg_cuda( at::cuda::CUDAGuard device_guard(device); // Create global nvJPEG handle - static std::once_flag nvjpeg_handle_creation_flag; + std::once_flag nvjpeg_handle_creation_flag; std::call_once(nvjpeg_handle_creation_flag, []() { if (nvjpeg_handle == nullptr) { nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);