Skip to content

Implement autocontrast transform #3117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,16 @@ def test_adjust_sharpness(self):
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)

def test_autocontrast(self):
self._test_adjust_fn(
F.autocontrast,
F_pil.autocontrast,
F_t.autocontrast,
[{}],
tol=1.0,
agg_method="max"
)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
8 changes: 8 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,14 @@ def test_random_solarize(self):
[{"threshold": 192}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_autocontrast(self):
self._test_randomness(
F.autocontrast,
transforms.RandomAutocontrast,
[{}]
)


if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def test_random_solarize(self):
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_autocontrast(self):
self._test_op('autocontrast', 'RandomAutocontrast')

def test_color_jitter(self):

tol = 1.0 + 1e-10
Expand Down
26 changes: 23 additions & 3 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,7 @@ def invert(img: Tensor) -> Tensor:
dimensions.

Returns:
PIL Image: Color inverted image.
PIL Image or Tensor: Color inverted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.invert(img)
Expand All @@ -1208,7 +1208,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
it can have an arbitrary number of trailing dimensions.
bits (int): The number of bits to keep for each channel (0-8).
Returns:
PIL Image: Posterized image.
PIL Image or Tensor: Posterized image.
"""
if not (0 <= bits <= 8):
raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits))
Expand All @@ -1229,7 +1229,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
dimensions.
threshold (float): All pixels equal or above this value are inverted.
Returns:
PIL Image: Solarized image.
PIL Image or Tensor: Solarized image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.solarize(img, threshold)
Expand All @@ -1253,3 +1253,23 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
return F_pil.adjust_sharpness(img, sharpness_factor)

return F_t.adjust_sharpness(img, sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
"""Maximize contrast of a PIL Image or torch Tensor by remapping its
pixels per channel so that the lowest becomes black and the lightest
becomes white.

Args:
img (PIL Image or Tensor): Image on which autocontrast is applied.
If img is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.

Returns:
PIL Image or Tensor: An image that was autocontrasted.
"""
if not isinstance(img, torch.Tensor):
return F_pil.autocontrast(img)

return F_t.autocontrast(img)
7 changes: 7 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,3 +637,10 @@ def adjust_sharpness(img, sharpness_factor):
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img


@torch.jit.unused
def autocontrast(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img)
22 changes: 22 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,3 +1262,25 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
return img

return _blend(img, _blur_image(img), sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))

_assert_channels(img, [1, 3])

bound = 1.0 if img.is_floating_point() else 255.0
dtype = img.dtype if torch.is_floating_point(img) else torch.float32

minimum = img.amin(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype)
maximum = img.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype)
eq_idxs = torch.where(minimum == maximum)[0]
minimum[eq_idxs] = 0
maximum[eq_idxs] = bound
scale = bound / (maximum - minimum)

return ((img.to(dtype) - minimum) * scale).clamp(0, bound).to(img.dtype)
42 changes: 41 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize"]
"RandomSolarize", "RandomAutocontrast"]


class Compose:
Expand Down Expand Up @@ -1836,3 +1836,43 @@ def forward(self, img):

def __repr__(self):
return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)


class RandomAutocontrast(torch.nn.Module):
"""Autocontrast the pixels of the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions.

Args:
p (float): probability of the image being autocontrasted. Default value is 0.5
"""

def __init__(self, p=0.5):
super().__init__()
self.p = p

@staticmethod
def get_params() -> float:
"""Choose a value for the random transformation.

Returns:
float: Random value which is used to determine whether the random transformation
should occur.
"""
return torch.rand(1).item()

def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be autocontrasted.

Returns:
PIL Image or Tensor: Randomly autocontrasted image.
"""
if self.get_params() < self.p:
return F.autocontrast(img)
return img

def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)