diff --git a/CMakeLists.txt b/CMakeLists.txt index 547ab7ddd2b..2dec2de88e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,7 @@ include(CMakePackageConfigHelpers) set(TVCPP torchvision/csrc) list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops - ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu) + ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda) if(WITH_CUDA) list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast) endif() diff --git a/setup.py b/setup.py index 86007622715..4cc3d0698a4 100644 --- a/setup.py +++ b/setup.py @@ -315,8 +315,23 @@ def get_extensions(): image_library += [jpeg_lib] image_include += [jpeg_include] + # Locating nvjpeg + # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI + nvjpeg_found = ( + extension is CUDAExtension and + CUDA_HOME is not None and + os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) + ) + + print('NVJPEG found: {0}'.format(nvjpeg_found)) + image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] + if nvjpeg_found: + print('Building torchvision with NVJPEG image support') + image_link_flags.append('nvjpeg') + image_path = os.path.join(extensions_dir, 'io', 'image') - image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) + image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) + + glob.glob(os.path.join(image_path, 'cuda', '*.cpp'))) if png_found or jpeg_found: ext_modules.append(extension( diff --git a/test/common_utils.py b/test/common_utils.py index 2c2fa73cf04..2a4aab5f65b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -24,6 +24,9 @@ PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG) IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true' +IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None +IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available' @contextlib.contextmanager @@ -407,11 +410,8 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names): def cpu_and_gpu(): import pytest # noqa - # ignore CPU tests in RE as they're already covered by another contbuild - IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None - IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" - CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available' + # ignore CPU tests in RE as they're already covered by another contbuild devices = [] if IN_RE_WORKER else ['cpu'] if torch.cuda.is_available(): @@ -427,3 +427,17 @@ def cpu_and_gpu(): devices.append(pytest.param('cuda', marks=cuda_marks)) return devices + + +def needs_cuda(test_func): + import pytest # noqa + + if IN_FBCODE and not IN_RE_WORKER: + # We don't want to skip in fbcode, so we just don't collect + # TODO: slightly more robust way would be to detect if we're in a sandcastle instance + # so that the test will still be collected (and skipped) in the devvms. + return pytest.mark.dont_collect(test_func) + elif torch.cuda.is_available(): + return test_func + else: + return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func) diff --git a/test/test_image.py b/test/test_image.py index ebc9a221f6d..11c8f3d7a03 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -3,10 +3,11 @@ import os import unittest +import pytest import numpy as np import torch from PIL import Image -from common_utils import get_tmp_dir +from common_utils import get_tmp_dir, needs_cuda from torchvision.io.image import ( decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, @@ -278,5 +279,43 @@ def test_write_file_non_ascii(self): os.unlink(fpath) +@needs_cuda +@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) +@pytest.mark.parametrize('img_path', get_images(IMAGE_ROOT, ".jpg")) +@pytest.mark.parametrize('scripted', (False, True)) +def test_decode_jpeg_cuda(mode, img_path, scripted): + if 'cmyk' in img_path: + pytest.xfail("Decoding a CMYK jpeg isn't supported") + tester = ImageTester() + data = read_file(img_path) + img = decode_image(data, mode=mode) + f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg + img_nvjpeg = f(data, mode=mode, device='cuda') + + # Some difference expected between jpeg implementations + tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2) + + +@needs_cuda +@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda'))) +def test_decode_jpeg_cuda_device_param(cuda_device): + """Make sure we can pass a string or a torch.device as device param""" + data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) + decode_jpeg(data, device=cuda_device) + + +@needs_cuda +def test_decode_jpeg_cuda_errors(): + data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) + with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): + decode_jpeg(data.reshape(-1, 1), device='cuda') + with pytest.raises(RuntimeError, match="input tensor must be on CPU"): + decode_jpeg(data.to('cuda'), device='cuda') + with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): + decode_jpeg(data.to(torch.float), device='cuda') + with pytest.raises(RuntimeError, match="Expected a cuda device"): + torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu') + + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp new file mode 100644 index 00000000000..68f63ced427 --- /dev/null +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -0,0 +1,185 @@ +#include "decode_jpeg_cuda.h" + +#include + +#if NVJPEG_FOUND +#include +#include +#include +#endif + +#include + +namespace vision { +namespace image { + +#if !NVJPEG_FOUND + +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { + TORCH_CHECK( + false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); +} + +#else + +namespace { +static nvjpegHandle_t nvjpeg_handle = nullptr; +} + +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + + TORCH_CHECK( + !data.is_cuda(), + "The input tensor must be on CPU when decoding with nvjpeg") + + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + TORCH_CHECK(device.is_cuda(), "Expected a cuda device") + + at::cuda::CUDAGuard device_guard(device); + + // Create global nvJPEG handle + std::once_flag nvjpeg_handle_creation_flag; + std::call_once(nvjpeg_handle_creation_flag, []() { + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + + if (create_status != NVJPEG_STATUS_SUCCESS) { + // Reset handle so that one can still call the function again in the + // same process if there was a failure + free(nvjpeg_handle); + nvjpeg_handle = nullptr; + } + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + } + }); + + // Create the jpeg state + nvjpegJpegState_t jpeg_state; + nvjpegStatus_t state_status = + nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state); + + TORCH_CHECK( + state_status == NVJPEG_STATUS_SUCCESS, + "nvjpegJpegStateCreate failed: ", + state_status); + + auto datap = data.data_ptr(); + + // Get the image information + int num_channels; + nvjpegChromaSubsampling_t subsampling; + int widths[NVJPEG_MAX_COMPONENT]; + int heights[NVJPEG_MAX_COMPONENT]; + nvjpegStatus_t info_status = nvjpegGetImageInfo( + nvjpeg_handle, + datap, + data.numel(), + &num_channels, + &subsampling, + widths, + heights); + + if (info_status != NVJPEG_STATUS_SUCCESS) { + nvjpegJpegStateDestroy(jpeg_state); + TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); + } + + if (subsampling == NVJPEG_CSS_UNKNOWN) { + nvjpegJpegStateDestroy(jpeg_state); + TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); + } + + int width = widths[0]; + int height = heights[0]; + + nvjpegOutputFormat_t ouput_format; + int num_channels_output; + + switch (mode) { + case IMAGE_READ_MODE_UNCHANGED: + num_channels_output = num_channels; + // For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will + // not properly decode RGB images (it's fine for grayscale), so we set + // output_format manually here + if (num_channels == 1) { + ouput_format = NVJPEG_OUTPUT_Y; + } else if (num_channels == 3) { + ouput_format = NVJPEG_OUTPUT_RGB; + } else { + nvjpegJpegStateDestroy(jpeg_state); + TORCH_CHECK( + false, + "When mode is UNCHANGED, only 1 or 3 input channels are allowed."); + } + break; + case IMAGE_READ_MODE_GRAY: + ouput_format = NVJPEG_OUTPUT_Y; + num_channels_output = 1; + break; + case IMAGE_READ_MODE_RGB: + ouput_format = NVJPEG_OUTPUT_RGB; + num_channels_output = 3; + break; + default: + nvjpegJpegStateDestroy(jpeg_state); + TORCH_CHECK( + false, "The provided mode is not supported for JPEG decoding on GPU"); + } + + auto out_tensor = torch::empty( + {int64_t(num_channels_output), int64_t(height), int64_t(width)}, + torch::dtype(torch::kU8).device(device)); + + // nvjpegImage_t is a struct with + // - an array of pointers to each channel + // - the pitch for each channel + // which must be filled in manually + nvjpegImage_t out_image; + + for (int c = 0; c < num_channels_output; c++) { + out_image.channel[c] = out_tensor[c].data_ptr(); + out_image.pitch[c] = width; + } + for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) { + out_image.channel[c] = nullptr; + out_image.pitch[c] = 0; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()); + + nvjpegStatus_t decode_status = nvjpegDecode( + nvjpeg_handle, + jpeg_state, + datap, + data.numel(), + ouput_format, + &out_image, + stream); + + nvjpegJpegStateDestroy(jpeg_state); + + TORCH_CHECK( + decode_status == NVJPEG_STATUS_SUCCESS, + "nvjpegDecode failed: ", + decode_status); + + return out_tensor; +} + +#endif // NVJPEG_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h new file mode 100644 index 00000000000..496b355e9b7 --- /dev/null +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 51cf9c7ce3e..37d64013cb2 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators() .op("image::encode_jpeg", &encode_jpeg) .op("image::read_file", &read_file) .op("image::write_file", &write_file) - .op("image::decode_image", &decode_image); + .op("image::decode_image", &decode_image) + .op("image::decode_jpeg_cuda", &decode_jpeg_cuda); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index fb09d6d71b8..05bac44c77d 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -6,3 +6,4 @@ #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" +#include "cuda/decode_jpeg_cuda.h" diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 8310c1eb273..399a6f0d1ac 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -148,7 +148,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): write_file(filename, output) -def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: +def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, + device: str = 'cpu') -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. Optionally converts the image to the desired format. @@ -156,16 +157,25 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG Args: input (Tensor[1]): a one dimensional uint8 tensor containing - the raw bytes of the JPEG image. + the raw bytes of the JPEG image. This tensor must be on CPU, + regardless of the ``device`` parameter. mode (ImageReadMode): the read mode used for optionally converting the image. Default: `ImageReadMode.UNCHANGED`. See `ImageReadMode` class for more information on various available modes. + device (str or torch.device): The device on which the decoded image will + be stored. If a cuda device is specified, the image will be decoded + with `nvjpeg `_. This is only + supported for CUDA version >= 10.1 Returns: output (Tensor[image_channels, image_height, image_width]) """ - output = torch.ops.image.decode_jpeg(input, mode.value) + device = torch.device(device) + if device.type == 'cuda': + output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) + else: + output = torch.ops.image.decode_jpeg(input, mode.value) return output