diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py index f1adee687eb..ee988f13934 100644 --- a/torchvision/datasets/svhn.py +++ b/torchvision/datasets/svhn.py @@ -80,7 +80,7 @@ def __init__( # this makes it inconsistent with several loss functions # which expect the class labels to be in the range [0, C-1] np.place(self.labels, self.labels == 10, 0) - self.data = np.transpose(self.data, (3, 2, 0, 1)) + self.data = np.transpose(self.data, (3, 0, 1, 2)) # convert to HWC def __getitem__(self, index: int) -> Tuple[Any, Any]: """ @@ -94,7 +94,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # doing this so that it is consistent with all other datasets # to return a PIL Image - img = Image.fromarray(np.transpose(img, (1, 2, 0))) + img = Image.fromarray(img) if self.transform is not None: img = self.transform(img)