From 3fcf7753cc8b56156141924659597acad275b31d Mon Sep 17 00:00:00 2001 From: Marat Dukhan Date: Wed, 14 Dec 2016 14:41:37 -0500 Subject: [PATCH 1/3] Minimally support accimage.Image --- test/preprocess-bench.py | 5 +++++ torchvision/datasets/folder.py | 9 ++++++++- torchvision/transforms.py | 10 +++++++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/test/preprocess-bench.py b/test/preprocess-bench.py index 85599362a73..9acf9e71e54 100644 --- a/test/preprocess-bench.py +++ b/test/preprocess-bench.py @@ -6,6 +6,11 @@ import torch.utils.data import torchvision.transforms as transforms import torchvision.datasets as datasets +try: + import accimage + print("Using accimage.Image") +except ImportError: + print("Using PIL.Image") parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 5eb3126ae96..fce90452bbb 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,6 +1,10 @@ import torch.utils.data as data from PIL import Image +try: + import accimage +except ImportError: + accimage = None import os import os.path @@ -47,7 +51,10 @@ def __init__(self, root, transform=None, target_transform=None): def __getitem__(self, index): path, target = self.imgs[index] - img = Image.open(os.path.join(self.root, path)).convert('RGB') + if accimage is None: + img = Image.open(os.path.join(self.root, path)).convert('RGB') + else: + img = accimage.Image(os.path.join(self.root, path)) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 48be812569b..90aa71478e8 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -3,6 +3,10 @@ import math import random from PIL import Image, ImageOps +try: + import accimage +except ImportError: + accimage = None import numpy as np import numbers import types @@ -28,7 +32,11 @@ class ToTensor(object): """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ def __call__(self, pic): - if isinstance(pic, np.ndarray): + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.empty([pic.channels, pic.height, pic.width]) + pic.copyto(nppic) + img = torch.from_numpy(np.transpose(nppic, axes=(1, 2, 0))) + elif isinstance(pic, np.ndarray): # handle numpy array img = torch.from_numpy(pic) else: From 2021a4549b421fe531eb64df792429c75c012811 Mon Sep 17 00:00:00 2001 From: Marat Dukhan Date: Wed, 21 Dec 2016 04:43:09 -0500 Subject: [PATCH 2/3] Fix tensor format when using accimage --- torchvision/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 90aa71478e8..9b42f30bb1d 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -35,7 +35,7 @@ def __call__(self, pic): if accimage is not None and isinstance(pic, accimage.Image): nppic = np.empty([pic.channels, pic.height, pic.width]) pic.copyto(nppic) - img = torch.from_numpy(np.transpose(nppic, axes=(1, 2, 0))) + img = torch.from_numpy(nppic) elif isinstance(pic, np.ndarray): # handle numpy array img = torch.from_numpy(pic) From 8f9df3c4b8baaccdb3cb96fbc5a3eebe5a7b8186 Mon Sep 17 00:00:00 2001 From: Marat Dukhan Date: Wed, 21 Dec 2016 04:44:02 -0500 Subject: [PATCH 3/3] Fallback to PIL if accimage fails in decoding --- torchvision/datasets/folder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index fce90452bbb..9fb528f1614 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -54,7 +54,11 @@ def __getitem__(self, index): if accimage is None: img = Image.open(os.path.join(self.root, path)).convert('RGB') else: - img = accimage.Image(os.path.join(self.root, path)) + try: + img = accimage.Image(os.path.join(self.root, path)) + except IOError: + # Potentially a decoding problem, fall back to PIL.Image + img = Image.open(os.path.join(self.root, path)).convert('RGB') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: