diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index dcc9c5ffa57..8a6daf2dd9f 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -115,9 +115,9 @@ def to_pil_image(pic, mode=None): npimg = npimg[:, :, 0] if npimg.dtype == np.uint8: expected_mode = 'L' - if npimg.dtype == np.int16: + elif npimg.dtype == np.int16: expected_mode = 'I;16' - if npimg.dtype == np.int32: + elif npimg.dtype == np.int32: expected_mode = 'I' elif npimg.dtype == np.float32: expected_mode = 'F'