Skip to content

Commit 1f2c15f

Browse files
colesburysoumith
authored andcommitted
Add support for accimage.Image (#153)
It can be enabled by calling torchvision.set_image_backend('accimage')
1 parent 323f529 commit 1f2c15f

File tree

6 files changed

+118
-7
lines changed

6 files changed

+118
-7
lines changed

test/assets/grace_hopper_517x606.jpg

72 KB
Loading

test/preprocess-bench.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from tqdm import tqdm
55
import torch
66
import torch.utils.data
7+
import torchvision
78
import torchvision.transforms as transforms
89
import torchvision.datasets as datasets
910

@@ -15,11 +16,17 @@
1516
help='number of data loading threads (default: 2)')
1617
parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
1718
help='mini-batch size (1 = pure stochastic) Default: 256')
19+
parser.add_argument('--accimage', action='store_true',
20+
help='use accimage')
1821

1922

2023
if __name__ == "__main__":
2124
args = parser.parse_args()
2225

26+
if args.accimage:
27+
torchvision.set_image_backend('accimage')
28+
print('Using {}'.format(torchvision.get_image_backend()))
29+
2330
# Data loading code
2431
transform = transforms.Compose([
2532
transforms.RandomSizedCrop(224),
@@ -38,11 +45,13 @@
3845
train_iter = iter(train_loader)
3946

4047
start_time = timer()
41-
batch_count = 100 * args.nThreads
42-
for i in tqdm(xrange(batch_count)):
48+
batch_count = 20 * args.nThreads
49+
for _ in tqdm(range(batch_count)):
4350
batch = next(train_iter)
4451
end_time = timer()
45-
print("Performance: {dataset:.0f} minutes/dataset, {batch:.2f} secs/batch, {image:.2f} ms/image".format(
46-
dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
47-
batch=(end_time - start_time) / float(batch_count),
48-
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3))
52+
print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch,"
53+
" {image:.2f} ms/image {rate:.0f} images/sec"
54+
.format(dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
55+
batch=(end_time - start_time) / float(batch_count) * 1.0e+3,
56+
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3,
57+
rate=(batch_count * args.batchSize) / (end_time - start_time)))

test/test_transforms.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
import unittest
44
import random
55
import numpy as np
6+
from PIL import Image
7+
try:
8+
import accimage
9+
except ImportError:
10+
accimage = None
11+
12+
13+
GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg'
614

715

816
class Tester(unittest.TestCase):
@@ -153,6 +161,45 @@ def test_to_tensor(self):
153161
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
154162
assert np.allclose(output.numpy(), expected_output)
155163

164+
@unittest.skipIf(accimage is None, 'accimage not available')
165+
def test_accimage_to_tensor(self):
166+
trans = transforms.ToTensor()
167+
168+
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
169+
output = trans(accimage.Image(GRACE_HOPPER))
170+
171+
self.assertEqual(expected_output.size(), output.size())
172+
assert np.allclose(output.numpy(), expected_output.numpy())
173+
174+
@unittest.skipIf(accimage is None, 'accimage not available')
175+
def test_accimage_resize(self):
176+
trans = transforms.Compose([
177+
transforms.Scale(256, interpolation=Image.LINEAR),
178+
transforms.ToTensor(),
179+
])
180+
181+
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
182+
output = trans(accimage.Image(GRACE_HOPPER))
183+
184+
self.assertEqual(expected_output.size(), output.size())
185+
self.assertLess(np.abs((expected_output - output).mean()), 1e-3)
186+
self.assertLess((expected_output - output).var(), 1e-5)
187+
# note the high absolute tolerance
188+
assert np.allclose(output.numpy(), expected_output.numpy(), atol=5e-2)
189+
190+
@unittest.skipIf(accimage is None, 'accimage not available')
191+
def test_accimage_crop(self):
192+
trans = transforms.Compose([
193+
transforms.CenterCrop(256),
194+
transforms.ToTensor(),
195+
])
196+
197+
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
198+
output = trans(accimage.Image(GRACE_HOPPER))
199+
200+
self.assertEqual(expected_output.size(), output.size())
201+
assert np.allclose(output.numpy(), expected_output.numpy())
202+
156203
def test_tensor_to_pil_image(self):
157204
trans = transforms.ToPILImage()
158205
to_tensor = transforms.ToTensor()

torchvision/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,31 @@
22
from torchvision import datasets
33
from torchvision import transforms
44
from torchvision import utils
5+
6+
7+
_image_backend = 'PIL'
8+
9+
10+
def set_image_backend(backend):
11+
"""
12+
Specifies the package used to load images.
13+
14+
Options are 'PIL' and 'accimage'. The :mod:`accimage` package uses the
15+
Intel IPP library. It is generally faster than PIL, but does not support as
16+
many operations.
17+
18+
Args:
19+
backend (string): name of the image backend
20+
"""
21+
global _image_backend
22+
if backend not in ['PIL', 'accimage']:
23+
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'"
24+
.format(backend))
25+
_image_backend = backend
26+
27+
28+
def get_image_backend():
29+
"""
30+
Gets the name of the package used to load images
31+
"""
32+
return _image_backend

torchvision/datasets/folder.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,27 @@ def make_dataset(dir, class_to_idx):
3838
return images
3939

4040

41-
def default_loader(path):
41+
def pil_loader(path):
4242
return Image.open(path).convert('RGB')
4343

4444

45+
def accimage_loader(path):
46+
import accimage
47+
try:
48+
return accimage.Image(path)
49+
except IOError:
50+
# Potentially a decoding problem, fall back to PIL.Image
51+
return pil_loader(path)
52+
53+
54+
def default_loader(path):
55+
from torchvision import get_image_backend
56+
if get_image_backend() == 'accimage':
57+
return accimage_loader(path)
58+
else:
59+
return pil_loader(path)
60+
61+
4562
class ImageFolder(data.Dataset):
4663

4764
def __init__(self, root, transform=None, target_transform=None,

torchvision/transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import math
44
import random
55
from PIL import Image, ImageOps
6+
try:
7+
import accimage
8+
except ImportError:
9+
accimage = None
610
import numpy as np
711
import numbers
812
import types
@@ -42,6 +46,12 @@ def __call__(self, pic):
4246
img = torch.from_numpy(pic.transpose((2, 0, 1)))
4347
# backard compability
4448
return img.float().div(255)
49+
50+
if accimage is not None and isinstance(pic, accimage.Image):
51+
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
52+
pic.copyto(nppic)
53+
return torch.from_numpy(nppic)
54+
4555
# handle PIL Image
4656
if pic.mode == 'I':
4757
img = torch.from_numpy(np.array(pic, np.int32, copy=False))

0 commit comments

Comments
 (0)