From 10c3efaf31d6d286f06819daaa3250dbcdd35fca Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Dec 2020 13:47:53 +0000 Subject: [PATCH 01/18] Invert Transform (#3104) * Adding invert operator. * Make use of the _assert_channels(). * Update upper bound value. --- test/test_functional_tensor.py | 15 ++++++++ test/test_transforms.py | 32 ++++++++++++++++ test/test_transforms_tensor.py | 3 ++ torchvision/transforms/functional.py | 18 +++++++++ torchvision/transforms/functional_pil.py | 20 ++++++++++ torchvision/transforms/functional_tensor.py | 27 +++++++++++++ torchvision/transforms/transforms.py | 42 ++++++++++++++++++++- 7 files changed, 156 insertions(+), 1 deletion(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index c070c5c1d61..0c4997f1499 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -862,6 +862,21 @@ def test_gaussian_blur(self): msg="{}, {}".format(ksize, sigma) ) + def test_invert(self): + script_invert = torch.jit.script(F.invert) + + img_tensor, pil_img = self._create_data(16, 18, device=self.device) + inverted_img = F.invert(img_tensor) + inverted_pil_img = F.invert(pil_img) + self.compareTensorToPIL(inverted_img, inverted_pil_img) + + # scriptable function test + inverted_img_script = script_invert(img_tensor) + self.assertTrue(inverted_img.equal(inverted_img_script)) + + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + self._test_fn_on_batch(batch_tensors, F.invert) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 30749772d6a..d6b8f48959c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1749,6 +1749,38 @@ def test_gaussian_blur_asserts(self): with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"): transforms.GaussianBlur(3, "sigma_string") + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_invert(self): + random_state = random.getstate() + random.seed(42) + img = transforms.ToPILImage()(torch.rand(3, 10, 10)) + inv_img = F.invert(img) + + num_samples = 250 + num_inverts = 0 + for _ in range(num_samples): + out = transforms.RandomInvert()(img) + if out == inv_img: + num_inverts += 1 + + p_value = stats.binom_test(num_inverts, num_samples, p=0.5) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + num_samples = 250 + num_inverts = 0 + for _ in range(num_samples): + out = transforms.RandomInvert(p=0.7)(img) + if out == inv_img: + num_inverts += 1 + + p_value = stats.binom_test(num_inverts, num_samples, p=0.7) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + # Checking if RandomInvert can be printed as string + transforms.RandomInvert().__repr__() + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 326d9dbb651..1b9eadaaff0 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -89,6 +89,9 @@ def test_random_horizontal_flip(self): def test_random_vertical_flip(self): self._test_op('vflip', 'RandomVerticalFlip') + def test_random_invert(self): + self._test_op('invert', 'RandomInvert') + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 72baf021f9d..b64d00138dd 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1178,3 +1178,21 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa if not isinstance(img, torch.Tensor): output = to_pil_image(output) return output + + +def invert(img: Tensor) -> Tensor: + """Invert the colors of a PIL Image or torch Tensor. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image: Color inverted image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.invert(img) + + return F_t.invert(img) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 51d83f0fd63..17c67355535 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -606,3 +606,23 @@ def to_grayscale(img, num_output_channels): raise ValueError('num_output_channels should be either 1 or 3') return img + + +@torch.jit.unused +def invert(img): + """PRIVATE METHOD. Invert the colors of an image. + + .. warning:: + + Module ``transforms.functional_pil`` is private and should not be used in user application. + Please, consider instead using methods from `transforms.functional` module. + + Args: + img (PIL Image): Image to have its colors inverted. + + Returns: + PIL Image: Color inverted image Tensor. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.invert(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 0c72a745bba..ce899efbabf 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1179,3 +1179,30 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) return img + + +def invert(img: Tensor) -> Tensor: + """PRIVATE METHOD. Invert the colors of a grayscale or RGB image. + + .. warning::`` + + Module ``transforms.functional_tensor`` is private and should not be used in user application. + Please, consider instead using methods from `transforms.functional` module. + + Args: + img (Tensor): Image to have its colors inverted in the form [C, H, W]. + + Returns: + Tensor: Color inverted image Tensor. + """ + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + + _assert_channels(img, [1, 3]) + + bound = 1.0 if img.is_floating_point() else 255.0 + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + return (bound - img.to(dtype)).to(img.dtype) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 3b159fd3f22..9295004e4a6 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -21,7 +21,7 @@ "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode"] + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert"] class Compose: @@ -1699,3 +1699,43 @@ def _setup_angle(x, name, req_sizes=(2, )): _check_sequence_input(x, name, req_sizes) return [float(d) for d in x] + + +class RandomInvert(torch.nn.Module): + """Inverts the colors of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + Args: + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + @staticmethod + def get_params() -> float: + """Choose value for random color inversion. + + Returns: + float: Random value which is used to determine whether the random color inversion + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be inverted. + + Returns: + PIL Image or Tensor: Randomly color inverted image. + """ + if self.get_params() < self.p: + return F.invert(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) From cd03c18950516e0f83af7dbec32fb6148266f426 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Dec 2020 15:47:19 +0000 Subject: [PATCH 02/18] Remove private doc from invert, create or reuse generic testing methods to avoid duplication of code in the tests. (#3106) --- test/test_functional_tensor.py | 21 ++++------ test/test_transforms.py | 45 ++++++++++----------- torchvision/transforms/functional_pil.py | 13 ------ torchvision/transforms/functional_tensor.py | 13 ------ 4 files changed, 30 insertions(+), 62 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 0c4997f1499..9a90d2a4ddf 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -863,19 +863,14 @@ def test_gaussian_blur(self): ) def test_invert(self): - script_invert = torch.jit.script(F.invert) - - img_tensor, pil_img = self._create_data(16, 18, device=self.device) - inverted_img = F.invert(img_tensor) - inverted_pil_img = F.invert(pil_img) - self.compareTensorToPIL(inverted_img, inverted_pil_img) - - # scriptable function test - inverted_img_script = script_invert(img_tensor) - self.assertTrue(inverted_img.equal(inverted_img_script)) - - batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) - self._test_fn_on_batch(batch_tensors, F.invert) + self._test_adjust_fn( + F.invert, + F_pil.invert, + F_t.invert, + [{}], + tol=1.0, + agg_method="max" + ) @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") diff --git a/test/test_transforms.py b/test/test_transforms.py index d6b8f48959c..7ef7bfdb88d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1749,37 +1749,36 @@ def test_gaussian_blur_asserts(self): with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"): transforms.GaussianBlur(3, "sigma_string") - @unittest.skipIf(stats is None, 'scipy.stats not available') - def test_random_invert(self): + def _test_randomness(self, fn, trans, configs): random_state = random.getstate() random.seed(42) img = transforms.ToPILImage()(torch.rand(3, 10, 10)) - inv_img = F.invert(img) - num_samples = 250 - num_inverts = 0 - for _ in range(num_samples): - out = transforms.RandomInvert()(img) - if out == inv_img: - num_inverts += 1 + for p in [0.5, 0.7]: + for config in configs: + inv_img = fn(img, **config) - p_value = stats.binom_test(num_inverts, num_samples, p=0.5) - random.setstate(random_state) - self.assertGreater(p_value, 0.0001) + num_samples = 250 + counts = 0 + for _ in range(num_samples): + out = trans(p=p, **config)(img) + if out == inv_img: + counts += 1 - num_samples = 250 - num_inverts = 0 - for _ in range(num_samples): - out = transforms.RandomInvert(p=0.7)(img) - if out == inv_img: - num_inverts += 1 + p_value = stats.binom_test(counts, num_samples, p=p) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) - p_value = stats.binom_test(num_inverts, num_samples, p=0.7) - random.setstate(random_state) - self.assertGreater(p_value, 0.0001) + # Checking if it can be printed as string + trans().__repr__() - # Checking if RandomInvert can be printed as string - transforms.RandomInvert().__repr__() + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_invert(self): + self._test_randomness( + F.invert, + transforms.RandomInvert, + [{}] + ) if __name__ == '__main__': diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 17c67355535..1dd40191cfa 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -610,19 +610,6 @@ def to_grayscale(img, num_output_channels): @torch.jit.unused def invert(img): - """PRIVATE METHOD. Invert the colors of an image. - - .. warning:: - - Module ``transforms.functional_pil`` is private and should not be used in user application. - Please, consider instead using methods from `transforms.functional` module. - - Args: - img (PIL Image): Image to have its colors inverted. - - Returns: - PIL Image: Color inverted image Tensor. - """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return ImageOps.invert(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ce899efbabf..e40851aa4f1 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1182,19 +1182,6 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te def invert(img: Tensor) -> Tensor: - """PRIVATE METHOD. Invert the colors of a grayscale or RGB image. - - .. warning::`` - - Module ``transforms.functional_tensor`` is private and should not be used in user application. - Please, consider instead using methods from `transforms.functional` module. - - Args: - img (Tensor): Image to have its colors inverted in the form [C, H, W]. - - Returns: - Tensor: Color inverted image Tensor. - """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') From 4b800b922433cfaf85f76f8c521afb55a58503cf Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Dec 2020 17:47:31 +0000 Subject: [PATCH 03/18] Create posterize transformation and refactor common methods to assist reuse. (#3108) --- test/test_functional_tensor.py | 16 ++++++- test/test_transforms.py | 15 +++++-- test/test_transforms_tensor.py | 6 +++ torchvision/transforms/functional.py | 21 +++++++++ torchvision/transforms/functional_pil.py | 7 +++ torchvision/transforms/functional_tensor.py | 14 ++++++ torchvision/transforms/transforms.py | 48 +++++++++++++++++++-- 7 files changed, 118 insertions(+), 9 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 9a90d2a4ddf..4df930d4517 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -289,13 +289,14 @@ def test_pad(self): self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs) - def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"): + def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max", + dts=(None, torch.float32, torch.float64)): script_fn = torch.jit.script(fn) torch.manual_seed(15) tensor, pil_img = self._create_data(26, 34, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) - for dt in [None, torch.float32, torch.float64]: + for dt in dts: if dt is not None: tensor = F.convert_image_dtype(tensor, dt) @@ -872,6 +873,17 @@ def test_invert(self): agg_method="max" ) + def test_posterize(self): + self._test_adjust_fn( + F.posterize, + F_pil.posterize, + F_t.posterize, + [{"bits": bits} for bits in range(0, 8)], + tol=1.0, + agg_method="max", + dts=(None,) + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 7ef7bfdb88d..81757510302 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1761,7 +1761,9 @@ def _test_randomness(self, fn, trans, configs): num_samples = 250 counts = 0 for _ in range(num_samples): - out = trans(p=p, **config)(img) + tranformation = trans(p=p, **config) + tranformation.__repr__() + out = tranformation(img) if out == inv_img: counts += 1 @@ -1769,9 +1771,6 @@ def _test_randomness(self, fn, trans, configs): random.setstate(random_state) self.assertGreater(p_value, 0.0001) - # Checking if it can be printed as string - trans().__repr__() - @unittest.skipIf(stats is None, 'scipy.stats not available') def test_random_invert(self): self._test_randomness( @@ -1780,6 +1779,14 @@ def test_random_invert(self): [{}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_posterize(self): + self._test_randomness( + F.posterize, + transforms.RandomPosterize, + [{"bits": 4}] + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 1b9eadaaff0..eba782a75cb 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -92,6 +92,12 @@ def test_random_vertical_flip(self): def test_random_invert(self): self._test_op('invert', 'RandomInvert') + def test_random_posterize(self): + fn_kwargs = meth_kwargs = {"bits": 4} + self._test_op( + 'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index b64d00138dd..dc0c2a2f2bd 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1196,3 +1196,24 @@ def invert(img: Tensor) -> Tensor: return F_pil.invert(img) return F_t.invert(img) + + +def posterize(img: Tensor, bits: int) -> Tensor: + """Posterize a PIL Image or torch Tensor by reducing the number of bits for each color channel. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + bits (int): The number of bits to keep for each channel (0-8). + Returns: + PIL Image: Posterized image. + """ + if not (0 <= bits <= 8): + raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits)) + + if not isinstance(img, torch.Tensor): + return F_pil.posterize(img, bits) + + return F_t.posterize(img, bits) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 1dd40191cfa..2e1b16f26b6 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -613,3 +613,10 @@ def invert(img): if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return ImageOps.invert(img) + + +@torch.jit.unused +def posterize(img, bits): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.posterize(img, bits) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index e40851aa4f1..003f0138b0c 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1193,3 +1193,17 @@ def invert(img: Tensor) -> Tensor: bound = 1.0 if img.is_floating_point() else 255.0 dtype = img.dtype if torch.is_floating_point(img) else torch.float32 return (bound - img.to(dtype)).to(img.dtype) + + +def posterize(img: Tensor, bits: int) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + if img.dtype != torch.uint8: + raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype)) + + _assert_channels(img, [1, 3]) + mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) + return img & mask diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 9295004e4a6..fbe7a23fc61 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -21,7 +21,7 @@ "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert"] + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize"] class Compose: @@ -1717,10 +1717,10 @@ def __init__(self, p=0.5): @staticmethod def get_params() -> float: - """Choose value for random color inversion. + """Choose a value for the random transformation. Returns: - float: Random value which is used to determine whether the random color inversion + float: Random value which is used to determine whether the random transformation should occur. """ return torch.rand(1).item() @@ -1739,3 +1739,45 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomPosterize(torch.nn.Module): + """Posterize the image randomly with a given probability by reducing the + number of bits for each color channel. The image can be a PIL Image or a torch + Tensor, in which case it is expected to have [..., H, W] shape, where ... means + an arbitrary number of leading dimensions + + Args: + bits (int): number of bits to keep for each channel (0-8) + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, bits, p=0.5): + super().__init__() + self.bits = bits + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be posterized. + + Returns: + PIL Image or Tensor: Randomly posterized image. + """ + if self.get_params() < self.p: + return F.posterize(img, self.bits) + return img + + def __repr__(self): + return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) From 63b8a273a326d12362bfe877f8d99a0a8b2d0f16 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Dec 2020 23:23:13 +0000 Subject: [PATCH 04/18] Implement the solarize transform. (#3112) --- test/test_functional_tensor.py | 20 +++++++++ test/test_transforms.py | 8 ++++ test/test_transforms_tensor.py | 6 +++ torchvision/transforms/functional.py | 24 ++++++++-- torchvision/transforms/functional_pil.py | 7 +++ torchvision/transforms/functional_tensor.py | 21 ++++++++- torchvision/transforms/transforms.py | 49 +++++++++++++++++++-- 7 files changed, 128 insertions(+), 7 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 4df930d4517..63e8271a858 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -884,6 +884,26 @@ def test_posterize(self): dts=(None,) ) + def test_solarize(self): + self._test_adjust_fn( + F.solarize, + F_pil.solarize, + F_t.solarize, + [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]], + tol=1.0, + agg_method="max", + dts=(None,) + ) + self._test_adjust_fn( + F.solarize, + lambda img, threshold: F_pil.solarize(img, 255 * threshold), + F_t.solarize, + [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]], + tol=1.0, + agg_method="max", + dts=(torch.float32, torch.float64) + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 81757510302..fc52fc66686 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1787,6 +1787,14 @@ def test_random_posterize(self): [{"bits": 4}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_solarize(self): + self._test_randomness( + F.solarize, + transforms.RandomSolarize, + [{"threshold": 192}] + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index eba782a75cb..331f8a2eb4f 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -98,6 +98,12 @@ def test_random_posterize(self): 'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_random_solarize(self): + fn_kwargs = meth_kwargs = {"threshold": 192.0} + self._test_op( + 'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index dc0c2a2f2bd..e3b0a9bd98a 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1203,9 +1203,9 @@ def posterize(img: Tensor, bits: int) -> Tensor: Args: img (PIL Image or Tensor): Image to have its colors inverted. - If img is a Tensor, it is expected to be in [..., H, W] format, - where ... means it can have an arbitrary number of trailing - dimensions. + If img is a Tensor, it should be of type torch.uint8 and + it is expected to be in [..., H, W] format, where ... means + it can have an arbitrary number of trailing dimensions. bits (int): The number of bits to keep for each channel (0-8). Returns: PIL Image: Posterized image. @@ -1217,3 +1217,21 @@ def posterize(img: Tensor, bits: int) -> Tensor: return F_pil.posterize(img, bits) return F_t.posterize(img, bits) + + +def solarize(img: Tensor, threshold: float) -> Tensor: + """Solarize a PIL Image or torch Tensor by inverting all pixel values above a threshold. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + threshold (float): All pixels equal or above this value are inverted. + Returns: + PIL Image: Solarized image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.solarize(img, threshold) + + return F_t.solarize(img, threshold) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 2e1b16f26b6..d60588fd138 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -620,3 +620,10 @@ def posterize(img, bits): if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return ImageOps.posterize(img, bits) + + +@torch.jit.unused +def solarize(img, threshold): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.solarize(img, threshold) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 003f0138b0c..5eb70988f90 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1192,7 +1192,7 @@ def invert(img: Tensor) -> Tensor: bound = 1.0 if img.is_floating_point() else 255.0 dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - return (bound - img.to(dtype)).to(img.dtype) + return (bound - img.to(dtype)).clamp(0, bound).to(img.dtype) def posterize(img: Tensor, bits: int) -> Tensor: @@ -1207,3 +1207,22 @@ def posterize(img: Tensor, bits: int) -> Tensor: _assert_channels(img, [1, 3]) mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) return img & mask + + +def solarize(img: Tensor, threshold: float) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + + _assert_channels(img, [1, 3]) + + bound = 1.0 if img.is_floating_point() else 255.0 + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + result = img.clone().view(-1) + invert_idx = torch.where(result >= threshold)[0] + result[invert_idx] = (bound - result[invert_idx].to(dtype=dtype)).clamp(0, bound).to(dtype=img.dtype) + + return result.view(img.shape) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index fbe7a23fc61..66ccb42e525 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -21,7 +21,8 @@ "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize"] + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", + "RandomSolarize"] class Compose: @@ -1705,7 +1706,7 @@ class RandomInvert(torch.nn.Module): """Inverts the colors of the given image randomly with a given probability. The image can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading - dimensions + dimensions. Args: p (float): probability of the image being color inverted. Default value is 0.5 @@ -1745,7 +1746,7 @@ class RandomPosterize(torch.nn.Module): """Posterize the image randomly with a given probability by reducing the number of bits for each color channel. The image can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] shape, where ... means - an arbitrary number of leading dimensions + an arbitrary number of leading dimensions. Args: bits (int): number of bits to keep for each channel (0-8) @@ -1781,3 +1782,45 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) + + +class RandomSolarize(torch.nn.Module): + """Solarize the image randomly with a given probability by inverting all pixel + values above a threshold. The image can be a PIL Image or a torch Tensor, in + which case it is expected to have [..., H, W] shape, where ... means an arbitrary + number of leading dimensions. + + Args: + threshold (float): all pixels equal or above this value are inverted. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, threshold, p=0.5): + super().__init__() + self.threshold = threshold + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be solarized. + + Returns: + PIL Image or Tensor: Randomly solarized image. + """ + if self.get_params() < self.p: + return F.solarize(img, self.threshold) + return img + + def __repr__(self): + return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) From b4e9a2fedc4f939ff71f4359ea33855b3fe319ab Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 4 Dec 2020 02:32:11 +0000 Subject: [PATCH 05/18] Implement the adjust_sharpness transform (#3114) * Adding functional operator for sharpness. * Adding transforms for sharpness. * Handling tiny images and adding a test. --- test/test_functional_tensor.py | 8 ++++ test/test_transforms.py | 45 ++++++++++++++++++++- test/test_transforms_tensor.py | 8 +++- torchvision/transforms/functional.py | 18 +++++++++ torchvision/transforms/functional_pil.py | 10 +++++ torchvision/transforms/functional_tensor.py | 36 +++++++++++++++++ torchvision/transforms/transforms.py | 32 ++++++++++----- 7 files changed, 145 insertions(+), 12 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 63e8271a858..715450b5cfd 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -904,6 +904,14 @@ def test_solarize(self): dts=(torch.float32, torch.float64) ) + def test_adjust_sharpness(self): + self._test_adjust_fn( + F.adjust_sharpness, + F_pil.adjust_sharpness, + F_t.adjust_sharpness, + [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]] + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index fc52fc66686..58ffa93f6e2 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1232,6 +1232,48 @@ def test_adjust_hue(self): y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) self.assertTrue(np.allclose(y_np, y_ans)) + def test_adjust_sharpness(self): + x_shape = [4, 4, 3] + x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, + 0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105, + 111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = F.adjust_sharpness(x_pil, 1) + y_np = np.array(y_pil) + self.assertTrue(np.allclose(y_np, x_np)) + + # test 1 + y_pil = F.adjust_sharpness(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30, + 30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101, + 107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 2 + y_pil = F.adjust_sharpness(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, + 0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112, + 119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 3 + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_th = torch.tensor(x_np.transpose(2, 0, 1)) + y_pil = F.adjust_sharpness(x_pil, 2) + y_np = np.array(y_pil).transpose(2, 0, 1) + y_th = F.adjust_sharpness(x_th, 2) + self.assertTrue(np.allclose(y_np, y_th.numpy())) + def test_adjust_gamma(self): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] @@ -1268,10 +1310,11 @@ def test_adjusts_L_mode(self): self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L') self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L') self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L') + self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L') self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L') def test_color_jitter(self): - color_jitter = transforms.ColorJitter(2, 2, 2, 0.1) + color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2) x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 331f8a2eb4f..30c5b885bb8 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -131,8 +131,14 @@ def test_color_jitter(self): "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max" ) + for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]: + meth_kwargs = {"sharpness": f} + self._test_class_op( + "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" + ) + # All 4 parameters together - meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} + meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2} self._test_class_op( "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max" ) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index e3b0a9bd98a..8383b08364b 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1235,3 +1235,21 @@ def solarize(img: Tensor, threshold: float) -> Tensor: return F_pil.solarize(img, threshold) return F_t.solarize(img, threshold) + + +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: + """Adjust the sharpness of an Image. + + Args: + img (PIL Image or Tensor): Image to be adjusted. + sharpness_factor (float): How much to adjust the sharpness. Can be + any non negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + + Returns: + PIL Image or Tensor: Sharpness adjusted image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.adjust_sharpness(img, sharpness_factor) + + return F_t.adjust_sharpness(img, sharpness_factor) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index d60588fd138..72eafe37a2c 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -627,3 +627,13 @@ def solarize(img, threshold): if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return ImageOps.solarize(img, threshold) + + +@torch.jit.unused +def adjust_sharpness(img, sharpness_factor): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Sharpness(img) + img = enhancer.enhance(sharpness_factor) + return img diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 5eb70988f90..70e1cdc7833 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1226,3 +1226,39 @@ def solarize(img: Tensor, threshold: float) -> Tensor: result[invert_idx] = (bound - result[invert_idx].to(dtype=dtype)).clamp(0, bound).to(dtype=img.dtype) return result.view(img.shape) + + +def _blur_image(img: Tensor) -> Tensor: + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + kernel = torch.ones((3, 3), dtype=dtype, device=img.device) + kernel[1, 1] = 5.0 + kernel /= kernel.sum() + kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + result, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) + result = conv2d(result, kernel, groups=result.shape[-3]) + result = torch_pad(result, [1, 1, 1, 1]) + result = _cast_squeeze_out(result, need_cast, need_squeeze, out_dtype) + + result[..., 0, :] = img[..., 0, :] + result[..., -1, :] = img[..., -1, :] + result[..., :, 0] = img[..., :, 0] + result[..., :, -1] = img[..., :, -1] + + return result + + +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: + if sharpness_factor < 0: + raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor)) + + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + _assert_channels(img, [1, 3]) + + if img.size(-1) <= 2 or img.size(-2) <= 2: + return img + + return _blend(img, _blur_image(img), sharpness_factor) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 66ccb42e525..2a854933ebd 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1039,7 +1039,7 @@ def __repr__(self): class ColorJitter(torch.nn.Module): - """Randomly change the brightness, contrast and saturation of an image. + """Randomly change the brightness, contrast, saturation, hue and sharpness of an image. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. @@ -1054,15 +1054,19 @@ class ColorJitter(torch.nn.Module): hue (float or tuple of float (min, max)): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + sharpness (float or tuple of float (min, max)): How much to jitter sharpness. + sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness] + or the given [min, max]. Should be non negative numbers. """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0): super().__init__() self.brightness = self._check_input(brightness, 'brightness') self.contrast = self._check_input(contrast, 'contrast') self.saturation = self._check_input(saturation, 'saturation') self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + self.sharpness = self._check_input(sharpness, 'sharpness') @torch.jit.unused def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): @@ -1078,7 +1082,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs else: raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) - # if value is 0 or (1., 1.) for brightness/contrast/saturation + # if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: value = None @@ -1088,8 +1092,10 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs def get_params(brightness: Optional[List[float]], contrast: Optional[List[float]], saturation: Optional[List[float]], - hue: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + hue: Optional[List[float]], + sharpness: Optional[List[float]] + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float], + Optional[float]]: """Get the parameters for the randomized transform to be applied on image. Args: @@ -1101,19 +1107,22 @@ def get_params(brightness: Optional[List[float]], uniformly. Pass None to turn off the transformation. hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. Pass None to turn off the transformation. + sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen + uniformly. Pass None to turn off the transformation. Returns: tuple: The parameters used to apply the randomized transform along with their random order. """ - fn_idx = torch.randperm(4) + fn_idx = torch.randperm(5) b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1])) - return fn_idx, b, c, s, h + return fn_idx, b, c, s, h, sp def forward(self, img): """ @@ -1123,8 +1132,8 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \ + self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: @@ -1135,6 +1144,8 @@ def forward(self, img): img = F.adjust_saturation(img, saturation_factor) elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) + elif fn_id == 4 and sharpness_factor is not None: + img = F.adjust_sharpness(img, sharpness_factor) return img @@ -1143,7 +1154,8 @@ def __repr__(self): format_string += 'brightness={0}'.format(self.brightness) format_string += ', contrast={0}'.format(self.contrast) format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0})'.format(self.hue) + format_string += ', hue={0}'.format(self.hue) + format_string += ', sharpness={0})'.format(self.sharpness) return format_string From 94fc57326c1d2fafba0dffc222e9d28604624b67 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 4 Dec 2020 13:47:40 +0000 Subject: [PATCH 06/18] Implement the autocontrast transform. (#3117) --- test/test_functional_tensor.py | 10 +++++ test/test_transforms.py | 8 ++++ test/test_transforms_tensor.py | 3 ++ torchvision/transforms/functional.py | 26 +++++++++++-- torchvision/transforms/functional_pil.py | 7 ++++ torchvision/transforms/functional_tensor.py | 22 +++++++++++ torchvision/transforms/transforms.py | 42 ++++++++++++++++++++- 7 files changed, 114 insertions(+), 4 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 715450b5cfd..fa6297fa5ef 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -912,6 +912,16 @@ def test_adjust_sharpness(self): [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]] ) + def test_autocontrast(self): + self._test_adjust_fn( + F.autocontrast, + F_pil.autocontrast, + F_t.autocontrast, + [{}], + tol=1.0, + agg_method="max" + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 58ffa93f6e2..81104c10f21 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1838,6 +1838,14 @@ def test_random_solarize(self): [{"threshold": 192}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_autocontrast(self): + self._test_randomness( + F.autocontrast, + transforms.RandomAutocontrast, + [{}] + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 30c5b885bb8..2c36664a517 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -104,6 +104,9 @@ def test_random_solarize(self): 'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_random_autocontrast(self): + self._test_op('autocontrast', 'RandomAutocontrast') + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8383b08364b..d401aa4cc90 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1190,7 +1190,7 @@ def invert(img: Tensor) -> Tensor: dimensions. Returns: - PIL Image: Color inverted image. + PIL Image or Tensor: Color inverted image. """ if not isinstance(img, torch.Tensor): return F_pil.invert(img) @@ -1208,7 +1208,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: it can have an arbitrary number of trailing dimensions. bits (int): The number of bits to keep for each channel (0-8). Returns: - PIL Image: Posterized image. + PIL Image or Tensor: Posterized image. """ if not (0 <= bits <= 8): raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits)) @@ -1229,7 +1229,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor: dimensions. threshold (float): All pixels equal or above this value are inverted. Returns: - PIL Image: Solarized image. + PIL Image or Tensor: Solarized image. """ if not isinstance(img, torch.Tensor): return F_pil.solarize(img, threshold) @@ -1253,3 +1253,23 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: return F_pil.adjust_sharpness(img, sharpness_factor) return F_t.adjust_sharpness(img, sharpness_factor) + + +def autocontrast(img: Tensor) -> Tensor: + """Maximize contrast of a PIL Image or torch Tensor by remapping its + pixels per channel so that the lowest becomes black and the lightest + becomes white. + + Args: + img (PIL Image or Tensor): Image on which autocontrast is applied. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image or Tensor: An image that was autocontrasted. + """ + if not isinstance(img, torch.Tensor): + return F_pil.autocontrast(img) + + return F_t.autocontrast(img) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 72eafe37a2c..14f91713aca 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -637,3 +637,10 @@ def adjust_sharpness(img, sharpness_factor): enhancer = ImageEnhance.Sharpness(img) img = enhancer.enhance(sharpness_factor) return img + + +@torch.jit.unused +def autocontrast(img): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.autocontrast(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 70e1cdc7833..8a0432fa456 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1262,3 +1262,25 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: return img return _blend(img, _blur_image(img), sharpness_factor) + + +def autocontrast(img: Tensor) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + + _assert_channels(img, [1, 3]) + + bound = 1.0 if img.is_floating_point() else 255.0 + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + minimum = img.amin(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype) + maximum = img.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype) + eq_idxs = torch.where(minimum == maximum)[0] + minimum[eq_idxs] = 0 + maximum[eq_idxs] = bound + scale = bound / (maximum - minimum) + + return ((img.to(dtype) - minimum) * scale).clamp(0, bound).to(img.dtype) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2a854933ebd..963806b9962 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize"] + "RandomSolarize", "RandomAutocontrast"] class Compose: @@ -1836,3 +1836,43 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) + + +class RandomAutocontrast(torch.nn.Module): + """Autocontrast the pixels of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions. + + Args: + p (float): probability of the image being autocontrasted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be autocontrasted. + + Returns: + PIL Image or Tensor: Randomly autocontrasted image. + """ + if self.get_params() < self.p: + return F.autocontrast(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) From 64a3e1bad646bcd8c5cb807486b1bf87a8c1f81f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 4 Dec 2020 20:32:43 +0000 Subject: [PATCH 07/18] Implement the equalize transform (#3119) * Implement the equalize transform. * Turn off deterministic for histogram. --- test/test_functional_tensor.py | 10 +++++ test/test_transforms.py | 10 ++++- test/test_transforms_tensor.py | 4 ++ torchvision/transforms/functional.py | 20 ++++++++++ torchvision/transforms/functional_pil.py | 7 ++++ torchvision/transforms/functional_tensor.py | 38 +++++++++++++++++++ torchvision/transforms/transforms.py | 42 ++++++++++++++++++++- 7 files changed, 129 insertions(+), 2 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index fa6297fa5ef..0e039be6041 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -922,6 +922,16 @@ def test_autocontrast(self): agg_method="max" ) + def test_equalize(self): + torch.set_deterministic(False) + self._test_adjust_fn( + F.equalize, + F_pil.equalize, + F_t.equalize, + [{}], + dts=(None,) + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 81104c10f21..5defce28588 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1795,7 +1795,7 @@ def test_gaussian_blur_asserts(self): def _test_randomness(self, fn, trans, configs): random_state = random.getstate() random.seed(42) - img = transforms.ToPILImage()(torch.rand(3, 10, 10)) + img = transforms.ToPILImage()(torch.rand(3, 16, 18)) for p in [0.5, 0.7]: for config in configs: @@ -1846,6 +1846,14 @@ def test_random_autocontrast(self): [{}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_equalize(self): + self._test_randomness( + F.equalize, + transforms.RandomEqualize, + [{}] + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 2c36664a517..7af1f1d4c46 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -107,6 +107,10 @@ def test_random_solarize(self): def test_random_autocontrast(self): self._test_op('autocontrast', 'RandomAutocontrast') + def test_random_equalize(self): + torch.set_deterministic(False) + self._test_op('equalize', 'RandomEqualize') + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d401aa4cc90..948638f3dd8 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1273,3 +1273,23 @@ def autocontrast(img: Tensor) -> Tensor: return F_pil.autocontrast(img) return F_t.autocontrast(img) + + +def equalize(img: Tensor) -> Tensor: + """Equalize the histogram of a PIL Image or torch Tensor by applying + a non-linear mapping to the input in order to create a uniform + distribution of grayscale values in the output. + + Args: + img (PIL Image or Tensor): Image on which equalize is applied. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image or Tensor: An image that was equalized. + """ + if not isinstance(img, torch.Tensor): + return F_pil.equalize(img) + + return F_t.equalize(img) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 14f91713aca..26f3b504d99 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -644,3 +644,10 @@ def autocontrast(img): if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return ImageOps.autocontrast(img) + + +@torch.jit.unused +def equalize(img): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.equalize(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 8a0432fa456..47437385828 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1284,3 +1284,41 @@ def autocontrast(img: Tensor) -> Tensor: scale = bound / (maximum - minimum) return ((img.to(dtype) - minimum) * scale).clamp(0, bound).to(img.dtype) + + +def _scale_channel(img_chan): + hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) + + nonzero_hist = hist[hist != 0] + if nonzero_hist.numel() > 0: + step = (nonzero_hist.sum() - nonzero_hist[-1]) // 255 + else: + step = torch.tensor(0, device=img_chan.device) + if step == 0: + return img_chan + + lut = (torch.cumsum(hist, 0) + (step // 2)) // step + lut = torch.cat([torch.zeros(1, device=img_chan.device), lut[:-1]]).clamp(0, 255) + + return lut[img_chan.to(torch.int64)].to(torch.uint8) + + +def _equalize_single_image(img: Tensor) -> Tensor: + return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))]) + + +def equalize(img: Tensor) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if not (3 <= img.ndim <= 4): + raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim)) + if img.dtype != torch.uint8: + raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype)) + + _assert_channels(img, [1, 3]) + + if img.ndim == 3: + return _equalize_single_image(img) + + return torch.stack([_equalize_single_image(x) for x in img]) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 963806b9962..f4416b36acd 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAutocontrast"] + "RandomSolarize", "RandomAutocontrast", "RandomEqualize"] class Compose: @@ -1876,3 +1876,43 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomEqualize(torch.nn.Module): + """Equalize the histogram of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions. + + Args: + p (float): probability of the image being equalized. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be equalized. + + Returns: + PIL Image or Tensor: Randomly equalized image. + """ + if self.get_params() < self.p: + return F.equalize(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) From 05cf5674ac43541334a7274b9cf4540fd90be20f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 5 Dec 2020 21:07:40 +0000 Subject: [PATCH 08/18] Fixing test. (#3126) --- test/test_functional_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 0e039be6041..a651c0e9f38 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -929,6 +929,8 @@ def test_equalize(self): F_pil.equalize, F_t.equalize, [{}], + tol=1.0, + agg_method="max", dts=(None,) ) From c7337b9a1ac10030001c5d2365cb1f77cb2058aa Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 6 Dec 2020 18:09:32 +0000 Subject: [PATCH 09/18] Force ratio to be float to avoid numeric overflows on blend. (#3127) --- torchvision/transforms/functional_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index deceb7314db..a123b4f5694 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -570,6 +570,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: + ratio = float(ratio) bound = 1.0 if img1.is_floating_point() else 255.0 return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) From ff4bfbbcc3cb34ea3f59a4cdf7f4df4d218cbc2e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 6 Dec 2020 19:06:47 +0000 Subject: [PATCH 10/18] Separate the tests of Adjust Sharpness from ColorJitter. (#3128) --- test/test_transforms.py | 10 +++- test/test_transforms_tensor.py | 14 ++--- torchvision/transforms/transforms.py | 77 +++++++++++++++++++--------- 3 files changed, 70 insertions(+), 31 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fbfab98be45..f2ca0dc9d1e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1316,7 +1316,7 @@ def test_adjusts_L_mode(self): self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L') def test_color_jitter(self): - color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2) + color_jitter = transforms.ColorJitter(2, 2, 2, 0.1) x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] @@ -1840,6 +1840,14 @@ def test_random_solarize(self): [{"threshold": 192}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_adjust_sharpness(self): + self._test_randomness( + F.adjust_sharpness, + transforms.RandomAdjustSharpness, + [{"sharpness_factor": 2.0}] + ) + @unittest.skipIf(stats is None, 'scipy.stats not available') def test_random_autocontrast(self): self._test_randomness( diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 7af1f1d4c46..ea3f818ad45 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -104,6 +104,12 @@ def test_random_solarize(self): 'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_random_adjust_sharpness(self): + fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0} + self._test_op( + 'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + def test_random_autocontrast(self): self._test_op('autocontrast', 'RandomAutocontrast') @@ -138,14 +144,8 @@ def test_color_jitter(self): "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max" ) - for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]: - meth_kwargs = {"sharpness": f} - self._test_class_op( - "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" - ) - # All 4 parameters together - meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2} + meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} self._test_class_op( "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max" ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index f4416b36acd..c6322ef71d4 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAutocontrast", "RandomEqualize"] + "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] class Compose: @@ -1039,7 +1039,7 @@ def __repr__(self): class ColorJitter(torch.nn.Module): - """Randomly change the brightness, contrast, saturation, hue and sharpness of an image. + """Randomly change the brightness, contrast, saturation and hue of an image. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. @@ -1054,19 +1054,15 @@ class ColorJitter(torch.nn.Module): hue (float or tuple of float (min, max)): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. - sharpness (float or tuple of float (min, max)): How much to jitter sharpness. - sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness] - or the given [min, max]. Should be non negative numbers. """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): super().__init__() self.brightness = self._check_input(brightness, 'brightness') self.contrast = self._check_input(contrast, 'contrast') self.saturation = self._check_input(saturation, 'saturation') self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) - self.sharpness = self._check_input(sharpness, 'sharpness') @torch.jit.unused def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): @@ -1082,7 +1078,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs else: raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) - # if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness + # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: value = None @@ -1092,10 +1088,8 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs def get_params(brightness: Optional[List[float]], contrast: Optional[List[float]], saturation: Optional[List[float]], - hue: Optional[List[float]], - sharpness: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float], - Optional[float]]: + hue: Optional[List[float]] + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: """Get the parameters for the randomized transform to be applied on image. Args: @@ -1107,22 +1101,19 @@ def get_params(brightness: Optional[List[float]], uniformly. Pass None to turn off the transformation. hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. Pass None to turn off the transformation. - sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen - uniformly. Pass None to turn off the transformation. Returns: tuple: The parameters used to apply the randomized transform along with their random order. """ - fn_idx = torch.randperm(5) + fn_idx = torch.randperm(4) b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) - sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1])) - return fn_idx, b, c, s, h, sp + return fn_idx, b, c, s, h def forward(self, img): """ @@ -1132,8 +1123,8 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness) + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ + self.get_params(self.brightness, self.contrast, self.saturation, self.hue) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: @@ -1144,8 +1135,6 @@ def forward(self, img): img = F.adjust_saturation(img, saturation_factor) elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) - elif fn_id == 4 and sharpness_factor is not None: - img = F.adjust_sharpness(img, sharpness_factor) return img @@ -1154,8 +1143,7 @@ def __repr__(self): format_string += 'brightness={0}'.format(self.brightness) format_string += ', contrast={0}'.format(self.contrast) format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0}'.format(self.hue) - format_string += ', sharpness={0})'.format(self.sharpness) + format_string += ', hue={0})'.format(self.hue) return format_string @@ -1838,6 +1826,49 @@ def __repr__(self): return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) +class RandomAdjustSharpness(torch.nn.Module): + """Adjust the sharpness of the image randomly with a given probability. The image + can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] + shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness_factor (float): How much to adjust the sharpness. Can be + any non negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, sharpness_factor, p=0.5): + super().__init__() + self.sharpness_factor = sharpness_factor + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be sharpened. + + Returns: + PIL Image or Tensor: Randomly sharpened image. + """ + if self.get_params() < self.p: + return F.adjust_sharpness(img, self.sharpness_factor) + return img + + def __repr__(self): + return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) + + class RandomAutocontrast(torch.nn.Module): """Autocontrast the pixels of the given image randomly with a given probability. The image can be a PIL Image or a torch Tensor, in which case it is expected From 70f204289b364f61951f266d59e7320839b89d01 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 8 Dec 2020 16:33:26 +0000 Subject: [PATCH 11/18] Add AutoAugment Policies and main Transform (#3142) * Separate the tests of Adjust Sharpness from ColorJitter. * Initial implementation, not-jitable. * AutoAugment passing JIT. * Adding tests/docs, changing formatting. * Update test. * Fix formats * Fix documentation and imports. --- test/test_transforms.py | 10 ++ test/test_transforms_tensor.py | 16 ++ torchvision/transforms/__init__.py | 1 + torchvision/transforms/autoaugment.py | 232 ++++++++++++++++++++++++++ 4 files changed, 259 insertions(+) create mode 100644 torchvision/transforms/autoaugment.py diff --git a/test/test_transforms.py b/test/test_transforms.py index f2ca0dc9d1e..b3c82334d14 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1864,6 +1864,16 @@ def test_random_equalize(self): [{}] ) + def test_autoaugment(self): + for policy in transforms.AutoAugmentPolicy: + for fill in [None, 85, (128, 128, 128)]: + random.seed(42) + img = Image.open(GRACE_HOPPER) + transform = transforms.AutoAugment(policy=policy, fill=fill) + for _ in range(100): + img = transform(img) + transform.__repr__() + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index d1602ce878f..22a7c065122 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -626,6 +626,22 @@ def test_convert_image_dtype(self): with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt")) + def test_autoaugment(self): + tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device) + batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) + + for policy in T.AutoAugmentPolicy: + for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]: + for _ in range(100): + transform = T.AutoAugment(policy=policy, fill=fill) + s_transform = torch.jit.script(transform) + + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + + with get_tmp_dir() as tmp_dir: + s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 7986cdd6429..77680a14f0d 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1 +1,2 @@ from .transforms import * +from .autoaugment import * diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py new file mode 100644 index 00000000000..4cdf219c22c --- /dev/null +++ b/torchvision/transforms/autoaugment.py @@ -0,0 +1,232 @@ +import math +import torch + +from enum import Enum +from torch import Tensor +from torch.jit.annotations import List, Tuple +from typing import Optional + +from . import functional as F + + +class AutoAugmentPolicy(Enum): + """AutoAugment policies learned on different datasets. + """ + IMAGENET = "imagenet" + CIFAR10 = "cifar10" + SVHN = "svhn" + + +class AutoAugment(torch.nn.Module): + r"""AutoAugment data augmentation method based on + `"AutoAugment: Learning Augmentation Strategies from Data" `_. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + policy (AutoAugmentPolicy): Desired policy enum defined by + :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. + fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed + image. If int or float, the value is used for all bands respectively. + This option is supported for PIL image and Tensor inputs. + If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. + + Example: + >>> t = transforms.AutoAugment() + >>> transformed = t(image) + + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> transforms.AutoAugment(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: Optional[List[float]] = None): + super().__init__() + self.policy = policy + self.fill = fill + if policy == AutoAugmentPolicy.IMAGENET: + self.policies = [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + self.policies = [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + self.policies = [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: + raise ValueError("The provided policy {} is not recognized.".format(policy)) + + _BINS = 10 + self._op_meta = { + # name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, _BINS), True), + "ShearY": (torch.linspace(0.0, 0.3, _BINS), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "Rotate": (torch.linspace(0.0, 30.0, _BINS), True), + "Brightness": (torch.linspace(0.0, 0.9, _BINS), True), + "Color": (torch.linspace(0.0, 0.9, _BINS), True), + "Contrast": (torch.linspace(0.0, 0.9, _BINS), True), + "Sharpness": (torch.linspace(0.0, 0.9, _BINS), True), + "Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False), + "Solarize": (torch.linspace(256.0, 0.0, _BINS), False), + "AutoContrast": (None, None), + "Equalize": (None, None), + "Invert": (None, None), + } + + @staticmethod + def get_params(policy_num: int) -> Tuple[int, Tensor, Tensor]: + """Get parameters for autoaugment transformation + + Returns: + params required by the autoaugment transformation + """ + policy_id = torch.randint(policy_num, (1,)).item() + probs = torch.rand((2,)) + signs = torch.randint(2, (2,)) + + return policy_id, probs, signs + + def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]: + return self._op_meta[name] + + def forward(self, img: Tensor): + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: AutoAugmented image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F._get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + policy_id, probs, signs = self.get_params(len(self.policies)) + + for i, (op_name, p, magnitude_id) in enumerate(self.policies[policy_id]): + if probs[i] <= p: + magnitudes, signed = self._get_op_meta(op_name) + magnitude = float(magnitudes[magnitude_id].item()) \ + if magnitudes is not None and magnitude_id is not None else 0.0 + if signed is not None and signed and signs[i] == 0: + magnitude *= -1.0 + + if op_name == "ShearX": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], + fill=fill) + elif op_name == "ShearY": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], + fill=fill) + elif op_name == "TranslateX": + img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0, + shear=[0.0, 0.0], fill=fill) + elif op_name == "TranslateY": + img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0, + shear=[0.0, 0.0], fill=fill) + elif op_name == "Rotate": + img = F.rotate(img, magnitude, fill=fill) + elif op_name == "Brightness": + img = F.adjust_brightness(img, 1.0 + magnitude) + elif op_name == "Color": + img = F.adjust_saturation(img, 1.0 + magnitude) + elif op_name == "Contrast": + img = F.adjust_contrast(img, 1.0 + magnitude) + elif op_name == "Sharpness": + img = F.adjust_sharpness(img, 1.0 + magnitude) + elif op_name == "Posterize": + img = F.posterize(img, int(magnitude)) + elif op_name == "Solarize": + img = F.solarize(img, magnitude) + elif op_name == "AutoContrast": + img = F.autocontrast(img) + elif op_name == "Equalize": + img = F.equalize(img) + elif op_name == "Invert": + img = F.invert(img) + else: + raise ValueError("The provided operator {} is not recognized.".format(op_name)) + + return img + + def __repr__(self): + return self.__class__.__name__ + '(policy={},fill={})'.format(self.policy, self.fill) From 1b3d645a2e11a6bcf123d1197f491afd2b6b4a1c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Dec 2020 11:22:42 +0000 Subject: [PATCH 12/18] Apply changes from code review: - Move the transformations outside of AutoAugment on a separate method. - Renamed degenerate method for sharpness for better clarity. --- torchvision/transforms/autoaugment.py | 182 ++++++++++---------- torchvision/transforms/functional_tensor.py | 4 +- 2 files changed, 95 insertions(+), 91 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 4cdf219c22c..4c039ba2d68 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -17,6 +17,93 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" +def _get_transforms(policy: AutoAugmentPolicy): + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + + class AutoAugment(torch.nn.Module): r"""AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" `_. @@ -45,91 +132,8 @@ def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: super().__init__() self.policy = policy self.fill = fill - if policy == AutoAugmentPolicy.IMAGENET: - self.policies = [ - (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), - (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), - (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), - (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), - (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), - (("Rotate", 0.8, 8), ("Color", 0.4, 0)), - (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), - (("Equalize", 0.0, None), ("Equalize", 0.8, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Rotate", 0.8, 8), ("Color", 1.0, 2)), - (("Color", 0.8, 8), ("Solarize", 0.8, 7)), - (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), - (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), - (("Color", 0.4, 0), ("Equalize", 0.6, None)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - ] - elif policy == AutoAugmentPolicy.CIFAR10: - self.policies = [ - (("Invert", 0.1, None), ("Contrast", 0.2, 6)), - (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), - (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), - (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), - (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), - (("Color", 0.4, 3), ("Brightness", 0.6, 7)), - (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), - (("Equalize", 0.6, None), ("Equalize", 0.5, None)), - (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), - (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), - (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), - (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), - (("Brightness", 0.9, 6), ("Color", 0.2, 8)), - (("Solarize", 0.5, 2), ("Invert", 0.0, None)), - (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), - (("Equalize", 0.2, None), ("Equalize", 0.6, None)), - (("Color", 0.9, 9), ("Equalize", 0.6, None)), - (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), - (("Brightness", 0.1, 3), ("Color", 0.7, 0)), - (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), - (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), - (("Equalize", 0.8, None), ("Invert", 0.1, None)), - (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), - ] - elif policy == AutoAugmentPolicy.SVHN: - self.policies = [ - (("ShearX", 0.9, 4), ("Invert", 0.2, None)), - (("ShearY", 0.9, 8), ("Invert", 0.7, None)), - (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), - (("ShearY", 0.9, 8), ("Invert", 0.4, None)), - (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), - (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), - (("ShearY", 0.8, 8), ("Invert", 0.7, None)), - (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), - (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), - (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), - (("Invert", 0.6, None), ("Rotate", 0.8, 4)), - (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), - (("ShearX", 0.1, 6), ("Invert", 0.6, None)), - (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), - (("ShearY", 0.8, 4), ("Invert", 0.8, None)), - (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), - (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), - (("ShearX", 0.7, 2), ("Invert", 0.1, None)), - ] - else: + self.transforms = _get_transforms(policy) + if self.transforms is None: raise ValueError("The provided policy {} is not recognized.".format(policy)) _BINS = 10 @@ -152,13 +156,13 @@ def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: } @staticmethod - def get_params(policy_num: int) -> Tuple[int, Tensor, Tensor]: + def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: """Get parameters for autoaugment transformation Returns: params required by the autoaugment transformation """ - policy_id = torch.randint(policy_num, (1,)).item() + policy_id = torch.randint(transform_num, (1,)).item() probs = torch.rand((2,)) signs = torch.randint(2, (2,)) @@ -181,9 +185,9 @@ def forward(self, img: Tensor): elif fill is not None: fill = [float(f) for f in fill] - policy_id, probs, signs = self.get_params(len(self.policies)) + transform_id, probs, signs = self.get_params(len(self.transforms)) - for i, (op_name, p, magnitude_id) in enumerate(self.policies[policy_id]): + for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): if probs[i] <= p: magnitudes, signed = self._get_op_meta(op_name) magnitude = float(magnitudes[magnitude_id].item()) \ diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a123b4f5694..a3a1f592716 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1230,7 +1230,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor: return result.view(img.shape) -def _blur_image(img: Tensor) -> Tensor: +def _blurred_degenerate_image(img: Tensor) -> Tensor: dtype = img.dtype if torch.is_floating_point(img) else torch.float32 kernel = torch.ones((3, 3), dtype=dtype, device=img.device) @@ -1263,7 +1263,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: if img.size(-1) <= 2 or img.size(-2) <= 2: return img - return _blend(img, _blur_image(img), sharpness_factor) + return _blend(img, _blurred_degenerate_image(img), sharpness_factor) def autocontrast(img: Tensor) -> Tensor: From b62ec5d3bbf949b63ef26df360bab867b8936690 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Dec 2020 13:09:42 +0000 Subject: [PATCH 13/18] Update torchvision/transforms/functional.py Co-authored-by: vfdev --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 1db4f130766..2b90e78a0ac 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1197,7 +1197,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: """Posterize a PIL Image or torch Tensor by reducing the number of bits for each color channel. Args: - img (PIL Image or Tensor): Image to have its colors inverted. + img (PIL Image or Tensor): Image to have its colors posterized. If img is a Tensor, it should be of type torch.uint8 and it is expected to be in [..., H, W] format, where ... means it can have an arbitrary number of trailing dimensions. From 061db1f31b5413975e91e87b491a8bfa78ef7a75 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Dec 2020 13:33:31 +0000 Subject: [PATCH 14/18] Apply more changes from code review: - Add InterpolationMode parameter. - Move all declarations away from AutoAugment constructor and into the private method. --- torchvision/transforms/autoaugment.py | 61 +++++++++++++++------------ 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 4c039ba2d68..c2049e45017 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -6,7 +6,7 @@ from torch.jit.annotations import List, Tuple from typing import Optional -from . import functional as F +from . import functional as F, InterpolationMode class AutoAugmentPolicy(Enum): @@ -104,6 +104,27 @@ def _get_transforms(policy: AutoAugmentPolicy): ] +def _get_magnitudes(): + _BINS = 10 + return { + # name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, _BINS), True), + "ShearY": (torch.linspace(0.0, 0.3, _BINS), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "Rotate": (torch.linspace(0.0, 30.0, _BINS), True), + "Brightness": (torch.linspace(0.0, 0.9, _BINS), True), + "Color": (torch.linspace(0.0, 0.9, _BINS), True), + "Contrast": (torch.linspace(0.0, 0.9, _BINS), True), + "Sharpness": (torch.linspace(0.0, 0.9, _BINS), True), + "Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False), + "Solarize": (torch.linspace(256.0, 0.0, _BINS), False), + "AutoContrast": (None, None), + "Equalize": (None, None), + "Invert": (None, None), + } + + class AutoAugment(torch.nn.Module): r"""AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" `_. @@ -113,6 +134,9 @@ class AutoAugment(torch.nn.Module): Args: policy (AutoAugmentPolicy): Desired policy enum defined by :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed image. If int or float, the value is used for all bands respectively. This option is supported for PIL image and Tensor inputs. @@ -128,32 +152,17 @@ class AutoAugment(torch.nn.Module): >>> transforms.ToTensor()]) """ - def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: Optional[List[float]] = None): + def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None): super().__init__() self.policy = policy + self.interpolation = interpolation self.fill = fill + self.transforms = _get_transforms(policy) if self.transforms is None: raise ValueError("The provided policy {} is not recognized.".format(policy)) - - _BINS = 10 - self._op_meta = { - # name: (magnitudes, signed) - "ShearX": (torch.linspace(0.0, 0.3, _BINS), True), - "ShearY": (torch.linspace(0.0, 0.3, _BINS), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), - "Rotate": (torch.linspace(0.0, 30.0, _BINS), True), - "Brightness": (torch.linspace(0.0, 0.9, _BINS), True), - "Color": (torch.linspace(0.0, 0.9, _BINS), True), - "Contrast": (torch.linspace(0.0, 0.9, _BINS), True), - "Sharpness": (torch.linspace(0.0, 0.9, _BINS), True), - "Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False), - "Solarize": (torch.linspace(256.0, 0.0, _BINS), False), - "AutoContrast": (None, None), - "Equalize": (None, None), - "Invert": (None, None), - } + self._op_meta = _get_magnitudes() @staticmethod def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: @@ -197,18 +206,18 @@ def forward(self, img: Tensor): if op_name == "ShearX": img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], - fill=fill) + interpolation=self.interpolation, fill=fill) elif op_name == "ShearY": img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], - fill=fill) + interpolation=self.interpolation, fill=fill) elif op_name == "TranslateX": img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0, - shear=[0.0, 0.0], fill=fill) + interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) elif op_name == "TranslateY": img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0, - shear=[0.0, 0.0], fill=fill) + interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) elif op_name == "Rotate": - img = F.rotate(img, magnitude, fill=fill) + img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill) elif op_name == "Brightness": img = F.adjust_brightness(img, 1.0 + magnitude) elif op_name == "Color": From 6439dbcd0f7b1ba89e61fd3c535d924cdb8560db Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Dec 2020 15:46:30 +0000 Subject: [PATCH 15/18] Update documentation. --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 2b90e78a0ac..ec7e511989a 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1176,7 +1176,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa def invert(img: Tensor) -> Tensor: - """Invert the colors of a PIL Image or torch Tensor. + """Invert the colors of an RGB/grayscale PIL Image or torch Tensor. Args: img (PIL Image or Tensor): Image to have its colors inverted. From a5bc4925e79162aa6c7fcdc5e4a03bc8b55bc60e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Dec 2020 15:00:14 +0000 Subject: [PATCH 16/18] Apply suggestions from code review Co-authored-by: Francisco Massa --- torchvision/transforms/autoaugment.py | 2 +- torchvision/transforms/functional_tensor.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index c2049e45017..26847521998 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -242,4 +242,4 @@ def forward(self, img: Tensor): return img def __repr__(self): - return self.__class__.__name__ + '(policy={},fill={})'.format(self.policy, self.fill) + return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a3a1f592716..f0ef00364d0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1278,8 +1278,8 @@ def autocontrast(img: Tensor) -> Tensor: bound = 1.0 if img.is_floating_point() else 255.0 dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - minimum = img.amin(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype) - maximum = img.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype) + minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) + maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype) eq_idxs = torch.where(minimum == maximum)[0] minimum[eq_idxs] = 0 maximum[eq_idxs] = bound @@ -1292,10 +1292,7 @@ def _scale_channel(img_chan): hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) nonzero_hist = hist[hist != 0] - if nonzero_hist.numel() > 0: - step = (nonzero_hist.sum() - nonzero_hist[-1]) // 255 - else: - step = torch.tensor(0, device=img_chan.device) + step = nonzero_hist[:-1].sum() // 255 if step == 0: return img_chan From 48cfe2213baf1048f4306bdf380291625261bb76 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Dec 2020 15:59:02 +0000 Subject: [PATCH 17/18] Apply changes from code review: - Refactor code to eliminate as any to() and clamp() as possible. - Reuse methods where possible. - Apply speed ups. --- torchvision/transforms/functional_tensor.py | 35 ++++------ torchvision/transforms/transforms.py | 72 ++------------------- 2 files changed, 17 insertions(+), 90 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index f0ef00364d0..cb7a8911dd7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -703,9 +703,6 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con Returns: Tensor: Padded image. """ - if not _is_tensor_a_torch_image(img): - raise TypeError("tensor is not a torch image.") - if not isinstance(padding, (int, tuple, list)): raise TypeError("Got inappropriate padding arg") if not isinstance(fill, (int, float)): @@ -1192,9 +1189,8 @@ def invert(img: Tensor) -> Tensor: _assert_channels(img, [1, 3]) - bound = 1.0 if img.is_floating_point() else 255.0 - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - return (bound - img.to(dtype)).clamp(0, bound).to(img.dtype) + bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device) + return bound - img def posterize(img: Tensor, bits: int) -> Tensor: @@ -1220,14 +1216,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor: _assert_channels(img, [1, 3]) - bound = 1.0 if img.is_floating_point() else 255.0 - dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - - result = img.clone().view(-1) - invert_idx = torch.where(result >= threshold)[0] - result[invert_idx] = (bound - result[invert_idx].to(dtype=dtype)).clamp(0, bound).to(dtype=img.dtype) - - return result.view(img.shape) + inverted_img = invert(img) + return torch.where(img >= threshold, inverted_img, img) def _blurred_degenerate_image(img: Tensor) -> Tensor: @@ -1238,15 +1228,12 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: kernel /= kernel.sum() kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - result, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) - result = conv2d(result, kernel, groups=result.shape[-3]) - result = torch_pad(result, [1, 1, 1, 1]) - result = _cast_squeeze_out(result, need_cast, need_squeeze, out_dtype) + result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) + result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) + result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype) - result[..., 0, :] = img[..., 0, :] - result[..., -1, :] = img[..., -1, :] - result[..., :, 0] = img[..., :, 0] - result[..., :, -1] = img[..., :, -1] + result = img.clone() + result[..., 1:-1, 1:-1] = result_tmp return result @@ -1285,7 +1272,7 @@ def autocontrast(img: Tensor) -> Tensor: maximum[eq_idxs] = bound scale = bound / (maximum - minimum) - return ((img.to(dtype) - minimum) * scale).clamp(0, bound).to(img.dtype) + return ((img - minimum) * scale).clamp(0, bound).to(img.dtype) def _scale_channel(img_chan): @@ -1297,7 +1284,7 @@ def _scale_channel(img_chan): return img_chan lut = (torch.cumsum(hist, 0) + (step // 2)) // step - lut = torch.cat([torch.zeros(1, device=img_chan.device), lut[:-1]]).clamp(0, 255) + lut = pad(lut, [1, 0])[:-1].clamp(0, 255) return lut[img_chan.to(torch.int64)].to(torch.uint8) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index c6322ef71d4..117ba74b83a 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1716,16 +1716,6 @@ def __init__(self, p=0.5): super().__init__() self.p = p - @staticmethod - def get_params() -> float: - """Choose a value for the random transformation. - - Returns: - float: Random value which is used to determine whether the random transformation - should occur. - """ - return torch.rand(1).item() - def forward(self, img): """ Args: @@ -1734,7 +1724,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly color inverted image. """ - if self.get_params() < self.p: + if torch.rand(1).item() < self.p: return F.invert(img) return img @@ -1758,16 +1748,6 @@ def __init__(self, bits, p=0.5): self.bits = bits self.p = p - @staticmethod - def get_params() -> float: - """Choose a value for the random transformation. - - Returns: - float: Random value which is used to determine whether the random transformation - should occur. - """ - return torch.rand(1).item() - def forward(self, img): """ Args: @@ -1776,7 +1756,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly posterized image. """ - if self.get_params() < self.p: + if torch.rand(1).item() < self.p: return F.posterize(img, self.bits) return img @@ -1800,16 +1780,6 @@ def __init__(self, threshold, p=0.5): self.threshold = threshold self.p = p - @staticmethod - def get_params() -> float: - """Choose a value for the random transformation. - - Returns: - float: Random value which is used to determine whether the random transformation - should occur. - """ - return torch.rand(1).item() - def forward(self, img): """ Args: @@ -1818,7 +1788,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly solarized image. """ - if self.get_params() < self.p: + if torch.rand(1).item() < self.p: return F.solarize(img, self.threshold) return img @@ -1843,16 +1813,6 @@ def __init__(self, sharpness_factor, p=0.5): self.sharpness_factor = sharpness_factor self.p = p - @staticmethod - def get_params() -> float: - """Choose a value for the random transformation. - - Returns: - float: Random value which is used to determine whether the random transformation - should occur. - """ - return torch.rand(1).item() - def forward(self, img): """ Args: @@ -1861,7 +1821,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly sharpened image. """ - if self.get_params() < self.p: + if torch.rand(1).item() < self.p: return F.adjust_sharpness(img, self.sharpness_factor) return img @@ -1883,16 +1843,6 @@ def __init__(self, p=0.5): super().__init__() self.p = p - @staticmethod - def get_params() -> float: - """Choose a value for the random transformation. - - Returns: - float: Random value which is used to determine whether the random transformation - should occur. - """ - return torch.rand(1).item() - def forward(self, img): """ Args: @@ -1901,7 +1851,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly autocontrasted image. """ - if self.get_params() < self.p: + if torch.rand(1).item() < self.p: return F.autocontrast(img) return img @@ -1923,16 +1873,6 @@ def __init__(self, p=0.5): super().__init__() self.p = p - @staticmethod - def get_params() -> float: - """Choose a value for the random transformation. - - Returns: - float: Random value which is used to determine whether the random transformation - should occur. - """ - return torch.rand(1).item() - def forward(self, img): """ Args: @@ -1941,7 +1881,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly equalized image. """ - if self.get_params() < self.p: + if torch.rand(1).item() < self.p: return F.equalize(img) return img From a9a8537d4c7592edf92f635b4b87af86d043f11e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Dec 2020 16:23:14 +0000 Subject: [PATCH 18/18] Replacing pad. --- torchvision/transforms/functional_tensor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index cb7a8911dd7..a72cc41f5cd 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -703,6 +703,9 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con Returns: Tensor: Padded image. """ + if not _is_tensor_a_torch_image(img): + raise TypeError("tensor is not a torch image.") + if not isinstance(padding, (int, tuple, list)): raise TypeError("Got inappropriate padding arg") if not isinstance(fill, (int, float)): @@ -1284,7 +1287,7 @@ def _scale_channel(img_chan): return img_chan lut = (torch.cumsum(hist, 0) + (step // 2)) // step - lut = pad(lut, [1, 0])[:-1].clamp(0, 255) + lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) return lut[img_chan.to(torch.int64)].to(torch.uint8)