Skip to content

Commit 4b800b9

Browse files
authored
Create posterize transformation and refactor common methods to assist reuse. (#3108)
1 parent cd03c18 commit 4b800b9

File tree

7 files changed

+118
-9
lines changed

7 files changed

+118
-9
lines changed

test/test_functional_tensor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,13 +289,14 @@ def test_pad(self):
289289

290290
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
291291

292-
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"):
292+
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
293+
dts=(None, torch.float32, torch.float64)):
293294
script_fn = torch.jit.script(fn)
294295
torch.manual_seed(15)
295296
tensor, pil_img = self._create_data(26, 34, device=self.device)
296297
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
297298

298-
for dt in [None, torch.float32, torch.float64]:
299+
for dt in dts:
299300

300301
if dt is not None:
301302
tensor = F.convert_image_dtype(tensor, dt)
@@ -872,6 +873,17 @@ def test_invert(self):
872873
agg_method="max"
873874
)
874875

876+
def test_posterize(self):
877+
self._test_adjust_fn(
878+
F.posterize,
879+
F_pil.posterize,
880+
F_t.posterize,
881+
[{"bits": bits} for bits in range(0, 8)],
882+
tol=1.0,
883+
agg_method="max",
884+
dts=(None,)
885+
)
886+
875887

876888
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
877889
class CUDATester(Tester):

test/test_transforms.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,17 +1761,16 @@ def _test_randomness(self, fn, trans, configs):
17611761
num_samples = 250
17621762
counts = 0
17631763
for _ in range(num_samples):
1764-
out = trans(p=p, **config)(img)
1764+
tranformation = trans(p=p, **config)
1765+
tranformation.__repr__()
1766+
out = tranformation(img)
17651767
if out == inv_img:
17661768
counts += 1
17671769

17681770
p_value = stats.binom_test(counts, num_samples, p=p)
17691771
random.setstate(random_state)
17701772
self.assertGreater(p_value, 0.0001)
17711773

1772-
# Checking if it can be printed as string
1773-
trans().__repr__()
1774-
17751774
@unittest.skipIf(stats is None, 'scipy.stats not available')
17761775
def test_random_invert(self):
17771776
self._test_randomness(
@@ -1780,6 +1779,14 @@ def test_random_invert(self):
17801779
[{}]
17811780
)
17821781

1782+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1783+
def test_random_posterize(self):
1784+
self._test_randomness(
1785+
F.posterize,
1786+
transforms.RandomPosterize,
1787+
[{"bits": 4}]
1788+
)
1789+
17831790

17841791
if __name__ == '__main__':
17851792
unittest.main()

test/test_transforms_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ def test_random_vertical_flip(self):
9292
def test_random_invert(self):
9393
self._test_op('invert', 'RandomInvert')
9494

95+
def test_random_posterize(self):
96+
fn_kwargs = meth_kwargs = {"bits": 4}
97+
self._test_op(
98+
'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
99+
)
100+
95101
def test_color_jitter(self):
96102

97103
tol = 1.0 + 1e-10

torchvision/transforms/functional.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,3 +1196,24 @@ def invert(img: Tensor) -> Tensor:
11961196
return F_pil.invert(img)
11971197

11981198
return F_t.invert(img)
1199+
1200+
1201+
def posterize(img: Tensor, bits: int) -> Tensor:
1202+
"""Posterize a PIL Image or torch Tensor by reducing the number of bits for each color channel.
1203+
1204+
Args:
1205+
img (PIL Image or Tensor): Image to have its colors inverted.
1206+
If img is a Tensor, it is expected to be in [..., H, W] format,
1207+
where ... means it can have an arbitrary number of trailing
1208+
dimensions.
1209+
bits (int): The number of bits to keep for each channel (0-8).
1210+
Returns:
1211+
PIL Image: Posterized image.
1212+
"""
1213+
if not (0 <= bits <= 8):
1214+
raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits))
1215+
1216+
if not isinstance(img, torch.Tensor):
1217+
return F_pil.posterize(img, bits)
1218+
1219+
return F_t.posterize(img, bits)

torchvision/transforms/functional_pil.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,10 @@ def invert(img):
613613
if not _is_pil_image(img):
614614
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
615615
return ImageOps.invert(img)
616+
617+
618+
@torch.jit.unused
619+
def posterize(img, bits):
620+
if not _is_pil_image(img):
621+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
622+
return ImageOps.posterize(img, bits)

torchvision/transforms/functional_tensor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,3 +1193,17 @@ def invert(img: Tensor) -> Tensor:
11931193
bound = 1.0 if img.is_floating_point() else 255.0
11941194
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
11951195
return (bound - img.to(dtype)).to(img.dtype)
1196+
1197+
1198+
def posterize(img: Tensor, bits: int) -> Tensor:
1199+
if not _is_tensor_a_torch_image(img):
1200+
raise TypeError('tensor is not a torch image.')
1201+
1202+
if img.ndim < 3:
1203+
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
1204+
if img.dtype != torch.uint8:
1205+
raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype))
1206+
1207+
_assert_channels(img, [1, 3])
1208+
mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
1209+
return img & mask

torchvision/transforms/transforms.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
2222
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
2323
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
24-
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert"]
24+
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize"]
2525

2626

2727
class Compose:
@@ -1717,10 +1717,10 @@ def __init__(self, p=0.5):
17171717

17181718
@staticmethod
17191719
def get_params() -> float:
1720-
"""Choose value for random color inversion.
1720+
"""Choose a value for the random transformation.
17211721
17221722
Returns:
1723-
float: Random value which is used to determine whether the random color inversion
1723+
float: Random value which is used to determine whether the random transformation
17241724
should occur.
17251725
"""
17261726
return torch.rand(1).item()
@@ -1739,3 +1739,45 @@ def forward(self, img):
17391739

17401740
def __repr__(self):
17411741
return self.__class__.__name__ + '(p={})'.format(self.p)
1742+
1743+
1744+
class RandomPosterize(torch.nn.Module):
1745+
"""Posterize the image randomly with a given probability by reducing the
1746+
number of bits for each color channel. The image can be a PIL Image or a torch
1747+
Tensor, in which case it is expected to have [..., H, W] shape, where ... means
1748+
an arbitrary number of leading dimensions
1749+
1750+
Args:
1751+
bits (int): number of bits to keep for each channel (0-8)
1752+
p (float): probability of the image being color inverted. Default value is 0.5
1753+
"""
1754+
1755+
def __init__(self, bits, p=0.5):
1756+
super().__init__()
1757+
self.bits = bits
1758+
self.p = p
1759+
1760+
@staticmethod
1761+
def get_params() -> float:
1762+
"""Choose a value for the random transformation.
1763+
1764+
Returns:
1765+
float: Random value which is used to determine whether the random transformation
1766+
should occur.
1767+
"""
1768+
return torch.rand(1).item()
1769+
1770+
def forward(self, img):
1771+
"""
1772+
Args:
1773+
img (PIL Image or Tensor): Image to be posterized.
1774+
1775+
Returns:
1776+
PIL Image or Tensor: Randomly posterized image.
1777+
"""
1778+
if self.get_params() < self.p:
1779+
return F.posterize(img, self.bits)
1780+
return img
1781+
1782+
def __repr__(self):
1783+
return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p)

0 commit comments

Comments
 (0)