Skip to content

Add support for accimage.Image #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test/assets/grace_hopper_517x606.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 15 additions & 6 deletions test/preprocess-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm import tqdm
import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

Expand All @@ -15,11 +16,17 @@
help='number of data loading threads (default: 2)')
parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
help='mini-batch size (1 = pure stochastic) Default: 256')
parser.add_argument('--accimage', action='store_true',
help='use accimage')


if __name__ == "__main__":
args = parser.parse_args()

if args.accimage:
torchvision.set_image_backend('accimage')
print('Using {}'.format(torchvision.get_image_backend()))

# Data loading code
transform = transforms.Compose([
transforms.RandomSizedCrop(224),
Expand All @@ -38,11 +45,13 @@
train_iter = iter(train_loader)

start_time = timer()
batch_count = 100 * args.nThreads
for i in tqdm(xrange(batch_count)):
batch_count = 20 * args.nThreads
for _ in tqdm(range(batch_count)):

This comment was marked as off-topic.

This comment was marked as off-topic.

batch = next(train_iter)
end_time = timer()
print("Performance: {dataset:.0f} minutes/dataset, {batch:.2f} secs/batch, {image:.2f} ms/image".format(
dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
batch=(end_time - start_time) / float(batch_count),
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3))
print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch,"
" {image:.2f} ms/image {rate:.0f} images/sec"
.format(dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
batch=(end_time - start_time) / float(batch_count) * 1.0e+3,
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3,
rate=(batch_count * args.batchSize) / (end_time - start_time)))
47 changes: 47 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
import unittest
import random
import numpy as np
from PIL import Image
try:
import accimage
except ImportError:
accimage = None


GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg'


class Tester(unittest.TestCase):
Expand Down Expand Up @@ -153,6 +161,45 @@ def test_to_tensor(self):
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
assert np.allclose(output.numpy(), expected_output)

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_to_tensor(self):
trans = transforms.ToTensor()

expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))

self.assertEqual(expected_output.size(), output.size())
assert np.allclose(output.numpy(), expected_output.numpy())

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_resize(self):
trans = transforms.Compose([
transforms.Scale(256, interpolation=Image.LINEAR),
transforms.ToTensor(),
])

expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))

self.assertEqual(expected_output.size(), output.size())
self.assertLess(np.abs((expected_output - output).mean()), 1e-3)
self.assertLess((expected_output - output).var(), 1e-5)
# note the high absolute tolerance

This comment was marked as off-topic.

This comment was marked as off-topic.

assert np.allclose(output.numpy(), expected_output.numpy(), atol=5e-2)

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_crop(self):
trans = transforms.Compose([
transforms.CenterCrop(256),
transforms.ToTensor(),
])

expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))

self.assertEqual(expected_output.size(), output.size())
assert np.allclose(output.numpy(), expected_output.numpy())

def test_tensor_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
Expand Down
28 changes: 28 additions & 0 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,31 @@
from torchvision import datasets
from torchvision import transforms
from torchvision import utils


_image_backend = 'PIL'


def set_image_backend(backend):
"""
Specifies the package used to load images.

Options are 'PIL' and 'accimage'. The :mod:`accimage` package uses the
Intel IPP library. It is generally faster than PIL, but does not support as
many operations.

Args:
backend (string): name of the image backend
"""
global _image_backend
if backend not in ['PIL', 'accimage']:
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'"
.format(backend))
_image_backend = backend


def get_image_backend():
"""
Gets the name of the package used to load images
"""
return _image_backend
19 changes: 18 additions & 1 deletion torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,27 @@ def make_dataset(dir, class_to_idx):
return images


def default_loader(path):
def pil_loader(path):
return Image.open(path).convert('RGB')


def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)


def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)


class ImageFolder(data.Dataset):

def __init__(self, root, transform=None, target_transform=None,
Expand Down
10 changes: 10 additions & 0 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import math
import random
from PIL import Image, ImageOps
try:
import accimage
except ImportError:
accimage = None
import numpy as np
import numbers
import types
Expand Down Expand Up @@ -42,6 +46,12 @@ def __call__(self, pic):
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backard compability
return img.float().div(255)

if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)

# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
Expand Down