diff --git a/torchvision/utils.py b/torchvision/utils.py index 5bb4451a43f..589cb34d575 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -16,9 +16,10 @@ def make_grid(tensor, nrow=8, padding=2): for i in range(numImages): tensor[i].copy_(tensorlist[i]) if tensor.dim() == 2: # single image H x W - tensor = torch.view(1, tensor.size(0), tensor.size(1)) - tensor = torch.cat((tensor, tensor, tensor), 0) + tensor = tensor.view(1, tensor.size(0), tensor.size(1)) if tensor.dim() == 3: # single image + if tensor.size(0) == 1: + tensor = torch.cat((tensor, tensor, tensor), 0) return tensor if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images tensor = torch.cat((tensor, tensor, tensor), 1)