Skip to content

Support for decoding jpegs on GPU with nvjpeg #3792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f878b36
Initial stab at nvJPEG support #2742
jamt9000 Oct 10, 2020
5eb6d73
Init nvjpeg once on first call
jamt9000 Oct 11, 2020
afd4a2e
Merge remote-tracking branch 'origin' into nvjpeg
jamt9000 Jan 22, 2021
8abe4a5
Add io/image/cuda search path
jamt9000 Jan 22, 2021
8ae0751
Update test
jamt9000 Jan 22, 2021
9a2510f
Building CUDA ext should mean nvjpeg exists
jamt9000 Jan 23, 2021
ac3330b
Check if nvjpeg.h is actually there
jamt9000 Jan 23, 2021
e798157
Add ImageReadMode support for nvjpeg
jamt9000 Jan 23, 2021
a07f53a
lint
jamt9000 Jan 23, 2021
e485656
Call nvjpegJpegStateDestroy when bailing out
jamt9000 Jan 23, 2021
1c1e471
Use at::cuda::getCurrentCUDAStream()
jamt9000 Jan 23, 2021
3e7486e
Merge branch 'master' into nvjpeg
jamt9000 Feb 7, 2021
5bc5e21
Changes to match #3312
jamt9000 Feb 7, 2021
4d4cd45
Move includes outside namespace
jamt9000 Feb 7, 2021
dd3e445
Lint
jamt9000 Feb 7, 2021
f560eeb
Guard includes so cpu builds work
jamt9000 Feb 7, 2021
ab90893
Add device argument
jamt9000 Feb 7, 2021
0992fa4
Merge branch 'master' into nvjpeg_bis
NicolasHug Apr 27, 2021
540eaa4
WIP
NicolasHug Apr 27, 2021
785ba98
Merge branch 'master' of github.com:pytorch/vision into nvjpeg_bis
NicolasHug Apr 28, 2021
7b6eadf
WIP
NicolasHug Apr 28, 2021
380d5b5
WIP
NicolasHug May 4, 2021
3c73ac9
Merge branch 'master' of github.com:pytorch/vision into nvjpeg_bis
NicolasHug May 7, 2021
ef4e8ce
clean up
NicolasHug May 7, 2021
8de3a4a
remove bench
NicolasHug May 7, 2021
be5526e
cleanup
NicolasHug May 7, 2021
7b5c09f
linting
NicolasHug May 7, 2021
67aad81
fix setup.py
NicolasHug May 7, 2021
7be3c11
rocm
NicolasHug May 7, 2021
422d2f3
Merge branch 'master' into nvjpeg_bis
NicolasHug May 7, 2021
fbb4511
use proper device for stream
NicolasHug May 7, 2021
45f4515
put back device guard
NicolasHug May 7, 2021
7d9c55e
Merge branch 'master' of github.com:pytorch/vision into nvjpeg_bis
NicolasHug May 7, 2021
83124df
Merge branch 'nvjpeg_bis' of github.com:NicolasHug/vision into nvjpeg…
NicolasHug May 7, 2021
1c3c4bd
Merge branch 'master' of github.com:pytorch/vision into nvjpeg_bis
NicolasHug May 11, 2021
ba63e1e
Add unnamed namespace and use call_once to safely create handle
NicolasHug May 11, 2021
f167602
once_flag shouldn't be static
NicolasHug May 11, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ include(CMakePackageConfigHelpers)

set(TVCPP torchvision/csrc)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu)
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda)
if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif()
Expand Down
17 changes: 16 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,23 @@ def get_extensions():
image_library += [jpeg_lib]
image_include += [jpeg_include]

# Locating nvjpeg
# Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
nvjpeg_found = (
extension is CUDAExtension and
CUDA_HOME is not None and
os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h'))
)

print('NVJPEG found: {0}'.format(nvjpeg_found))
image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))]
if nvjpeg_found:
print('Building torchvision with NVJPEG image support')
image_link_flags.append('nvjpeg')

