Skip to content

Commit 6ee863c

Browse files
Maratyszczacolesbury
authored andcommitted
Add support for accimage.Image
It can be enabled by setting torchvision.image_backend = 'accimage'
1 parent 323f529 commit 6ee863c

File tree

6 files changed

+101
-6
lines changed

6 files changed

+101
-6
lines changed

test/assets/grace_hopper_512.jpg

64 KB
Loading

test/preprocess-bench.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from tqdm import tqdm
55
import torch
66
import torch.utils.data
7+
import torchvision
78
import torchvision.transforms as transforms
89
import torchvision.datasets as datasets
910

@@ -15,11 +16,17 @@
1516
help='number of data loading threads (default: 2)')
1617
parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
1718
help='mini-batch size (1 = pure stochastic) Default: 256')
19+
parser.add_argument('--accimage', action='store_true',
20+
help='use accimage')
1821

1922

2023
if __name__ == "__main__":
2124
args = parser.parse_args()
2225

26+
if args.accimage:
27+
torchvision.image_backend = 'accimage'
28+
print('Using {}'.format(torchvision.image_backend))
29+
2330
# Data loading code
2431
transform = transforms.Compose([
2532
transforms.RandomSizedCrop(224),
@@ -38,11 +45,14 @@
3845
train_iter = iter(train_loader)
3946

4047
start_time = timer()
41-
batch_count = 100 * args.nThreads
42-
for i in tqdm(xrange(batch_count)):
48+
batch_count = 20 * args.nThreads
49+
for _ in tqdm(range(batch_count)):
4350
batch = next(train_iter)
4451
end_time = timer()
45-
print("Performance: {dataset:.0f} minutes/dataset, {batch:.2f} secs/batch, {image:.2f} ms/image".format(
46-
dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
47-
batch=(end_time - start_time) / float(batch_count),
48-
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3))
52+
print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch,"
53+
" {image:.2f} ms/image {rate:.0f} images/sec"
54+
.format(
55+
dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
56+
batch=(end_time - start_time) / float(batch_count) * 1.0e+3,
57+
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3,
58+
rate=(batch_count * args.batchSize) / (end_time - start_time)))

test/test_transforms.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
import unittest
44
import random
55
import numpy as np
6+
from PIL import Image
7+
try:
8+
import accimage
9+
except ImportError:
10+
accimage = None
11+
12+
13+
GRACE_HOPPER = 'assets/grace_hopper_512.jpg'
614

715

816
class Tester(unittest.TestCase):
@@ -153,6 +161,51 @@ def test_to_tensor(self):
153161
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
154162
assert np.allclose(output.numpy(), expected_output)
155163

164+
@unittest.skipIf(accimage is None, 'accimage not available')
165+
def test_accimage_to_tensor(self):
166+
trans = transforms.ToTensor()
167+
168+
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
169+
output = trans(accimage.Image(GRACE_HOPPER))
170+
171+
import visdom
172+
vis = visdom.Visdom()
173+
vis.image(expected_output.numpy())
174+
vis.image(output.numpy())
175+
vis.image((expected_output - output).numpy())
176+
177+
self.assertEqual(expected_output.size(), output.size())
178+
assert np.allclose(output.numpy(), expected_output.numpy())
179+
180+
@unittest.skipIf(accimage is None, 'accimage not available')
181+
def test_accimage_resize(self):
182+
trans = transforms.Compose([
183+
transforms.Scale(256, interpolation=Image.LINEAR),
184+
transforms.ToTensor(),
185+
])
186+
187+
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
188+
output = trans(accimage.Image(GRACE_HOPPER))
189+
190+
self.assertEqual(expected_output.size(), output.size())
191+
self.assertLess(np.abs((expected_output - output).mean()), 1e-3)
192+
self.assertLess((expected_output - output).var(), 1e-5)
193+
# note the high absolute tolerance
194+
assert np.allclose(output.numpy(), expected_output.numpy(), atol=5e-2)
195+
196+
@unittest.skipIf(accimage is None, 'accimage not available')
197+
def test_accimage_crop(self):
198+
trans = transforms.Compose([
199+
transforms.CenterCrop(256),
200+
transforms.ToTensor(),
201+
])
202+
203+
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
204+
output = trans(accimage.Image(GRACE_HOPPER))
205+
206+
self.assertEqual(expected_output.size(), output.size())
207+
assert np.allclose(output.numpy(), expected_output.numpy())
208+
156209
def test_tensor_to_pil_image(self):
157210
trans = transforms.ToPILImage()
158211
to_tensor = transforms.ToTensor()

torchvision/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,17 @@
22
from torchvision import datasets
33
from torchvision import transforms
44
from torchvision import utils
5+
6+
7+
image_backend = 'PIL'
8+
"""
9+
Specifies the package used to load images.
10+
11+
Options are 'PIL' and 'accimage'. The :mod:`accimage` package uses the
12+
Intel IPP library. It is generally faster than PIL, but does not support as
13+
many operations.
14+
"""
15+
16+
17+
def get_image_backend():
18+
return image_backend

torchvision/datasets/folder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ def make_dataset(dir, class_to_idx):
3939

4040

4141
def default_loader(path):
42+
from torchvision import get_image_backend
43+
if get_image_backend() == 'accimage':
44+
import accimage
45+
try:
46+
return accimage.Image(path)
47+
except IOError:
48+
# Potentially a decoding problem, fall back to PIL.Image
49+
pass
4250
return Image.open(path).convert('RGB')
4351

4452

torchvision/transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import math
44
import random
55
from PIL import Image, ImageOps
6+
try:
7+
import accimage
8+
except ImportError:
9+
accimage = None
610
import numpy as np
711
import numbers
812
import types
@@ -42,6 +46,12 @@ def __call__(self, pic):
4246
img = torch.from_numpy(pic.transpose((2, 0, 1)))
4347
# backard compability
4448
return img.float().div(255)
49+
50+
if accimage is not None and isinstance(pic, accimage.Image):
51+
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
52+
pic.copyto(nppic)
53+
return torch.from_numpy(nppic)
54+
4555
# handle PIL Image
4656
if pic.mode == 'I':
4757
img = torch.from_numpy(np.array(pic, np.int32, copy=False))

0 commit comments

Comments
 (0)