Skip to content

Move AdjustSharpness from ColorJitter #3128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ def test_adjusts_L_mode(self):
self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')

def test_color_jitter(self):
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2)
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)

x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
Expand Down Expand Up @@ -1840,6 +1840,14 @@ def test_random_solarize(self):
[{"threshold": 192}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_adjust_sharpness(self):
self._test_randomness(
F.adjust_sharpness,
transforms.RandomAdjustSharpness,
[{"sharpness_factor": 2.0}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_autocontrast(self):
self._test_randomness(
Expand Down
14 changes: 7 additions & 7 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def test_random_solarize(self):
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_adjust_sharpness(self):
fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
self._test_op(
'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_autocontrast(self):
self._test_op('autocontrast', 'RandomAutocontrast')

Expand Down Expand Up @@ -138,14 +144,8 @@ def test_color_jitter(self):
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max"
)

for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
meth_kwargs = {"sharpness": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

# All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2}
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max"
)
Expand Down
77 changes: 54 additions & 23 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize", "RandomAutocontrast", "RandomEqualize"]
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]


class Compose:
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def __repr__(self):


class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast, saturation, hue and sharpness of an image.
"""Randomly change the brightness, contrast, saturation and hue of an image.

Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
Expand All @@ -1054,19 +1054,15 @@ class ColorJitter(torch.nn.Module):
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
sharpness (float or tuple of float (min, max)): How much to jitter sharpness.
sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness]
or the given [min, max]. Should be non negative numbers.
"""

def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False)
self.sharpness = self._check_input(sharpness, 'sharpness')

@torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
Expand All @@ -1082,7 +1078,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
else:
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))

# if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
Expand All @@ -1092,10 +1088,8 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
def get_params(brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]],
sharpness: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float],
Optional[float]]:
hue: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.

Args:
Expand All @@ -1107,22 +1101,19 @@ def get_params(brightness: Optional[List[float]],
uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.
sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen
uniformly. Pass None to turn off the transformation.

Returns:
tuple: The parameters used to apply the randomized transform
along with their random order.
"""
fn_idx = torch.randperm(5)
fn_idx = torch.randperm(4)

b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1]))

return fn_idx, b, c, s, h, sp
return fn_idx, b, c, s, h

def forward(self, img):
"""
Expand All @@ -1132,8 +1123,8 @@ def forward(self, img):
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness)
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)

for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
Expand All @@ -1144,8 +1135,6 @@ def forward(self, img):
img = F.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)
elif fn_id == 4 and sharpness_factor is not None:
img = F.adjust_sharpness(img, sharpness_factor)

return img

Expand All @@ -1154,8 +1143,7 @@ def __repr__(self):
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0}'.format(self.hue)
format_string += ', sharpness={0})'.format(self.sharpness)
format_string += ', hue={0})'.format(self.hue)
return format_string


Expand Down Expand Up @@ -1838,6 +1826,49 @@ def __repr__(self):
return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)


class RandomAdjustSharpness(torch.nn.Module):
"""Adjust the sharpness of the image randomly with a given probability. The image
can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W]
shape, where ... means an arbitrary number of leading dimensions.

Args:
sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
p (float): probability of the image being color inverted. Default value is 0.5
"""

def __init__(self, sharpness_factor, p=0.5):
super().__init__()
self.sharpness_factor = sharpness_factor
self.p = p

@staticmethod
def get_params() -> float:
"""Choose a value for the random transformation.

Returns:
float: Random value which is used to determine whether the random transformation
should occur.
"""
return torch.rand(1).item()

def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be sharpened.

Returns:
PIL Image or Tensor: Randomly sharpened image.
"""
if self.get_params() < self.p:
return F.adjust_sharpness(img, self.sharpness_factor)
return img

def __repr__(self):
return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p)


class RandomAutocontrast(torch.nn.Module):
"""Autocontrast the pixels of the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
Expand Down