image_path = os.path.join(extensions_dir, 'io', 'image')
image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
+ glob.glob(os.path.join(image_path, 'cuda', '*.cpp')))

if png_found or jpeg_found:
ext_modules.append(extension(
Expand Down
22 changes: 18 additions & 4 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367"
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG)
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true'
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available'


@contextlib.contextmanager
Expand Down Expand Up @@ -407,11 +410,8 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names):

def cpu_and_gpu():
import pytest # noqa
# ignore CPU tests in RE as they're already covered by another contbuild
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available'

# ignore CPU tests in RE as they're already covered by another contbuild
devices = [] if IN_RE_WORKER else ['cpu']

if torch.cuda.is_available():
Expand All @@ -427,3 +427,17 @@ def cpu_and_gpu():
devices.append(pytest.param('cuda', marks=cuda_marks))

return devices


def needs_cuda(test_func):
import pytest # noqa

if IN_FBCODE and not IN_RE_WORKER:
# We don't want to skip in fbcode, so we just don't collect
# TODO: slightly more robust way would be to detect if we're in a sandcastle instance
# so that the test will still be collected (and skipped) in the devvms.
return pytest.mark.dont_collect(test_func)
elif torch.cuda.is_available():
return test_func
else:
return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func)
41 changes: 40 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os
import unittest

import pytest
import numpy as np
import torch
from PIL import Image
from common_utils import get_tmp_dir
from common_utils import get_tmp_dir, needs_cuda

from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
Expand Down Expand Up @@ -278,5 +279,43 @@ def test_write_file_non_ascii(self):
os.unlink(fpath)


@needs_cuda
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize('img_path', get_images(IMAGE_ROOT, ".jpg"))
@pytest.mark.parametrize('scripted', (False, True))
def test_decode_jpeg_cuda(mode, img_path, scripted):
if 'cmyk' in img_path:
pytest.xfail("Decoding a CMYK jpeg isn't supported")
tester = ImageTester()
data = read_file(img_path)
img = decode_image(data, mode=mode)
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
img_nvjpeg = f(data, mode=mode, device='cuda')

# Some difference expected between jpeg implementations
tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to consider the mean different or the max difference here? What would be the minimum value so that max tests pass here?

Copy link
Member Author

@NicolasHug NicolasHug May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max error can be quite high unfortunately, the minimum threshold for all tests to pass seems to be 52, after which some tests start failing.
In test_decode_jpeg, we also test for the MAE (with the same threshold=2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, this looks suspicious that the decoding gives such large differences. Something to keep an eye on



@needs_cuda
@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda')))
def test_decode_jpeg_cuda_device_param(cuda_device):
"""Make sure we can pass a string or a torch.device as device param"""
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
decode_jpeg(data, device=cuda_device)


@needs_cuda
def test_decode_jpeg_cuda_errors():
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(data.reshape(-1, 1), device='cuda')
with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
decode_jpeg(data.to('cuda'), device='cuda')
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(data.to(torch.float), device='cuda')
with pytest.raises(RuntimeError, match="Expected a cuda device"):
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')


if __name__ == '__main__':
unittest.main()
185 changes: 185 additions & 0 deletions torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#include "decode_jpeg_cuda.h"

#include <ATen/ATen.h>

#if NVJPEG_FOUND
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <nvjpeg.h>
#endif

#include <string>

namespace vision {
namespace image {

#if !NVJPEG_FOUND

torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
TORCH_CHECK(
false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support");
}

#else

namespace {
static nvjpegHandle_t nvjpeg_handle = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding this on an anonymous namespace as it looks an internal detail of the implementation. Also just checking whether this should be released at any point to avoid memory leaks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just checking whether this should be released at any point to avoid memory leaks

Yeah this is a good point. Creating / allocating it at each call has some severe overhead so it makes sense to declare this as a global variable (related discussion: #2786 (comment)). But this means we never really know when to release it, and the memory will only be freed when the process is killed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference: this is thread-safe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good that it's thread-safe but it's still unclear to me whether we have to find a way to release it or if we can leave it be. We don't have such an idiom at TorchVision but I wonder if there are examples of resources on PyTorch core that are never released.

@ezyang how do you handle situations like this on core?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We leak a bunch of global statically initialized objects, as it is the easiest way to avoid destructor ordering problems. If nvjpeg is very simple library you might be able to write an RAII class for this object and have it destruct properly on shutdown.

Precedent in PyTorch is the cudnn convolution cache, look at that for some inspiration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang Thanks for the insights.

@NicolasHug Given Ed's input, I think we can mark this as resolved if you want.

}

torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");

TORCH_CHECK(
!data.is_cuda(),
"The input tensor must be on CPU when decoding with nvjpeg")

TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");

TORCH_CHECK(device.is_cuda(), "Expected a cuda device")

at::cuda::CUDAGuard device_guard(device);

// Create global nvJPEG handle
std::once_flag nvjpeg_handle_creation_flag;
std::call_once(nvjpeg_handle_creation_flag, []() {
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);

if (create_status != NVJPEG_STATUS_SUCCESS) {
// Reset handle so that one can still call the function again in the
// same process if there was a failure
free(nvjpeg_handle);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's an opaque handle I think the use of free() may not be correct unless it's documented as being supported.

(I would hope that it simply leaves the handle as null if initialisation fails, although I don't see that in the docs - here it just reinits the handle without any freeing when hw backend fails though)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm that's a good point. What would you recommend instead of free?

Before pushing this I did a quick test by inserting

       nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);
       create_status = NVJPEG_STATUS_NOT_INITIALIZED;   // <- this

and all the tests were failing gracefully with E RuntimeError: nvjpegCreateSimple failed: 1. Since I was running the tests with pytest test/test_image.py -k cuda they were all in the same process and pytest was just catching the RuntimeErrors, so I assumed it was OK.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't use it because I was wondering whether nvjpegDestroy would properly work with a bad handle?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if you assume anything can happen if initialisation fails then it might end up being an arbitrary value like 0xDEADBEEF and all you can do is reset it to null.

nvjpeg_handle = nullptr;
}
TORCH_CHECK(
create_status == NVJPEG_STATUS_SUCCESS,
"nvjpegCreateSimple failed: ",
create_status);
}
});

// Create the jpeg state
nvjpegJpegState_t jpeg_state;
nvjpegStatus_t state_status =
nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state);

TORCH_CHECK(
state_status == NVJPEG_STATUS_SUCCESS,
"nvjpegJpegStateCreate failed: ",
state_status);

auto datap = data.data_ptr<uint8_t>();

// Get the image information
int num_channels;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];
nvjpegStatus_t info_status = nvjpegGetImageInfo(
nvjpeg_handle,
datap,
data.numel(),
&num_channels,
&subsampling,
widths,
heights);

if (info_status != NVJPEG_STATUS_SUCCESS) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status);
}

if (subsampling == NVJPEG_CSS_UNKNOWN) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling");
}

int width = widths[0];
int height = heights[0];

nvjpegOutputFormat_t ouput_format;
int num_channels_output;

switch (mode) {
case IMAGE_READ_MODE_UNCHANGED:
num_channels_output = num_channels;
// For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will
// not properly decode RGB images (it's fine for grayscale), so we set
// output_format manually here
if (num_channels == 1) {
ouput_format = NVJPEG_OUTPUT_Y;
} else if (num_channels == 3) {
ouput_format = NVJPEG_OUTPUT_RGB;
} else {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false,
"When mode is UNCHANGED, only 1 or 3 input channels are allowed.");
}
break;
case IMAGE_READ_MODE_GRAY:
ouput_format = NVJPEG_OUTPUT_Y;
num_channels_output = 1;
break;
case IMAGE_READ_MODE_RGB:
ouput_format = NVJPEG_OUTPUT_RGB;
num_channels_output = 3;
break;
default:
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false, "The provided mode is not supported for JPEG decoding on GPU");
}

