diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 50f23c6f686..d181dd94be2 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -4,16 +4,19 @@ import math import numpy as np -from PIL.Image import NEAREST, BILINEAR, BICUBIC import torch import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_pil as F_pil import torchvision.transforms.functional as F +from torchvision.transforms import InterpolationModes from common_utils import TransformsTester +NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC + + class Tester(TransformsTester): def setUp(self): @@ -365,7 +368,7 @@ def test_adjust_gamma(self): ) def test_resize(self): - script_fn = torch.jit.script(F_t.resize) + script_fn = torch.jit.script(F.resize) tensor, pil_img = self._create_data(26, 36, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) @@ -382,14 +385,14 @@ def test_resize(self): for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]: for interpolation in [BILINEAR, BICUBIC, NEAREST]: - resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation) - resized_pil_img = F_pil.resize(pil_img, size=size, interpolation=interpolation) + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation) + resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) self.assertEqual( resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) ) - if interpolation != NEAREST: + if interpolation not in [NEAREST, ]: # We can not check values if mode = NEAREST, as results are different # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] @@ -407,6 +410,7 @@ def test_resize(self): script_size = [size, ] else: script_size = size + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) @@ -414,17 +418,24 @@ def test_resize(self): batch_tensors, F.resize, size=script_size, interpolation=interpolation ) + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + res1 = F.resize(tensor, size=32, interpolation=2) + res2 = F.resize(tensor, size=32, interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) + def test_resized_crop(self): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity tensor, _ = self._create_data(26, 36, device=self.device) - for i in [0, 2, 3]: - out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i) + + for mode in [NEAREST, BILINEAR, BICUBIC]: + out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) # 2) resize by half and crop a TL corner tensor, _ = self._create_data(26, 36, device=self.device) - out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0) + out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST) expected_out_tensor = tensor[:, :20:2, :30:2] self.assertTrue( expected_out_tensor.equal(out_tensor), @@ -433,17 +444,19 @@ def test_resized_crop(self): batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) self._test_fn_on_batch( - batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=0 + batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST ) def _test_affine_identity_map(self, tensor, scripted_affine): # 1) identity map - out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) self.assertTrue( tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) ) - out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_tensor = scripted_affine( + tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST + ) self.assertTrue( tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) ) @@ -461,13 +474,13 @@ def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine): ] for a, true_tensor in test_configs: out_pil_img = F.affine( - pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device) for fn in [F.affine, scripted_affine]: out_tensor = fn( - tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST ) if true_tensor is not None: self.assertTrue( @@ -496,13 +509,13 @@ def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine): for a in test_configs: out_pil_img = F.affine( - pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.affine, scripted_affine]: out_tensor = fn( - tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST ).cpu() if out_tensor.dtype != torch.uint8: @@ -526,10 +539,10 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine): ] for t in test_configs: - out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -550,13 +563,13 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine): (-45, [-10, -10], 1.2, [4.0, 5.0]), (-90, [0, 0], 1.0, [0.0, 0.0]), ] - for r in [0, ]: + for r in [NEAREST, ]: for a, t, s, sh in test_configs: - out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu() + out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu() if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -605,18 +618,36 @@ def test_affine(self): batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0] ) + tensor, pil_img = data[0] + # assert deprecation warning and non-BC + with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): + res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2) + res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) + + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2) + res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) + + with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"): + res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10) + res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10) + self.assertEqual(res1, res2) + def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): img_size = pil_img.size dt = tensor.dtype - for r in [0, ]: + for r in [NEAREST, ]: for a in range(-180, 180, 17): for e in [True, False]: for c in centers: - out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) + out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.rotate, scripted_rotate]: - out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu() + out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu() if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -673,12 +704,24 @@ def test_rotate(self): center = (20, 22) self._test_fn_on_batch( - batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center + batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center ) + tensor, pil_img = data[0] + # assert deprecation warning and non-BC + with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): + res1 = F.rotate(tensor, 45, resample=2) + res2 = F.rotate(tensor, 45, interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) + + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + res1 = F.rotate(tensor, 45, interpolation=2) + res2 = F.rotate(tensor, 45, interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): dt = tensor.dtype - for r in [0, ]: + for r in [NEAREST, ]: for spoints, epoints in test_configs: out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) @@ -739,9 +782,17 @@ def test_perspective(self): for spoints, epoints in test_configs: self._test_fn_on_batch( - batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0 + batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=NEAREST ) + # assert changed type warning + spoints = [[0, 0], [33, 0], [33, 25], [0, 25]] + epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2) + res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) + def test_gaussian_blur(self): small_image_tensor = torch.from_numpy( np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) diff --git a/test/test_transforms.py b/test/test_transforms.py index d6651816cd2..f113f9ee653 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1492,11 +1492,21 @@ def test_random_rotation(self): t = transforms.RandomRotation((-10, 10)) angle = t.get_params(t.degrees) - self.assertTrue(angle > -10 and angle < 10) + self.assertTrue(-10 < angle < 10) # Checking if RandomRotation can be printed as string t.__repr__() + # assert deprecation warning and non-BC + with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): + t = transforms.RandomRotation((-10, 10), resample=2) + self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) + + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + t = transforms.RandomRotation((-10, 10), interpolation=2) + self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) + def test_random_affine(self): with self.assertRaises(ValueError): @@ -1537,8 +1547,22 @@ def test_random_affine(self): # Checking if RandomAffine can be printed as string t.__repr__() - t = transforms.RandomAffine(10, resample=Image.BILINEAR) - self.assertIn("Image.BILINEAR", t.__repr__()) + t = transforms.RandomAffine(10, interpolation=transforms.InterpolationModes.BILINEAR) + self.assertIn("bilinear", t.__repr__()) + + # assert deprecation warning and non-BC + with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): + t = transforms.RandomAffine(10, resample=2) + self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) + + with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"): + t = transforms.RandomAffine(10, fillcolor=10) + self.assertEqual(t.fill, 10) + + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + t = transforms.RandomAffine(10, interpolation=2) + self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) def test_to_grayscale(self): """Unit tests for grayscale transform""" diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index aafd862d351..ad5d303d36d 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -2,8 +2,7 @@ import torch from torchvision import transforms as T from torchvision.transforms import functional as F - -from PIL.Image import NEAREST, BILINEAR, BICUBIC +from torchvision.transforms import InterpolationModes import numpy as np @@ -12,6 +11,9 @@ from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes +NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC + + class Tester(TransformsTester): def setUp(self): @@ -349,7 +351,7 @@ def test_random_affine(self): for interpolation in [NEAREST, BILINEAR]: transform = T.RandomAffine( degrees=degrees, translate=translate, - scale=scale, shear=shear, resample=interpolation + scale=scale, shear=shear, interpolation=interpolation ) s_transform = torch.jit.script(transform) @@ -368,7 +370,7 @@ def test_random_rotate(self): for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]: for interpolation in [NEAREST, BILINEAR]: transform = T.RandomRotation( - degrees=degrees, resample=interpolation, expand=expand, center=center + degrees=degrees, interpolation=interpolation, expand=expand, center=center ) s_transform = torch.jit.script(transform) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 4d8a0c09e34..b1c2bc187ed 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,6 +1,7 @@ import math import numbers import warnings +from enum import Enum from typing import Any, Optional import numpy as np @@ -19,6 +20,41 @@ from . import functional_tensor as F_t +class InterpolationModes(Enum): + """Interpolation modes + """ + NEAREST = "nearest" + BILINEAR = "bilinear" + BICUBIC = "bicubic" + # For PIL compatibility + BOX = "box" + HAMMING = "hamming" + LANCZOS = "lanczos" + + +# TODO: Once torchscript supports Enums with staticmethod +# this can be put into InterpolationModes as staticmethod +def _interpolation_modes_from_int(i: int) -> InterpolationModes: + inverse_modes_mapping = { + 0: InterpolationModes.NEAREST, + 2: InterpolationModes.BILINEAR, + 3: InterpolationModes.BICUBIC, + 4: InterpolationModes.BOX, + 5: InterpolationModes.HAMMING, + 1: InterpolationModes.LANCZOS, + } + return inverse_modes_mapping[i] + + +pil_modes_mapping = { + InterpolationModes.NEAREST: 0, + InterpolationModes.BILINEAR: 2, + InterpolationModes.BICUBIC: 3, + InterpolationModes.BOX: 4, + InterpolationModes.HAMMING: 5, + InterpolationModes.LANCZOS: 1, +} + _is_pil_image = F_pil._is_pil_image _parse_fill = F_pil._parse_fill @@ -293,7 +329,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return tensor -def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) -> Tensor: +def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = InterpolationModes.BILINEAR) -> Tensor: r"""Resize the input image to the given size. The image can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -307,17 +343,31 @@ def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) -> :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. In torchscript mode size as single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation enum defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. + Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, + ``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. Returns: PIL Image or Tensor: Resized image. """ + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + + if not isinstance(interpolation, InterpolationModes): + raise TypeError("Argument interpolation should be a InterpolationModes") + if not isinstance(img, torch.Tensor): - return F_pil.resize(img, size=size, interpolation=interpolation) + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.resize(img, size=size, interpolation=pil_interpolation) - return F_t.resize(img, size=size, interpolation=interpolation) + return F_t.resize(img, size=size, interpolation=interpolation.value) def scale(*args, **kwargs): @@ -424,7 +474,8 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: def resized_crop( - img: Tensor, top: int, left: int, height: int, width: int, size: List[int], interpolation: int = Image.BILINEAR + img: Tensor, top: int, left: int, height: int, width: int, size: List[int], + interpolation: InterpolationModes = InterpolationModes.BILINEAR ) -> Tensor: """Crop the given image and resize it to desired size. The image can be a PIL Image or a Tensor, in which case it is expected @@ -439,9 +490,12 @@ def resized_crop( height (int): Height of the crop box. width (int): Width of the crop box. size (sequence or int): Desired output size. Same semantics as ``resize``. - interpolation (int, optional): Desired interpolation enum defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. + Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, + ``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + Returns: PIL Image or Tensor: Cropped image. """ @@ -502,7 +556,7 @@ def perspective( img: Tensor, startpoints: List[List[int]], endpoints: List[List[int]], - interpolation: int = 2, + interpolation: InterpolationModes = InterpolationModes.BILINEAR, fill: Optional[int] = None ) -> Tensor: """Perform perspective transform of the given image. @@ -515,8 +569,10 @@ def perspective( ``[top-left, top-right, bottom-right, bottom-left]`` of the original image. endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image. - interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and - ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor @@ -528,10 +584,22 @@ def perspective( coeffs = _get_perspective_coeffs(startpoints, endpoints) + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + + if not isinstance(interpolation, InterpolationModes): + raise TypeError("Argument interpolation should be a InterpolationModes") + if not isinstance(img, torch.Tensor): - return F_pil.perspective(img, coeffs, interpolation=interpolation, fill=fill) + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill) - return F_t.perspective(img, coeffs, interpolation=interpolation, fill=fill) + return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill) def vflip(img: Tensor) -> Tensor: @@ -801,8 +869,9 @@ def _get_inverse_affine_matrix( def rotate( - img: Tensor, angle: float, resample: int = 0, expand: bool = False, - center: Optional[List[int]] = None, fill: Optional[int] = None + img: Tensor, angle: float, interpolation: InterpolationModes = InterpolationModes.NEAREST, + expand: bool = False, center: Optional[List[int]] = None, + fill: Optional[int] = None, resample: Optional[int] = None ) -> Tensor: """Rotate the image by angle. The image can be a PIL Image or a Tensor, in which case it is expected @@ -811,9 +880,10 @@ def rotate( Args: img (PIL Image or Tensor): image to be rotated. angle (float or int): rotation angle value in degrees, counter-clockwise. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -825,6 +895,8 @@ def rotate( Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. + resample (int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:interpolation: instead. Returns: PIL Image or Tensor: Rotated image. @@ -832,14 +904,32 @@ def rotate( .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ + if resample is not None: + warnings.warn( + "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" + ) + interpolation = _interpolation_modes_from_int(resample) + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") if center is not None and not isinstance(center, (list, tuple)): raise TypeError("Argument center should be a sequence") + if not isinstance(interpolation, InterpolationModes): + raise TypeError("Argument interpolation should be a InterpolationModes") + if not isinstance(img, torch.Tensor): - return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill) + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) center_f = [0.0, 0.0] if center is not None: @@ -850,12 +940,13 @@ def rotate( # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) - return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill) + return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) def affine( img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], - resample: int = 0, fillcolor: Optional[int] = None + interpolation: InterpolationModes = InterpolationModes.NEAREST, fill: Optional[int] = None, + resample: Optional[int] = None, fillcolor: Optional[int] = None ) -> Tensor: """Apply affine transformation on the image keeping image center invariant. The image can be a PIL Image or a Tensor, in which case it is expected @@ -869,17 +960,41 @@ def affine( shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction. If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while the second value corresponds to a shear parallel to the y axis. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. - fillcolor (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. + fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:fill: instead. + resample (int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:interpolation: instead. Returns: PIL Image or Tensor: Transformed image. """ + if resample is not None: + warnings.warn( + "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" + ) + interpolation = _interpolation_modes_from_int(resample) + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + + if fillcolor is not None: + warnings.warn( + "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" + ) + fill = fillcolor + if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") @@ -895,6 +1010,9 @@ def affine( if not isinstance(shear, (numbers.Number, (list, tuple))): raise TypeError("Shear should be either a single value or a sequence of two values") + if not isinstance(interpolation, InterpolationModes): + raise TypeError("Argument interpolation should be a InterpolationModes") + if isinstance(angle, int): angle = float(angle) @@ -920,12 +1038,12 @@ def affine( # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine center = [img_size[0] * 0.5, img_size[1] * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) - - return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) - return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) + return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) @torch.jit.unused diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index d76bc7a0027..7e3989f0288 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -474,7 +474,7 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"): @torch.jit.unused -def affine(img, matrix, resample=0, fillcolor=None): +def affine(img, matrix, interpolation=0, fill=None): """PRIVATE METHOD. Apply affine transformation on the PIL Image keeping image center invariant. .. warning:: @@ -485,11 +485,11 @@ def affine(img, matrix, resample=0, fillcolor=None): Args: img (PIL Image): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): An optional resampling filter. See `filters`_ for more information. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + fill (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) Returns: PIL Image: Transformed image. @@ -498,12 +498,12 @@ def affine(img, matrix, resample=0, fillcolor=None): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) output_size = img.size - opts = _parse_fill(fillcolor, img, '5.0.0') - return img.transform(output_size, Image.AFFINE, matrix, resample, **opts) + opts = _parse_fill(fill, img, '5.0.0') + return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts) @torch.jit.unused -def rotate(img, angle, resample=0, expand=False, center=None, fill=None): +def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None): """PRIVATE METHOD. Rotate PIL image by angle. .. warning:: @@ -514,7 +514,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None): Args: img (PIL Image): image to be rotated. angle (float or int): rotation angle value in degrees, counter-clockwise. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): An optional resampling filter. See `filters`_ for more information. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. expand (bool, optional): Optional expansion flag. @@ -538,7 +538,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None): raise TypeError("img should be PIL Image. Got {}".format(type(img))) opts = _parse_fill(fill, img, '5.2.0') - return img.rotate(angle, resample, expand, center, **opts) + return img.rotate(angle, interpolation, expand, center, **opts) @torch.jit.unused diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 42a686bb726..4f3e72a62ce 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -757,7 +757,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con return img -def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: +def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor: r"""PRIVATE METHOD. Resize the input Tensor to the given size. .. warning:: @@ -774,8 +774,8 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. In torchscript mode padding as a single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation. Default is bilinear (=2). Other supported values: - nearest(=0) and bicubic(=3). + interpolation (str): Desired interpolation. Default is "bilinear". Other supported values: + "nearest" and "bicubic". Returns: Tensor: Resized image. @@ -785,16 +785,10 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: if not isinstance(size, (int, tuple, list)): raise TypeError("Got inappropriate size arg") - if not isinstance(interpolation, int): + if not isinstance(interpolation, str): raise TypeError("Got inappropriate interpolation arg") - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - 3: "bicubic", - } - - if interpolation not in _interpolation_modes: + if interpolation not in ["nearest", "bilinear", "bicubic"]: raise ValueError("This interpolation mode is unsupported with Tensor input") if isinstance(size, tuple): @@ -822,16 +816,14 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: if (w <= h and w == size_w) or (h <= w and h == size_h): return img - mode = _interpolation_modes[interpolation] - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) # Define align_corners to avoid warnings - align_corners = False if mode in ["bilinear", "bicubic"] else None + align_corners = False if interpolation in ["bilinear", "bicubic"] else None - img = interpolate(img, size=[size_h, size_w], mode=mode, align_corners=align_corners) + img = interpolate(img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners) - if mode == "bicubic" and out_dtype == torch.uint8: + if interpolation == "bicubic" and out_dtype == torch.uint8: img = img.clamp(min=0, max=255) img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) @@ -842,9 +834,9 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: def _assert_grid_transform_inputs( img: Tensor, matrix: Optional[List[float]], - resample: int, - fillcolor: Optional[int], - _interpolation_modes: Dict[int, str], + interpolation: str, + fill: Optional[int], + supported_interpolation_modes: List[str], coeffs: Optional[List[float]] = None, ): if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): @@ -859,11 +851,11 @@ def _assert_grid_transform_inputs( if coeffs is not None and len(coeffs) != 8: raise ValueError("Argument coeffs should have 8 float values") - if fillcolor is not None: - warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero") + if fill is not None and not (isinstance(fill, (int, float)) and fill == 0): + warnings.warn("Argument fill is not supported for Tensor input. Fill value is zero") - if resample not in _interpolation_modes: - raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample)) + if interpolation not in supported_interpolation_modes: + raise ValueError("Interpolation mode '{}' is unsupported with Tensor input".format(interpolation)) def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]: @@ -931,7 +923,7 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None + img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[int] = None ) -> Tensor: """PRIVATE METHOD. Apply affine transformation on the Tensor image keeping image center invariant. @@ -943,28 +935,21 @@ def affine( Args: img (Tensor): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. - resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: - bilinear(=2). - fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the + interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear". + fill (int, optional): this option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. Returns: Tensor: Transformed image. """ - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - } - - _assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes) + _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) - mode = _interpolation_modes[resample] - return _apply_grid_transform(img, grid, mode) + return _apply_grid_transform(img, grid, interpolation) def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: @@ -993,7 +978,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( - img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None + img: Tensor, matrix: List[float], interpolation: str = "nearest", + expand: bool = False, fill: Optional[int] = None ) -> Tensor: """PRIVATE METHOD. Rotate the Tensor image by angle. @@ -1006,8 +992,7 @@ def rotate( img (Tensor): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation. Translation part (``matrix[2]`` and ``matrix[5]``) should be in pixel coordinates. - resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: - bilinear(=2). + interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear". expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -1021,21 +1006,14 @@ def rotate( .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - } - - _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) + _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) w, h = img.shape[-1], img.shape[-2] ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) - mode = _interpolation_modes[resample] - - return _apply_grid_transform(img, grid, mode) + return _apply_grid_transform(img, grid, interpolation) def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device): @@ -1072,7 +1050,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, def perspective( - img: Tensor, perspective_coeffs: List[float], interpolation: int = 2, fill: Optional[int] = None + img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[int] = None ) -> Tensor: """PRIVATE METHOD. Perform perspective transform of the given Tensor image. @@ -1084,7 +1062,7 @@ def perspective( Args: img (Tensor): Image to be transformed. perspective_coeffs (list of float): perspective transformation coefficients. - interpolation (int): Interpolation type. Default, ``PIL.Image.BILINEAR``. + interpolation (str): Interpolation type. Default, "bilinear". fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. @@ -1094,26 +1072,19 @@ def perspective( if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): raise TypeError('Input img should be Tensor Image') - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - } - _assert_grid_transform_inputs( img, matrix=None, - resample=interpolation, - fillcolor=fill, - _interpolation_modes=_interpolation_modes, + interpolation=interpolation, + fill=fill, + supported_interpolation_modes=["nearest", "bilinear"], coeffs=perspective_coeffs ) ow, oh = img.shape[-1], img.shape[-2] dtype = img.dtype if torch.is_floating_point(img) else torch.float32 grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device) - mode = _interpolation_modes[interpolation] - - return _apply_grid_transform(img, grid, mode) + return _apply_grid_transform(img, grid, interpolation) def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index af74a3188f3..40198ee2cc5 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -6,7 +6,6 @@ from typing import Tuple, List, Optional import torch -from PIL import Image from torch import Tensor try: @@ -15,21 +14,14 @@ accimage = None from . import functional as F +from .functional import InterpolationModes, _interpolation_modes_from_int + __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur"] - -_pil_interpolation_to_str = { - Image.NEAREST: 'PIL.Image.NEAREST', - Image.BILINEAR: 'PIL.Image.BILINEAR', - Image.BICUBIC: 'PIL.Image.BICUBIC', - Image.LANCZOS: 'PIL.Image.LANCZOS', - Image.HAMMING: 'PIL.Image.HAMMING', - Image.BOX: 'PIL.Image.BOX', -} + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationModes"] class Compose: @@ -242,18 +234,30 @@ class Resize(torch.nn.Module): (size * height / width, size). In torchscript mode padding as single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation enum defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and + ``InterpolationModes.BICUBIC`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + """ - def __init__(self, size, interpolation=Image.BILINEAR): + def __init__(self, size, interpolation=InterpolationModes.BILINEAR): super().__init__() if not isinstance(size, (int, Sequence)): raise TypeError("Size should be int or sequence. Got {}".format(type(size))) if isinstance(size, Sequence) and len(size) not in (1, 2): raise ValueError("If size is a sequence, it should have 1 or 2 values") self.size = size + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation def forward(self, img): @@ -267,7 +271,7 @@ def forward(self, img): return F.resize(img, self.size, self.interpolation) def __repr__(self): - interpolate_str = _pil_interpolation_to_str[self.interpolation] + interpolate_str = self.interpolation.value return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) @@ -659,18 +663,28 @@ class RandomPerspective(torch.nn.Module): distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. Default is 0.5. p (float): probability of the image being transformed. Default is 0.5. - interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and - ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. Default is 0. This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. - """ - def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0): + def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationModes.BILINEAR, fill=0): super().__init__() self.p = p + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation self.distortion_scale = distortion_scale self.fill = fill @@ -744,12 +758,15 @@ class RandomResizedCrop(torch.nn.Module): made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). scale (tuple of float): scale range of the cropped image before resizing, relatively to the origin image. ratio (tuple of float): aspect ratio range of the cropped image before resizing. - interpolation (int): Desired interpolation enum defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and + ``InterpolationModes.BICUBIC`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + """ - def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationModes.BILINEAR): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -760,6 +777,14 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation self.scale = scale self.ratio = ratio @@ -824,7 +849,7 @@ def forward(self, img): return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) def __repr__(self): - interpolate_str = _pil_interpolation_to_str[self.interpolation] + interpolate_str = self.interpolation.value format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) @@ -1122,9 +1147,10 @@ class RandomRotation(torch.nn.Module): degrees (sequence or float or int): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). - resample (int, optional): An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. - If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. expand (bool, optional): Optional expansion flag. If true, expands the output to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -1136,13 +1162,31 @@ class RandomRotation(torch.nn.Module): Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. + resample (int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:interpolation: instead. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - def __init__(self, degrees, resample=False, expand=False, center=None, fill=None): + def __init__( + self, degrees, interpolation=InterpolationModes.NEAREST, expand=False, center=None, fill=None, resample=None + ): super().__init__() + if resample is not None: + warnings.warn( + "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" + ) + interpolation = _interpolation_modes_from_int(resample) + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) if center is not None: @@ -1150,7 +1194,7 @@ def __init__(self, degrees, resample=False, expand=False, center=None, fill=None self.center = center - self.resample = resample + self.resample = self.interpolation = interpolation self.expand = expand self.fill = fill @@ -1173,11 +1217,12 @@ def forward(self, img): PIL Image or Tensor: Rotated image. """ angle = self.get_params(self.degrees) - return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill) + return F.rotate(img, angle, self.interpolation, self.expand, self.center, self.fill) def __repr__(self): + interpolate_str = self.interpolation.value format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) - format_string += ', resample={0}'.format(self.resample) + format_string += ', interpolation={0}'.format(interpolate_str) format_string += ', expand={0}'.format(self.expand) if self.center is not None: format_string += ', center={0}'.format(self.center) @@ -1208,19 +1253,47 @@ class RandomAffine(torch.nn.Module): range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. Will not apply shear by default. - resample (int, optional): An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. - fillcolor (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + fill (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. + fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:fill: instead. + resample (int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:interpolation: instead. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, fillcolor=0): + def __init__( + self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationModes.NEAREST, fill=0, + fillcolor=None, resample=None + ): super().__init__() + if resample is not None: + warnings.warn( + "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" + ) + interpolation = _interpolation_modes_from_int(resample) + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + + if fillcolor is not None: + warnings.warn( + "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" + ) + fill = fillcolor + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) if translate is not None: @@ -1242,8 +1315,8 @@ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, else: self.shear = shear - self.resample = resample - self.fillcolor = fillcolor + self.resample = self.interpolation = interpolation + self.fillcolor = self.fill = fill @staticmethod def get_params( @@ -1294,7 +1367,7 @@ def forward(self, img): img_size = F._get_image_size(img) ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) - return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) + return F.affine(img, *ret, interpolation=self.interpolation, fill=self.fill) def __repr__(self): s = '{name}(degrees={degrees}' @@ -1304,13 +1377,13 @@ def __repr__(self): s += ', scale={scale}' if self.shear is not None: s += ', shear={shear}' - if self.resample > 0: - s += ', resample={resample}' - if self.fillcolor != 0: - s += ', fillcolor={fillcolor}' + if self.interpolation != InterpolationModes.NEAREST: + s += ', interpolation={interpolation}' + if self.fill != 0: + s += ', fill={fill}' s += ')' d = dict(self.__dict__) - d['resample'] = _pil_interpolation_to_str[d['resample']] + d['interpolation'] = self.interpolation.value return s.format(name=self.__class__.__name__, **d)