From c56148cade4fa3c5085b1a0f400ede208b111a70 Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Thu, 29 Jul 2021 15:06:20 +0200 Subject: [PATCH 01/19] Initial Proposal --- test/test_transforms.py | 11 ++ test/test_transforms_tensor.py | 15 +++ torchvision/transforms/autoaugment.py | 185 ++++++++++++++++++-------- 3 files changed, 154 insertions(+), 57 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 74757bcb4e6..821a149837d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1483,6 +1483,17 @@ def test_autoaugment(policy, fill): img = transform(img) transform.__repr__() +@pytest.mark.parametrize('augmentation_space', ['aa', 'ta_wide']) +@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) +@pytest.mark.parametrize('num_magnitude_bins', [10,13,30]) +def test_autoaugment(augmentation_space, fill, num_magnitude_bins): + random.seed(42) + img = Image.open(GRACE_HOPPER) + transform = transforms.TrivialAugment(augmentation_space=augmentation_space, fill=fill, num_magnitude_bins=num_magnitude_bins) + for _ in range(100): + img = transform(img) + transform.__repr__() + def test_random_crop(): height = random.randint(10, 32) * 2 diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 0bf5d77716f..df288628373 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -541,6 +541,21 @@ def test_autoaugment(device, policy, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) +@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize('augmentation_space', ['aa', 'ta_wide']) +@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +def test_trivialaugment(device, augmentation_space, fill): + tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) + batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) + + s_transform = None + transform = T.TrivialAugment(augmentation_space=augmentation_space, fill=fill) + s_transform = torch.jit.script(transform) + for _ in range(25): + _test_transform_vs_scripted(transform, s_transform, tensor) + _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + + def test_autoaugment_save(): transform = T.AutoAugment() s_transform = torch.jit.script(transform) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 97522945d2e..350b90cf776 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -7,7 +7,7 @@ from . import functional as F, InterpolationMode -__all__ = ["AutoAugmentPolicy", "AutoAugment"] +__all__ = ["AutoAugmentPolicy", "AutoAugment", "TrivialAugment"] class AutoAugmentPolicy(Enum): @@ -106,25 +106,130 @@ def _get_transforms(policy: AutoAugmentPolicy): ] -def _get_magnitudes(): - _BINS = 10 - return { +def _get_magnitudes(augmentation_space: str, image_size: List[int], num_bins: int=10): + if augmentation_space == 'aa': + shear_max = 0.3 + translate_max_x = 150.0 / 331.0 * image_size[0] + translate_max_y = 150.0 / 331.0 * image_size[1] + rotate_max = 30.0 + enhancer_max = 0.9 + posterize_min_bits = 4 + + elif augmentation_space == 'ta_wide': + shear_max = 0.99 + translate_max_x = 32.0 # this is an absolute + translate_max_y = 32.0 # this is an absolute + rotate_max = 135.0 + enhancer_max = 0.99 + posterize_min_bits = 2 + else: + raise ValueError(f"Provided augmentation_space arguments {augmentation_space} not available.") + + + magnitudes = { # name: (magnitudes, signed) - "ShearX": (torch.linspace(0.0, 0.3, _BINS), True), - "ShearY": (torch.linspace(0.0, 0.3, _BINS), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), - "Rotate": (torch.linspace(0.0, 30.0, _BINS), True), - "Brightness": (torch.linspace(0.0, 0.9, _BINS), True), - "Color": (torch.linspace(0.0, 0.9, _BINS), True), - "Contrast": (torch.linspace(0.0, 0.9, _BINS), True), - "Sharpness": (torch.linspace(0.0, 0.9, _BINS), True), - "Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False), - "Solarize": (torch.linspace(256.0, 0.0, _BINS), False), - "AutoContrast": (None, None), - "Equalize": (None, None), - "Invert": (None, None), + "ShearX": (torch.linspace(0.0, shear_max, num_bins), True), + "ShearY": (torch.linspace(0.0, shear_max, num_bins), True), + "TranslateX": (torch.linspace(0.0, translate_max_x, num_bins), True), + "TranslateY": (torch.linspace(0.0, translate_max_y, num_bins), True), + "Rotate": (torch.linspace(0.0, rotate_max, num_bins), True), + "Brightness": (torch.linspace(0.0, enhancer_max, num_bins), True), + "Color": (torch.linspace(0.0, enhancer_max, num_bins), True), + "Contrast": (torch.linspace(0.0, enhancer_max, num_bins), True), + "Sharpness": (torch.linspace(0.0, enhancer_max, num_bins), True), + "Posterize": (8-(torch.arange(num_bins)/((num_bins-1)/(8-posterize_min_bits))).round().int(), False), + "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(float('nan')), False), + "Equalize": (torch.tensor(float('nan')), False), + "Invert": (torch.tensor(float('nan')), False), } + return magnitudes + +def apply_aug(img: Tensor, op_name: str, magnitude: float, + interpolation: InterpolationMode, fill: Optional[List[float]], num_magnitude_bins: int=10,): + if op_name == "ShearX": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, fill=fill) + elif op_name == "ShearY": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, fill=fill) + elif op_name == "TranslateX": + img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0, + interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "TranslateY": + img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0, + interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "Rotate": + img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) + elif op_name == "Brightness": + img = F.adjust_brightness(img, 1.0 + magnitude) + elif op_name == "Color": + img = F.adjust_saturation(img, 1.0 + magnitude) + elif op_name == "Contrast": + img = F.adjust_contrast(img, 1.0 + magnitude) + elif op_name == "Sharpness": + img = F.adjust_sharpness(img, 1.0 + magnitude) + elif op_name == "Posterize": + img = F.posterize(img, int(magnitude)) + elif op_name == "Solarize": + img = F.solarize(img, magnitude) + elif op_name == "AutoContrast": + img = F.autocontrast(img) + elif op_name == "Equalize": + img = F.equalize(img) + elif op_name == "Invert": + img = F.invert(img) + else: + raise ValueError("The provided operator {} is not recognized.".format(op_name)) + return img + + +class TrivialAugment(torch.nn.Module): + r"""Dataset-independent data-augmentation with TrivialAugment, as described in + `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + augmentation_space (str): A string defining which augmentation space to use. + The augmentation space can either set to be the one used for AutoAugment (`aa`) + or to the strongest augmentation space from the TrivialAugment paper (`ta_wide`). + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + def __init__(self, augmentation_space: str = 'ta_wide', num_magnitude_bins: int=30, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None): + super().__init__() + self.augmentation_space = augmentation_space + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def forward(self, img: Tensor): + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F._get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + op_meta = _get_magnitudes(self.augmentation_space, F._get_image_size(img), num_bins=self.num_magnitude_bins) + op_index = torch.randint(len(op_meta), (1,)) + op_name = list(op_meta.keys())[op_index.item()] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,))].item()) \ + if magnitudes.isnan().all() else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + + return apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill, + num_magnitude_bins=self.num_magnitude_bins) class AutoAugment(torch.nn.Module): @@ -154,7 +259,6 @@ def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, self.transforms = _get_transforms(policy) if self.transforms is None: raise ValueError("The provided policy {} is not recognized.".format(policy)) - self._op_meta = _get_magnitudes() @staticmethod def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: @@ -190,46 +294,13 @@ def forward(self, img: Tensor): for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): if probs[i] <= p: - magnitudes, signed = self._get_op_meta(op_name) + op_meta = _get_magnitudes('aa', F._get_image_size(img)) + magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[magnitude_id].item()) \ - if magnitudes is not None and magnitude_id is not None else 0.0 - if signed is not None and signed and signs[i] == 0: + if magnitudes.isnan().all() and magnitude_id is not None else 0.0 + if signed and signs[i]==0: magnitude *= -1.0 - - if op_name == "ShearX": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], - interpolation=self.interpolation, fill=fill) - elif op_name == "ShearY": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], - interpolation=self.interpolation, fill=fill) - elif op_name == "TranslateX": - img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0, - interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) - elif op_name == "TranslateY": - img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0, - interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) - elif op_name == "Rotate": - img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill) - elif op_name == "Brightness": - img = F.adjust_brightness(img, 1.0 + magnitude) - elif op_name == "Color": - img = F.adjust_saturation(img, 1.0 + magnitude) - elif op_name == "Contrast": - img = F.adjust_contrast(img, 1.0 + magnitude) - elif op_name == "Sharpness": - img = F.adjust_sharpness(img, 1.0 + magnitude) - elif op_name == "Posterize": - img = F.posterize(img, int(magnitude)) - elif op_name == "Solarize": - img = F.solarize(img, magnitude) - elif op_name == "AutoContrast": - img = F.autocontrast(img) - elif op_name == "Equalize": - img = F.equalize(img) - elif op_name == "Invert": - img = F.invert(img) - else: - raise ValueError("The provided operator {} is not recognized.".format(op_name)) + img = apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) return img From f4552ed3f9fa2783b28ce8a3fd9e15e4f92fd34b Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Thu, 29 Jul 2021 15:12:16 +0200 Subject: [PATCH 02/19] Tensor Save Test + Test Name Fix --- test/test_transforms.py | 2 +- test/test_transforms_tensor.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 821a149837d..d77e028621d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1486,7 +1486,7 @@ def test_autoaugment(policy, fill): @pytest.mark.parametrize('augmentation_space', ['aa', 'ta_wide']) @pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) @pytest.mark.parametrize('num_magnitude_bins', [10,13,30]) -def test_autoaugment(augmentation_space, fill, num_magnitude_bins): +def test_trivialaugment(augmentation_space, fill, num_magnitude_bins): random.seed(42) img = Image.open(GRACE_HOPPER) transform = transforms.TrivialAugment(augmentation_space=augmentation_space, fill=fill, num_magnitude_bins=num_magnitude_bins) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index df288628373..df72bb8c1b8 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -562,6 +562,12 @@ def test_autoaugment_save(): with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) +def test_trivialaugment_save(): + transform = T.TrivialAugment() + s_transform = torch.jit.script(transform) + with get_tmp_dir() as tmp_dir: + s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) + @pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize( From 25968e6302d3592a9cfa711f92e4d1de92220545 Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Thu, 29 Jul 2021 15:31:11 +0200 Subject: [PATCH 03/19] Formatting + removing unused argument --- test/test_transforms.py | 6 ++++-- test/test_transforms_tensor.py | 1 + torchvision/transforms/autoaugment.py | 17 +++++++++-------- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index d77e028621d..cfd67a3efcb 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1483,13 +1483,15 @@ def test_autoaugment(policy, fill): img = transform(img) transform.__repr__() + @pytest.mark.parametrize('augmentation_space', ['aa', 'ta_wide']) @pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) -@pytest.mark.parametrize('num_magnitude_bins', [10,13,30]) +@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30]) def test_trivialaugment(augmentation_space, fill, num_magnitude_bins): random.seed(42) img = Image.open(GRACE_HOPPER) - transform = transforms.TrivialAugment(augmentation_space=augmentation_space, fill=fill, num_magnitude_bins=num_magnitude_bins) + transform = transforms.TrivialAugment(augmentation_space=augmentation_space, + fill=fill, num_magnitude_bins=num_magnitude_bins) for _ in range(100): img = transform(img) transform.__repr__() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index df72bb8c1b8..1d85e297a8a 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -562,6 +562,7 @@ def test_autoaugment_save(): with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) + def test_trivialaugment_save(): transform = T.TrivialAugment() s_transform = torch.jit.script(transform) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 350b90cf776..bcaf206f53c 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -106,7 +106,7 @@ def _get_transforms(policy: AutoAugmentPolicy): ] -def _get_magnitudes(augmentation_space: str, image_size: List[int], num_bins: int=10): +def _get_magnitudes(augmentation_space: str, image_size: List[int], num_bins: int = 10): if augmentation_space == 'aa': shear_max = 0.3 translate_max_x = 150.0 / 331.0 * image_size[0] @@ -117,15 +117,14 @@ def _get_magnitudes(augmentation_space: str, image_size: List[int], num_bins: in elif augmentation_space == 'ta_wide': shear_max = 0.99 - translate_max_x = 32.0 # this is an absolute - translate_max_y = 32.0 # this is an absolute + translate_max_x = 32.0 # this is an absolute + translate_max_y = 32.0 # this is an absolute rotate_max = 135.0 enhancer_max = 0.99 posterize_min_bits = 2 else: raise ValueError(f"Provided augmentation_space arguments {augmentation_space} not available.") - magnitudes = { # name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, shear_max, num_bins), True), @@ -137,7 +136,7 @@ def _get_magnitudes(augmentation_space: str, image_size: List[int], num_bins: in "Color": (torch.linspace(0.0, enhancer_max, num_bins), True), "Contrast": (torch.linspace(0.0, enhancer_max, num_bins), True), "Sharpness": (torch.linspace(0.0, enhancer_max, num_bins), True), - "Posterize": (8-(torch.arange(num_bins)/((num_bins-1)/(8-posterize_min_bits))).round().int(), False), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / (8 - posterize_min_bits))).round().int(), False), "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), "AutoContrast": (torch.tensor(float('nan')), False), "Equalize": (torch.tensor(float('nan')), False), @@ -145,8 +144,9 @@ def _get_magnitudes(augmentation_space: str, image_size: List[int], num_bins: in } return magnitudes + def apply_aug(img: Tensor, op_name: str, magnitude: float, - interpolation: InterpolationMode, fill: Optional[List[float]], num_magnitude_bins: int=10,): + interpolation: InterpolationMode, fill: Optional[List[float]]): if op_name == "ShearX": img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], interpolation=interpolation, fill=fill) @@ -202,7 +202,8 @@ class TrivialAugment(torch.nn.Module): fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """ - def __init__(self, augmentation_space: str = 'ta_wide', num_magnitude_bins: int=30, + + def __init__(self, augmentation_space: str = 'ta_wide', num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None): super().__init__() @@ -298,7 +299,7 @@ def forward(self, img: Tensor): magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[magnitude_id].item()) \ if magnitudes.isnan().all() and magnitude_id is not None else 0.0 - if signed and signs[i]==0: + if signed and signs[i] == 0: magnitude *= -1.0 img = apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) From 2feff4f7193d33c03a5d6ea4c4c12d383fbf022d Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Thu, 29 Jul 2021 16:03:01 +0200 Subject: [PATCH 04/19] fix old argument --- torchvision/transforms/autoaugment.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index bcaf206f53c..5b0461409a8 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -229,8 +229,7 @@ def forward(self, img: Tensor): if signed and torch.randint(2, (1,)): magnitude *= -1.0 - return apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill, - num_magnitude_bins=self.num_magnitude_bins) + return apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) class AutoAugment(torch.nn.Module): From 58c7ba8180327b654cca8f7f36e2673e824c9662 Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Thu, 29 Jul 2021 16:37:39 +0200 Subject: [PATCH 05/19] fix isnan check error + indexing error with jit --- torchvision/transforms/autoaugment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 5b0461409a8..2863b15e6ad 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -224,8 +224,8 @@ def forward(self, img: Tensor): op_index = torch.randint(len(op_meta), (1,)) op_name = list(op_meta.keys())[op_index.item()] magnitudes, signed = op_meta[op_name] - magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,))].item()) \ - if magnitudes.isnan().all() else 0.0 + magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ + if not magnitudes.isnan().all() else 0.0 if signed and torch.randint(2, (1,)): magnitude *= -1.0 @@ -297,7 +297,7 @@ def forward(self, img: Tensor): op_meta = _get_magnitudes('aa', F._get_image_size(img)) magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[magnitude_id].item()) \ - if magnitudes.isnan().all() and magnitude_id is not None else 0.0 + if not magnitudes.isnan().all() and magnitude_id is not None else 0.0 if signed and signs[i] == 0: magnitude *= -1.0 img = apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) From 406848a536c3dee00c30dbb586dbc15ed59e0c77 Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Tue, 17 Aug 2021 16:20:37 +0200 Subject: [PATCH 06/19] Fix Flake8 error. --- torchvision/transforms/autoaugment.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index cbad3f7f9d1..ba1ed855dae 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -108,7 +108,9 @@ def _get_transforms( # type: ignore[return] ] -def _get_magnitudes(augmentation_space: str, image_size: List[int], num_bins: int = 10) -> Dict[str, Tuple[Tensor, bool]]: +def _get_magnitudes( + augmentation_space: str, image_size: List[int], num_bins: int = 10 +) -> Dict[str, Tuple[Tensor, bool]]: if augmentation_space == 'aa': shear_max = 0.3 translate_max_x = 150.0 / 331.0 * image_size[0] From f7434816b9543d14b3593c8fd01460894660599d Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Tue, 17 Aug 2021 16:35:17 +0200 Subject: [PATCH 07/19] Fix MyPy error. --- torchvision/transforms/autoaugment.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index ba1ed855dae..fa7e31a8e91 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -225,8 +225,8 @@ def forward(self, img: Tensor): fill = [float(f) for f in fill] op_meta = _get_magnitudes(self.augmentation_space, F._get_image_size(img), num_bins=self.num_magnitude_bins) - op_index = torch.randint(len(op_meta), (1,)) - op_name = list(op_meta.keys())[op_index.item()] + op_index: int = torch.randint(len(op_meta), (1,)).item() # type: ignore[assignment] + op_name = list(op_meta.keys())[op_index] magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ if not magnitudes.isnan().all() else 0.0 @@ -281,9 +281,6 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: return policy_id, probs, signs - def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]: - return self._op_meta[name] - def forward(self, img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. From 19d86960f5c04796d345e902893dd088f9115c36 Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Tue, 17 Aug 2021 16:43:23 +0200 Subject: [PATCH 08/19] Fix Flake8 error. --- torchvision/transforms/autoaugment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index fa7e31a8e91..4b3605a22c7 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -225,7 +225,7 @@ def forward(self, img: Tensor): fill = [float(f) for f in fill] op_meta = _get_magnitudes(self.augmentation_space, F._get_image_size(img), num_bins=self.num_magnitude_bins) - op_index: int = torch.randint(len(op_meta), (1,)).item() # type: ignore[assignment] + op_index: int = torch.randint(len(op_meta), (1,)).item() # type: ignore[assignment] op_name = list(op_meta.keys())[op_index] magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ From 1ed1568a95bd7b12a64c074c0e3bff7ee9732406 Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Tue, 17 Aug 2021 18:23:13 +0200 Subject: [PATCH 09/19] Fix PyTorch JIT error in UnitTests due to type annotation. --- torchvision/transforms/autoaugment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 4b3605a22c7..56e01aef5f6 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -225,8 +225,8 @@ def forward(self, img: Tensor): fill = [float(f) for f in fill] op_meta = _get_magnitudes(self.augmentation_space, F._get_image_size(img), num_bins=self.num_magnitude_bins) - op_index: int = torch.randint(len(op_meta), (1,)).item() # type: ignore[assignment] - op_name = list(op_meta.keys())[op_index] + op_index = torch.randint(len(op_meta), (1,)) + op_name = list(op_meta.keys())[op_index.item()] # type: ignore[index] magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ if not magnitudes.isnan().all() else 0.0 From 2fc8633cf863266f95c9d7f722ca44af6456f6f8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 26 Aug 2021 11:33:20 +0100 Subject: [PATCH 10/19] Fixing tests. --- test/test_transforms_tensor.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 8be34cd6a71..0eb45f08b19 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -533,6 +533,12 @@ def test_autoaugment(device, policy, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) +def test_autoaugment_save(tmpdir): + transform = T.AutoAugment() + s_transform = torch.jit.script(transform) + s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) + + @pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('augmentation_space', ['aa', 'ta_wide']) @pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) @@ -547,17 +553,11 @@ def test_trivialaugment(device, augmentation_space, fill): _test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) - -def test_autoaugment_save(tmpdir): - transform = T.AutoAugment() - s_transform = torch.jit.script(transform) - s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) - -def test_trivialaugment_save(tmp_dir): +def test_trivialaugment_save(tmpdir): transform = T.TrivialAugment() s_transform = torch.jit.script(transform) - s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) + s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) @pytest.mark.parametrize('device', cpu_and_gpu()) From 729c0dbd2e63853a2225527d4d01a1d05100d526 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 26 Aug 2021 11:54:07 +0100 Subject: [PATCH 11/19] Removing type ignore. --- torchvision/transforms/autoaugment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 56e01aef5f6..1a65308a2c9 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -225,8 +225,8 @@ def forward(self, img: Tensor): fill = [float(f) for f in fill] op_meta = _get_magnitudes(self.augmentation_space, F._get_image_size(img), num_bins=self.num_magnitude_bins) - op_index = torch.randint(len(op_meta), (1,)) - op_name = list(op_meta.keys())[op_index.item()] # type: ignore[index] + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ if not magnitudes.isnan().all() else 0.0 From 83552c6952f35d09e27122ba894cb35cad436297 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 26 Aug 2021 15:00:08 +0100 Subject: [PATCH 12/19] Adding support of ta_wide in references. --- references/classification/presets.py | 7 +++++-- test/test_transforms.py | 2 +- test/test_transforms_tensor.py | 2 +- torchvision/transforms/autoaugment.py | 27 +++++++++++++++++---------- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index ce5a6fe414f..b7ee64c5d4a 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -9,8 +9,11 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2 if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy)) + if auto_augment_policy == autoaugment.AugmentationSpace.TA_WIDE.value: + trans.append(autoaugment.TrivialAugment()) + else: + aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) + trans.append(autoaugment.AutoAugment(policy=aa_policy)) trans.extend([ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), diff --git a/test/test_transforms.py b/test/test_transforms.py index 7e1164db0b4..08a4f3c53cd 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1490,7 +1490,7 @@ def test_autoaugment(policy, fill): transform.__repr__() -@pytest.mark.parametrize('augmentation_space', ['aa', 'ta_wide']) +@pytest.mark.parametrize('augmentation_space', [space for space in transforms.AugmentationSpace]) @pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) @pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30]) def test_trivialaugment(augmentation_space, fill, num_magnitude_bins): diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 0eb45f08b19..2bcc3e22485 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -540,7 +540,7 @@ def test_autoaugment_save(tmpdir): @pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('augmentation_space', ['aa', 'ta_wide']) +@pytest.mark.parametrize('augmentation_space', [space for space in T.AugmentationSpace]) @pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) def test_trivialaugment(device, augmentation_space, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 2d5530f9dfb..cc9abab3576 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -7,7 +7,7 @@ from . import functional as F, InterpolationMode -__all__ = ["AutoAugmentPolicy", "AutoAugment", "TrivialAugment"] +__all__ = ["AutoAugmentPolicy", "AugmentationSpace", "AutoAugment", "TrivialAugment"] class AutoAugmentPolicy(Enum): @@ -19,6 +19,14 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" +class AugmentationSpace(Enum): + """The augmentation space to use. + Available spaces are `AA` for AutoAugment and `TA_WIDE` for the TrivialAugment. + """ + AA = "aa" + TA_WIDE = "ta_wide" + + def _get_transforms( # type: ignore[return] policy: AutoAugmentPolicy ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: @@ -109,9 +117,9 @@ def _get_transforms( # type: ignore[return] def _get_magnitudes( - augmentation_space: str, image_size: List[int], num_bins: int = 10 + augmentation_space: AugmentationSpace, image_size: List[int], num_bins: int = 10 ) -> Dict[str, Tuple[Tensor, bool]]: - if augmentation_space == 'aa': + if augmentation_space == AugmentationSpace.AA: shear_max = 0.3 translate_max_x = 150.0 / 331.0 * image_size[0] translate_max_y = 150.0 / 331.0 * image_size[1] @@ -119,7 +127,7 @@ def _get_magnitudes( enhancer_max = 0.9 posterize_min_bits = 4 - elif augmentation_space == 'ta_wide': + elif augmentation_space == AugmentationSpace.TA_WIDE: shear_max = 0.99 translate_max_x = 32.0 # this is an absolute translate_max_y = 32.0 # this is an absolute @@ -196,9 +204,8 @@ class TrivialAugment(torch.nn.Module): If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: - augmentation_space (str): A string defining which augmentation space to use. - The augmentation space can either set to be the one used for AutoAugment (`aa`) - or to the strongest augmentation space from the TrivialAugment paper (`ta_wide`). + augmentation_space (AugmentationSpace): Desired augmentation space enum defined by + :class:`torchvision.transforms.autoaugment.AugmentationSpace`. Default is ``AugmentationSpace.TA_WIDE``. num_magnitude_bins (int): The number of different magnitude values. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. @@ -207,9 +214,9 @@ class TrivialAugment(torch.nn.Module): image. If given a number, the value is used for all bands respectively. """ - def __init__(self, augmentation_space: str = 'ta_wide', num_magnitude_bins: int = 30, + def __init__(self, augmentation_space: AugmentationSpace = AugmentationSpace.TA_WIDE, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None): + fill: Optional[List[float]] = None) -> None: super().__init__() self.augmentation_space = augmentation_space self.num_magnitude_bins = num_magnitude_bins @@ -299,7 +306,7 @@ def forward(self, img: Tensor) -> Tensor: for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): if probs[i] <= p: - op_meta = _get_magnitudes('aa', F.get_image_size(img)) + op_meta = _get_magnitudes(AugmentationSpace.AA, F.get_image_size(img)) magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[magnitude_id].item()) \ if not magnitudes.isnan().all() and magnitude_id is not None else 0.0 From 1fe25fb9b83b66121fbc6672d752e3ec7d6a059d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 27 Aug 2021 14:21:56 +0100 Subject: [PATCH 13/19] Move methods in classes. --- torchvision/transforms/autoaugment.py | 283 +++++++++++++------------- 1 file changed, 146 insertions(+), 137 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index cc9abab3576..a7642e1a619 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -27,137 +27,7 @@ class AugmentationSpace(Enum): TA_WIDE = "ta_wide" -def _get_transforms( # type: ignore[return] - policy: AutoAugmentPolicy -) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: - if policy == AutoAugmentPolicy.IMAGENET: - return [ - (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), - (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), - (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), - (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), - (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), - (("Rotate", 0.8, 8), ("Color", 0.4, 0)), - (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), - (("Equalize", 0.0, None), ("Equalize", 0.8, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Rotate", 0.8, 8), ("Color", 1.0, 2)), - (("Color", 0.8, 8), ("Solarize", 0.8, 7)), - (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), - (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), - (("Color", 0.4, 0), ("Equalize", 0.6, None)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - ] - elif policy == AutoAugmentPolicy.CIFAR10: - return [ - (("Invert", 0.1, None), ("Contrast", 0.2, 6)), - (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), - (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), - (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), - (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), - (("Color", 0.4, 3), ("Brightness", 0.6, 7)), - (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), - (("Equalize", 0.6, None), ("Equalize", 0.5, None)), - (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), - (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), - (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), - (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), - (("Brightness", 0.9, 6), ("Color", 0.2, 8)), - (("Solarize", 0.5, 2), ("Invert", 0.0, None)), - (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), - (("Equalize", 0.2, None), ("Equalize", 0.6, None)), - (("Color", 0.9, 9), ("Equalize", 0.6, None)), - (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), - (("Brightness", 0.1, 3), ("Color", 0.7, 0)), - (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), - (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), - (("Equalize", 0.8, None), ("Invert", 0.1, None)), - (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), - ] - elif policy == AutoAugmentPolicy.SVHN: - return [ - (("ShearX", 0.9, 4), ("Invert", 0.2, None)), - (("ShearY", 0.9, 8), ("Invert", 0.7, None)), - (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), - (("ShearY", 0.9, 8), ("Invert", 0.4, None)), - (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), - (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), - (("ShearY", 0.8, 8), ("Invert", 0.7, None)), - (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), - (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), - (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), - (("Invert", 0.6, None), ("Rotate", 0.8, 4)), - (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), - (("ShearX", 0.1, 6), ("Invert", 0.6, None)), - (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), - (("ShearY", 0.8, 4), ("Invert", 0.8, None)), - (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), - (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), - (("ShearX", 0.7, 2), ("Invert", 0.1, None)), - ] - - -def _get_magnitudes( - augmentation_space: AugmentationSpace, image_size: List[int], num_bins: int = 10 -) -> Dict[str, Tuple[Tensor, bool]]: - if augmentation_space == AugmentationSpace.AA: - shear_max = 0.3 - translate_max_x = 150.0 / 331.0 * image_size[0] - translate_max_y = 150.0 / 331.0 * image_size[1] - rotate_max = 30.0 - enhancer_max = 0.9 - posterize_min_bits = 4 - - elif augmentation_space == AugmentationSpace.TA_WIDE: - shear_max = 0.99 - translate_max_x = 32.0 # this is an absolute - translate_max_y = 32.0 # this is an absolute - rotate_max = 135.0 - enhancer_max = 0.99 - posterize_min_bits = 2 - else: - raise ValueError(f"Provided augmentation_space arguments {augmentation_space} not available.") - - magnitudes = { - # name: (magnitudes, signed) - "ShearX": (torch.linspace(0.0, shear_max, num_bins), True), - "ShearY": (torch.linspace(0.0, shear_max, num_bins), True), - "TranslateX": (torch.linspace(0.0, translate_max_x, num_bins), True), - "TranslateY": (torch.linspace(0.0, translate_max_y, num_bins), True), - "Rotate": (torch.linspace(0.0, rotate_max, num_bins), True), - "Brightness": (torch.linspace(0.0, enhancer_max, num_bins), True), - "Color": (torch.linspace(0.0, enhancer_max, num_bins), True), - "Contrast": (torch.linspace(0.0, enhancer_max, num_bins), True), - "Sharpness": (torch.linspace(0.0, enhancer_max, num_bins), True), - "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / (8 - posterize_min_bits))).round().int(), False), - "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), - "AutoContrast": (torch.tensor(float('nan')), False), - "Equalize": (torch.tensor(float('nan')), False), - "Invert": (torch.tensor(float('nan')), False), - } - return magnitudes - - -def apply_aug(img: Tensor, op_name: str, magnitude: float, +def _apply_op(img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]): if op_name == "ShearX": img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], @@ -223,7 +93,33 @@ def __init__(self, augmentation_space: AugmentationSpace = AugmentationSpace.TA_ self.interpolation = interpolation self.fill = fill + @staticmethod + def _get_magnitudes(num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: + return { + # name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), + "Color": (torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), + "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(float('nan')), False), + "Equalize": (torch.tensor(float('nan')), False), + "Invert": (torch.tensor(float('nan')), False), + } + def forward(self, img: Tensor): + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: TrivialAugmented image. + """ fill = self.fill if isinstance(img, Tensor): if isinstance(fill, (int, float)): @@ -231,7 +127,12 @@ def forward(self, img: Tensor): elif fill is not None: fill = [float(f) for f in fill] - op_meta = _get_magnitudes(self.augmentation_space, F.get_image_size(img), num_bins=self.num_magnitude_bins) + if self.augmentation_space == AugmentationSpace.AA: + op_meta = AutoAugment._get_magnitudes(self.num_magnitude_bins, F.get_image_size(img)) + elif self.augmentation_space == AugmentationSpace.TA_WIDE: + op_meta = self._get_magnitudes(self.num_magnitude_bins) + else: + raise ValueError(f"Provided augmentation_space arguments {self.augmentation_space} not available.") op_index = int(torch.randint(len(op_meta), (1,)).item()) op_name = list(op_meta.keys())[op_index] magnitudes, signed = op_meta[op_name] @@ -240,7 +141,7 @@ def forward(self, img: Tensor): if signed and torch.randint(2, (1,)): magnitude *= -1.0 - return apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) class AutoAugment(torch.nn.Module): @@ -270,11 +171,119 @@ def __init__( self.policy = policy self.interpolation = interpolation self.fill = fill + self.transforms = self._get_transforms(policy) - self.transforms = _get_transforms(policy) - if self.transforms is None: + def _get_transforms( + self, + policy: AutoAugmentPolicy + ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: raise ValueError("The provided policy {} is not recognized.".format(policy)) + @staticmethod + def _get_magnitudes(num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(float('nan')), False), + "Equalize": (torch.tensor(float('nan')), False), + "Invert": (torch.tensor(float('nan')), False), + } + @staticmethod def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: """Get parameters for autoaugment transformation @@ -306,13 +315,13 @@ def forward(self, img: Tensor) -> Tensor: for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): if probs[i] <= p: - op_meta = _get_magnitudes(AugmentationSpace.AA, F.get_image_size(img)) + op_meta = self._get_magnitudes(10, F.get_image_size(img)) magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[magnitude_id].item()) \ if not magnitudes.isnan().all() and magnitude_id is not None else 0.0 if signed and signs[i] == 0: magnitude *= -1.0 - img = apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) return img From 226998cf78f3c14209bda2c88cedcda2f70086b7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 27 Aug 2021 15:10:59 +0100 Subject: [PATCH 14/19] Moving new classes to the bottom. --- torchvision/transforms/autoaugment.py | 197 ++++++++++++++------------ 1 file changed, 103 insertions(+), 94 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index a7642e1a619..4cc0f8d1108 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -7,24 +7,7 @@ from . import functional as F, InterpolationMode -__all__ = ["AutoAugmentPolicy", "AugmentationSpace", "AutoAugment", "TrivialAugment"] - - -class AutoAugmentPolicy(Enum): - """AutoAugment policies learned on different datasets. - Available policies are IMAGENET, CIFAR10 and SVHN. - """ - IMAGENET = "imagenet" - CIFAR10 = "cifar10" - SVHN = "svhn" - - -class AugmentationSpace(Enum): - """The augmentation space to use. - Available spaces are `AA` for AutoAugment and `TA_WIDE` for the TrivialAugment. - """ - AA = "aa" - TA_WIDE = "ta_wide" +__all__ = ["AutoAugmentPolicy", "AutoAugment", "AugmentationSpace", "TrivialAugment"] def _apply_op(img: Tensor, op_name: str, magnitude: float, @@ -66,82 +49,13 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float, return img -class TrivialAugment(torch.nn.Module): - r"""Dataset-independent data-augmentation with TrivialAugment, as described in - `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `. - If the image is torch Tensor, it should be of type torch.uint8, and it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - augmentation_space (AugmentationSpace): Desired augmentation space enum defined by - :class:`torchvision.transforms.autoaugment.AugmentationSpace`. Default is ``AugmentationSpace.TA_WIDE``. - num_magnitude_bins (int): The number of different magnitude values. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - fill (sequence or number, optional): Pixel fill value for the area outside the transformed - image. If given a number, the value is used for all bands respectively. - """ - - def __init__(self, augmentation_space: AugmentationSpace = AugmentationSpace.TA_WIDE, num_magnitude_bins: int = 30, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None) -> None: - super().__init__() - self.augmentation_space = augmentation_space - self.num_magnitude_bins = num_magnitude_bins - self.interpolation = interpolation - self.fill = fill - - @staticmethod - def _get_magnitudes(num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: - return { - # name: (magnitudes, signed) - "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), - "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), - "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), - "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), - "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), - "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), - "Color": (torch.linspace(0.0, 0.99, num_bins), True), - "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), - "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), - "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), - "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), - "AutoContrast": (torch.tensor(float('nan')), False), - "Equalize": (torch.tensor(float('nan')), False), - "Invert": (torch.tensor(float('nan')), False), - } - - def forward(self, img: Tensor): - """ - img (PIL Image or Tensor): Image to be transformed. - - Returns: - PIL Image or Tensor: TrivialAugmented image. - """ - fill = self.fill - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) - elif fill is not None: - fill = [float(f) for f in fill] - - if self.augmentation_space == AugmentationSpace.AA: - op_meta = AutoAugment._get_magnitudes(self.num_magnitude_bins, F.get_image_size(img)) - elif self.augmentation_space == AugmentationSpace.TA_WIDE: - op_meta = self._get_magnitudes(self.num_magnitude_bins) - else: - raise ValueError(f"Provided augmentation_space arguments {self.augmentation_space} not available.") - op_index = int(torch.randint(len(op_meta), (1,)).item()) - op_name = list(op_meta.keys())[op_index] - magnitudes, signed = op_meta[op_name] - magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ - if not magnitudes.isnan().all() else 0.0 - if signed and torch.randint(2, (1,)): - magnitude *= -1.0 - - return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) +class AutoAugmentPolicy(Enum): + """AutoAugment policies learned on different datasets. + Available policies are IMAGENET, CIFAR10 and SVHN. + """ + IMAGENET = "imagenet" + CIFAR10 = "cifar10" + SVHN = "svhn" class AutoAugment(torch.nn.Module): @@ -327,3 +241,98 @@ def forward(self, img: Tensor) -> Tensor: def __repr__(self) -> str: return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) + + +class AugmentationSpace(Enum): + """The augmentation space to use. + Available spaces are `AA` for AutoAugment and `TA_WIDE` for the TrivialAugment. + """ + AA = "aa" + TA_WIDE = "ta_wide" + + +class TrivialAugment(torch.nn.Module): + r"""Dataset-independent data-augmentation with TrivialAugment, as described in + `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + augmentation_space (AugmentationSpace): Desired augmentation space enum defined by + :class:`torchvision.transforms.autoaugment.AugmentationSpace`. Default is ``AugmentationSpace.TA_WIDE``. + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__(self, augmentation_space: AugmentationSpace = AugmentationSpace.TA_WIDE, num_magnitude_bins: int = 30, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None) -> None: + super().__init__() + self.augmentation_space = augmentation_space + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + @staticmethod + def _get_magnitudes(num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: + return { + # name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), + "Color": (torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), + "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(float('nan')), False), + "Equalize": (torch.tensor(float('nan')), False), + "Invert": (torch.tensor(float('nan')), False), + } + + def forward(self, img: Tensor): + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: TrivialAugmented image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + if self.augmentation_space == AugmentationSpace.AA: + op_meta = AutoAugment._get_magnitudes(self.num_magnitude_bins, F.get_image_size(img)) + elif self.augmentation_space == AugmentationSpace.TA_WIDE: + op_meta = self._get_magnitudes(self.num_magnitude_bins) + else: + raise ValueError(f"Provided augmentation_space arguments {self.augmentation_space} not available.") + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ + if not magnitudes.isnan().all() else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + + return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'augmentation_space={augmentation_space}' + s += ', num_magnitude_bins={num_magnitude_bins}' + s += ', interpolation={interpolation}' + s += ', fill={fill}' + s += ')' + return s.format(**self.__dict__) From 425c52da0474f0d83cd135473b72408701a51bab Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 27 Aug 2021 16:25:08 +0100 Subject: [PATCH 15/19] Specialize to TA to TAwide --- references/classification/presets.py | 4 +-- test/test_transforms.py | 6 ++--- test/test_transforms_tensor.py | 10 +++---- torchvision/transforms/autoaugment.py | 38 +++++++-------------------- 4 files changed, 17 insertions(+), 41 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index b7ee64c5d4a..0ccf835e7c3 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -9,8 +9,8 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2 if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: - if auto_augment_policy == autoaugment.AugmentationSpace.TA_WIDE.value: - trans.append(autoaugment.TrivialAugment()) + if auto_augment_policy == "ta_wide": + trans.append(autoaugment.TrivialAugmentWide()) else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy)) diff --git a/test/test_transforms.py b/test/test_transforms.py index 08a4f3c53cd..2b15c6afdd0 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1490,14 +1490,12 @@ def test_autoaugment(policy, fill): transform.__repr__() -@pytest.mark.parametrize('augmentation_space', [space for space in transforms.AugmentationSpace]) @pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) @pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30]) -def test_trivialaugment(augmentation_space, fill, num_magnitude_bins): +def test_trivialaugmentwide(fill, num_magnitude_bins): random.seed(42) img = Image.open(GRACE_HOPPER) - transform = transforms.TrivialAugment(augmentation_space=augmentation_space, - fill=fill, num_magnitude_bins=num_magnitude_bins) + transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins) for _ in range(100): img = transform(img) transform.__repr__() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 2bcc3e22485..a057e193d8a 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -540,22 +540,20 @@ def test_autoaugment_save(tmpdir): @pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('augmentation_space', [space for space in T.AugmentationSpace]) @pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) -def test_trivialaugment(device, augmentation_space, fill): +def test_trivialaugmentwide(device, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) - s_transform = None - transform = T.TrivialAugment(augmentation_space=augmentation_space, fill=fill) + transform = T.TrivialAugmentWide(fill=fill) s_transform = torch.jit.script(transform) for _ in range(25): _test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -def test_trivialaugment_save(tmpdir): - transform = T.TrivialAugment() +def test_trivialaugmentwide_save(tmpdir): + transform = T.TrivialAugmentWide() s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 4cc0f8d1108..4db0f9fea35 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -7,7 +7,7 @@ from . import functional as F, InterpolationMode -__all__ = ["AutoAugmentPolicy", "AutoAugment", "AugmentationSpace", "TrivialAugment"] +__all__ = ["AutoAugmentPolicy", "AutoAugment", "TrivialAugmentWide"] def _apply_op(img: Tensor, op_name: str, magnitude: float, @@ -178,8 +178,7 @@ def _get_transforms( else: raise ValueError("The provided policy {} is not recognized.".format(policy)) - @staticmethod - def _get_magnitudes(num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: return { # name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), @@ -243,24 +242,14 @@ def __repr__(self) -> str: return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) -class AugmentationSpace(Enum): - """The augmentation space to use. - Available spaces are `AA` for AutoAugment and `TA_WIDE` for the TrivialAugment. - """ - AA = "aa" - TA_WIDE = "ta_wide" - - -class TrivialAugment(torch.nn.Module): - r"""Dataset-independent data-augmentation with TrivialAugment, as described in +class TrivialAugmentWide(torch.nn.Module): + r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `. If the image is torch Tensor, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: - augmentation_space (AugmentationSpace): Desired augmentation space enum defined by - :class:`torchvision.transforms.autoaugment.AugmentationSpace`. Default is ``AugmentationSpace.TA_WIDE``. num_magnitude_bins (int): The number of different magnitude values. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. @@ -269,17 +258,14 @@ class TrivialAugment(torch.nn.Module): image. If given a number, the value is used for all bands respectively. """ - def __init__(self, augmentation_space: AugmentationSpace = AugmentationSpace.TA_WIDE, num_magnitude_bins: int = 30, - interpolation: InterpolationMode = InterpolationMode.NEAREST, + def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None) -> None: super().__init__() - self.augmentation_space = augmentation_space self.num_magnitude_bins = num_magnitude_bins self.interpolation = interpolation self.fill = fill - @staticmethod - def _get_magnitudes(num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: + def _get_magnitudes(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: return { # name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), @@ -303,7 +289,7 @@ def forward(self, img: Tensor): img (PIL Image or Tensor): Image to be transformed. Returns: - PIL Image or Tensor: TrivialAugmented image. + PIL Image or Tensor: Transformed image. """ fill = self.fill if isinstance(img, Tensor): @@ -312,12 +298,7 @@ def forward(self, img: Tensor): elif fill is not None: fill = [float(f) for f in fill] - if self.augmentation_space == AugmentationSpace.AA: - op_meta = AutoAugment._get_magnitudes(self.num_magnitude_bins, F.get_image_size(img)) - elif self.augmentation_space == AugmentationSpace.TA_WIDE: - op_meta = self._get_magnitudes(self.num_magnitude_bins) - else: - raise ValueError(f"Provided augmentation_space arguments {self.augmentation_space} not available.") + op_meta = self._get_magnitudes(self.num_magnitude_bins) op_index = int(torch.randint(len(op_meta), (1,)).item()) op_name = list(op_meta.keys())[op_index] magnitudes, signed = op_meta[op_name] @@ -330,8 +311,7 @@ def forward(self, img: Tensor): def __repr__(self) -> str: s = self.__class__.__name__ + '(' - s += 'augmentation_space={augmentation_space}' - s += ', num_magnitude_bins={num_magnitude_bins}' + s += 'num_magnitude_bins={num_magnitude_bins}' s += ', interpolation={interpolation}' s += ', fill={fill}' s += ')' From bd2dc1704355f65bc7de7de7c6420929bcbf280f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 2 Sep 2021 16:30:14 +0100 Subject: [PATCH 16/19] Add missing type --- torchvision/transforms/autoaugment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index ad3d0afcfcb..117030d3a50 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -369,7 +369,7 @@ def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: "Invert": (torch.tensor(0.0), False), } - def forward(self, img: Tensor): + def forward(self, img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. From 5770a0393e781ce31a97ad3991d7b9680d0cb17d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 2 Sep 2021 16:40:26 +0100 Subject: [PATCH 17/19] Fixing lint --- test/test_transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 7f4cab6ba21..675e79ac3ba 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1501,6 +1501,7 @@ def test_randaugment(num_ops, magnitude, fill): img = transform(img) transform.__repr__() + @pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) @pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30]) def test_trivialaugmentwide(fill, num_magnitude_bins): From 46f886c8c0264446971d787c10cf722e8bed3d20 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 2 Sep 2021 17:13:00 +0100 Subject: [PATCH 18/19] Fix doc --- gallery/plot_transforms.py | 2 +- torchvision/transforms/autoaugment.py | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/gallery/plot_transforms.py b/gallery/plot_transforms.py index 68ffae16a0f..fe5864ebad5 100644 --- a/gallery/plot_transforms.py +++ b/gallery/plot_transforms.py @@ -255,7 +255,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): #################################### # TrivialAugmentWide -# ~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data. augmenter = T.TrivialAugmentWide() imgs = [augmenter(orig_img) for _ in range(4)] diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 117030d3a50..4f82bc6acd5 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -330,17 +330,17 @@ def __repr__(self) -> str: class TrivialAugmentWide(torch.nn.Module): r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `. - If the image is torch Tensor, it should be of type torch.uint8, and it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - num_magnitude_bins (int): The number of different magnitude values. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - fill (sequence or number, optional): Pixel fill value for the area outside the transformed - image. If given a number, the value is used for all bands respectively. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. """ def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST, From 30bbae9a6ff1fac1b71025389f46686ad9cd0f42 Mon Sep 17 00:00:00 2001 From: Samuel Mueller Date: Mon, 6 Sep 2021 10:16:32 +0200 Subject: [PATCH 19/19] Fix search space of TrivialAugment. --- torchvision/transforms/autoaugment.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 4f82bc6acd5..44c7990482b 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float, img = F.equalize(img) elif op_name == "Invert": img = F.invert(img) + elif op_name == "Identity": + pass else: raise ValueError("The provided operator {} is not recognized.".format(op_name)) return img @@ -353,6 +355,7 @@ def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMod def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: return { # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), @@ -366,7 +369,6 @@ def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), "AutoContrast": (torch.tensor(0.0), False), "Equalize": (torch.tensor(0.0), False), - "Invert": (torch.tensor(0.0), False), } def forward(self, img: Tensor) -> Tensor: