diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 715450b5cfd..fa6297fa5ef 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -912,6 +912,16 @@ def test_adjust_sharpness(self): [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]] ) + def test_autocontrast(self): + self._test_adjust_fn( + F.autocontrast, + F_pil.autocontrast, + F_t.autocontrast, + [{}], + tol=1.0, + agg_method="max" + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 58ffa93f6e2..81104c10f21 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1838,6 +1838,14 @@ def test_random_solarize(self): [{"threshold": 192}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_autocontrast(self): + self._test_randomness( + F.autocontrast, + transforms.RandomAutocontrast, + [{}] + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 30c5b885bb8..2c36664a517 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -104,6 +104,9 @@ def test_random_solarize(self): 'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_random_autocontrast(self): + self._test_op('autocontrast', 'RandomAutocontrast') + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8383b08364b..d401aa4cc90 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1190,7 +1190,7 @@ def invert(img: Tensor) -> Tensor: dimensions. Returns: - PIL Image: Color inverted image. + PIL Image or Tensor: Color inverted image. """ if not isinstance(img, torch.Tensor): return F_pil.invert(img) @@ -1208,7 +1208,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: it can have an arbitrary number of trailing dimensions. bits (int): The number of bits to keep for each channel (0-8). Returns: - PIL Image: Posterized image. + PIL Image or Tensor: Posterized image. """ if not (0 <= bits <= 8): raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits)) @@ -1229,7 +1229,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor: dimensions. threshold (float): All pixels equal or above this value are inverted. Returns: - PIL Image: Solarized image. + PIL Image or Tensor: Solarized image. """ if not isinstance(img, torch.Tensor): return F_pil.solarize(img, threshold) @@ -1253,3 +1253,23 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: return F_pil.adjust_sharpness(img, sharpness_factor) return F_t.adjust_sharpness(img, sharpness_factor) + + +def autocontrast(img: Tensor) -> Tensor: + """Maximize contrast of a PIL Image or torch Tensor by remapping its + pixels per channel so that the lowest becomes black and the lightest + becomes white. + + Args: + img (PIL Image or Tensor): Image on which autocontrast is applied. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image or Tensor: An image that was autocontrasted. + """ + if not isinstance(img, torch.Tensor): + return F_pil.autocontrast(img) + + return F_t.autocontrast(img) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 72eafe37a2c..14f91713aca 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -637,3 +637,10 @@ def adjust_sharpness(img, sharpness_factor): enhancer = ImageEnhance.Sharpness(img) img = enhancer.enhance(sharpness_factor) return img + + +@torch.jit.unused +def autocontrast(img): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.autocontrast(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 70e1cdc7833..8a0432fa456 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1262,3 +1262,25 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: return img return _blend(img, _blur_image(img), sharpness_factor) + + +def autocontrast(img: Tensor) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + + _assert_channels(img, [1, 3]) + + bound = 1.0 if img.is_floating_point() else 255.0 + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + minimum = img.amin(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype) + maximum = img.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype) + eq_idxs = torch.where(minimum == maximum)[0] + minimum[eq_idxs] = 0 + maximum[eq_idxs] = bound + scale = bound / (maximum - minimum) + + return ((img.to(dtype) - minimum) * scale).clamp(0, bound).to(img.dtype) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2a854933ebd..963806b9962 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize"] + "RandomSolarize", "RandomAutocontrast"] class Compose: @@ -1836,3 +1836,43 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) + + +class RandomAutocontrast(torch.nn.Module): + """Autocontrast the pixels of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions. + + Args: + p (float): probability of the image being autocontrasted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be autocontrasted. + + Returns: + PIL Image or Tensor: Randomly autocontrasted image. + """ + if self.get_params() < self.p: + return F.autocontrast(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p)