auto out_tensor = torch::empty(
{int64_t(num_channels_output), int64_t(height), int64_t(width)},
torch::dtype(torch::kU8).device(device));

// nvjpegImage_t is a struct with
// - an array of pointers to each channel
// - the pitch for each channel
// which must be filled in manually
nvjpegImage_t out_image;

for (int c = 0; c < num_channels_output; c++) {
out_image.channel[c] = out_tensor[c].data_ptr<uint8_t>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is fine for now, but this adds extra overhead as we need to create a full Tensor just to extract the data pointer (and the Tensor construction is heavy). Given that we generally only have 3 channels that shouldn't be much of an issue, but still good to keep in mind.

Some alternatives would be to directly use the raw data_ptr adding the correct offsets, like

uint8_t * out_tensor_ptr = out_tensor.data_ptr<uint8_t>();
...
out_image.channel[c] = out_tensor_ptr + c * height * width;

Also, interesting that nvjpeg accept decoding images in both CHW and HWC formats -- I wonder if there is any performance implications by decoding it in CHW?

out_image.pitch[c] = width;
}
for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) {
out_image.channel[c] = nullptr;
out_image.pitch[c] = 0;
}

cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index());

nvjpegStatus_t decode_status = nvjpegDecode(
nvjpeg_handle,
jpeg_state,
datap,
data.numel(),
ouput_format,
&out_image,
stream);

nvjpegJpegStateDestroy(jpeg_state);

TORCH_CHECK(
decode_status == NVJPEG_STATUS_SUCCESS,
"nvjpegDecode failed: ",
decode_status);

return out_tensor;
}

#endif // NVJPEG_FOUND

} // namespace image
} // namespace vision
15 changes: 15 additions & 0 deletions torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device);

} // namespace image
} // namespace vision
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators()
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);
.op("image::decode_image", &decode_image)
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if the dispatcher would make sense here. Since this is the first IO method we add for GPU, it might be worth checking the naming conventions (_cuda) as this will be reproduced on the near future in other places. Thoughts @fmassa ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the dispatcher would be good, but I'm not sure how it handles constructor functions (like torch.empty / torch.rand).

Indeed, this function always takes CPU tensors, and it's up to a device argument to decide if we should dispatch to the CPU or the CUDA version.

@ezyang do you know if we can use the dispatcher to dispatch taking a torch.device into account, knowing that all tensors live in the CPU?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa How about reading the data on CPU since that's needed and then calling to() to move it on the right device. This can happen in the python side of things and remain hidden. Then after the binary data living in GPU, the dispatcher can be used as normal to decide if the decoding should happen on the GPU or CPU. Thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still, nvjpeg requires the input data to live on CPU, so we would need to move it back to CPU again within the function, which would be inefficient. I would have preferred if we could pass the tensor directly as a CUDA tensor as well, but I'm not sure this is possible without further overheads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarifications concerning nvjpeg. I think that we can investigate on future PRs how we could do this more elegantly. No need to block this PR.


} // namespace image
} // namespace vision
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
#include "cpu/encode_jpeg.h"
#include "cpu/encode_png.h"
#include "cpu/read_write_file.h"
#include "cuda/decode_jpeg_cuda.h"
16 changes: 13 additions & 3 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,24 +148,34 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output)


def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED,
device: str = 'cpu') -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.

Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the JPEG image.
the raw bytes of the JPEG image. This tensor must be on CPU,
regardless of the ``device`` parameter.
mode (ImageReadMode): the read mode used for optionally
converting the image. Default: `ImageReadMode.UNCHANGED`.
See `ImageReadMode` class for more information on various
available modes.
device (str or torch.device): The device on which the decoded image will
be stored. If a cuda device is specified, the image will be decoded
with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
supported for CUDA version >= 10.1

Returns:
output (Tensor[image_channels, image_height, image_width])
"""
output = torch.ops.image.decode_jpeg(input, mode.value)
device = torch.device(device)
if device.type == 'cuda':
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
else:
output = torch.ops.image.decode_jpeg(input, mode.value)
return output


Expand Down