From 31fadbee7d1a65cd73ae43dfd4ac6e97e7ca7b01 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Fri, 29 Oct 2021 10:32:46 +0100 Subject: [PATCH 1/7] Adding multiweight support for shufflenetv2 prototype models --- torchvision/prototype/models/__init__.py | 1 + torchvision/prototype/models/shufflenetv2.py | 121 +++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 torchvision/prototype/models/shufflenetv2.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index a187af7f090..399280eaff7 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -5,6 +5,7 @@ from .efficientnet import * from .mobilenetv3 import * from .mnasnet import * +from .shufflenetv2 import * from . import detection from . import quantization from . import segmentation diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py new file mode 100644 index 00000000000..d6d02873051 --- /dev/null +++ b/torchvision/prototype/models/shufflenetv2.py @@ -0,0 +1,121 @@ +import warnings +from functools import partial +from typing import Any, Optional + +from torchvision.transforms.functional import InterpolationMode + +from ...models.shufflenetv2 import ShuffleNetV2 +from ..transforms.presets import ImageNetEval +from ._api import Weights, WeightEntry +from ._meta import _IMAGENET_CATEGORIES + + +__all__ = [ + "ShuffleNetV2", + "ShuffleNetV2_x0_5Weights", + "ShuffleNetV2_x1_0Weights", + "ShuffleNetV2_x1_5Weights", + "ShuffleNetV2_x2_0Weights", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", +] + + +def _shufflenetv2( + weights: Optional[Weights], + progress: bool, + *args: Any, + **kwargs: Any, +) -> ShuffleNetV2: + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + + model = ShuffleNetV2(*args, **kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model + + +_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} + + +class ShuffleNetV2_x0_5Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 69.362, + "acc@5": 88.316, + }, + ) + + +class ShuffleNetV2_x1_0Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 60.552, + "acc@5": 81.746, + }, + ) + + +class ShuffleNetV2_x1_5Weights(Weights): + pass + + +class ShuffleNetV2_x2_0Weights(Weights): + pass + + +def shufflenet_v2_x0_5( + weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ShuffleNetV2_x0_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = ShuffleNetV2_x0_5Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + + +def shufflenet_v2_x1_0( + weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ShuffleNetV2_x1_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = ShuffleNetV2_x1_0Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + + +def shufflenet_v2_x1_5( + weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ShuffleNetV2_x1_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = ShuffleNetV2_x1_5Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + + +def shufflenet_v2_x2_0( + weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ShuffleNetV2_x2_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = ShuffleNetV2_x2_0Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) From 1e578b7fe05ff5a18201df8a36b552a18fabcd08 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Fri, 29 Oct 2021 10:42:31 +0100 Subject: [PATCH 2/7] Revert "Adding multiweight support for shufflenetv2 prototype models" This reverts commit 31fadbee7d1a65cd73ae43dfd4ac6e97e7ca7b01. --- torchvision/prototype/models/__init__.py | 1 - torchvision/prototype/models/shufflenetv2.py | 121 ------------------- 2 files changed, 122 deletions(-) delete mode 100644 torchvision/prototype/models/shufflenetv2.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 399280eaff7..a187af7f090 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -5,7 +5,6 @@ from .efficientnet import * from .mobilenetv3 import * from .mnasnet import * -from .shufflenetv2 import * from . import detection from . import quantization from . import segmentation diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py deleted file mode 100644 index d6d02873051..00000000000 --- a/torchvision/prototype/models/shufflenetv2.py +++ /dev/null @@ -1,121 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Optional - -from torchvision.transforms.functional import InterpolationMode - -from ...models.shufflenetv2 import ShuffleNetV2 -from ..transforms.presets import ImageNetEval -from ._api import Weights, WeightEntry -from ._meta import _IMAGENET_CATEGORIES - - -__all__ = [ - "ShuffleNetV2", - "ShuffleNetV2_x0_5Weights", - "ShuffleNetV2_x1_0Weights", - "ShuffleNetV2_x1_5Weights", - "ShuffleNetV2_x2_0Weights", - "shufflenet_v2_x0_5", - "shufflenet_v2_x1_0", - "shufflenet_v2_x1_5", - "shufflenet_v2_x2_0", -] - - -def _shufflenetv2( - weights: Optional[Weights], - progress: bool, - *args: Any, - **kwargs: Any, -) -> ShuffleNetV2: - if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) - - model = ShuffleNetV2(*args, **kwargs) - - if weights is not None: - model.load_state_dict(weights.state_dict(progress=progress)) - - return model - - -_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} - - -class ShuffleNetV2_x0_5Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( - url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageNetEval, crop_size=224), - meta={ - **_common_meta, - "recipe": "", - "acc@1": 69.362, - "acc@5": 88.316, - }, - ) - - -class ShuffleNetV2_x1_0Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( - url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageNetEval, crop_size=224), - meta={ - **_common_meta, - "recipe": "", - "acc@1": 60.552, - "acc@5": 81.746, - }, - ) - - -class ShuffleNetV2_x1_5Weights(Weights): - pass - - -class ShuffleNetV2_x2_0Weights(Weights): - pass - - -def shufflenet_v2_x0_5( - weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - if "pretrained" in kwargs: - warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = ShuffleNetV2_x0_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None - weights = ShuffleNetV2_x0_5Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) - - -def shufflenet_v2_x1_0( - weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - if "pretrained" in kwargs: - warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = ShuffleNetV2_x1_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None - weights = ShuffleNetV2_x1_0Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) - - -def shufflenet_v2_x1_5( - weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - if "pretrained" in kwargs: - warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = ShuffleNetV2_x1_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None - weights = ShuffleNetV2_x1_5Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) - - -def shufflenet_v2_x2_0( - weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - if "pretrained" in kwargs: - warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = ShuffleNetV2_x2_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None - weights = ShuffleNetV2_x2_0Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) From 4e3d900f796c1e3e667312087e77956ca4a4c017 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Fri, 29 Oct 2021 10:59:31 +0100 Subject: [PATCH 3/7] Adding multiweight support for shufflenetv2 prototype models --- torchvision/prototype/models/__init__.py | 1 + torchvision/prototype/models/shufflenetv2.py | 121 +++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 torchvision/prototype/models/shufflenetv2.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 264d787d40e..69fe4310606 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -7,6 +7,7 @@ from .mobilenetv2 import * from .mnasnet import * from .regnet import * +from .shufflenetv2 import * from . import detection from . import quantization from . import segmentation diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py new file mode 100644 index 00000000000..d6d02873051 --- /dev/null +++ b/torchvision/prototype/models/shufflenetv2.py @@ -0,0 +1,121 @@ +import warnings +from functools import partial +from typing import Any, Optional + +from torchvision.transforms.functional import InterpolationMode + +from ...models.shufflenetv2 import ShuffleNetV2 +from ..transforms.presets import ImageNetEval +from ._api import Weights, WeightEntry +from ._meta import _IMAGENET_CATEGORIES + + +__all__ = [ + "ShuffleNetV2", + "ShuffleNetV2_x0_5Weights", + "ShuffleNetV2_x1_0Weights", + "ShuffleNetV2_x1_5Weights", + "ShuffleNetV2_x2_0Weights", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", +] + + +def _shufflenetv2( + weights: Optional[Weights], + progress: bool, + *args: Any, + **kwargs: Any, +) -> ShuffleNetV2: + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + + model = ShuffleNetV2(*args, **kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model + + +_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} + + +class ShuffleNetV2_x0_5Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 69.362, + "acc@5": 88.316, + }, + ) + + +class ShuffleNetV2_x1_0Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "", + "acc@1": 60.552, + "acc@5": 81.746, + }, + ) + + +class ShuffleNetV2_x1_5Weights(Weights): + pass + + +class ShuffleNetV2_x2_0Weights(Weights): + pass + + +def shufflenet_v2_x0_5( + weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ShuffleNetV2_x0_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = ShuffleNetV2_x0_5Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + + +def shufflenet_v2_x1_0( + weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ShuffleNetV2_x1_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = ShuffleNetV2_x1_0Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + + +def shufflenet_v2_x1_5( + weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ShuffleNetV2_x1_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = ShuffleNetV2_x1_5Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + + +def shufflenet_v2_x2_0( + weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ShuffleNetV2_x2_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = ShuffleNetV2_x2_0Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) From 615b612933c1dea2da471ac5678bd4ec97e5255f Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Fri, 29 Oct 2021 11:14:45 +0100 Subject: [PATCH 4/7] Revert "Adding multiweight support for shufflenetv2 prototype models" This reverts commit 4e3d900f796c1e3e667312087e77956ca4a4c017. --- torchvision/prototype/models/__init__.py | 1 - torchvision/prototype/models/shufflenetv2.py | 121 ------------------- 2 files changed, 122 deletions(-) delete mode 100644 torchvision/prototype/models/shufflenetv2.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 69fe4310606..264d787d40e 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -7,7 +7,6 @@ from .mobilenetv2 import * from .mnasnet import * from .regnet import * -from .shufflenetv2 import * from . import detection from . import quantization from . import segmentation diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py deleted file mode 100644 index d6d02873051..00000000000 --- a/torchvision/prototype/models/shufflenetv2.py +++ /dev/null @@ -1,121 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Optional - -from torchvision.transforms.functional import InterpolationMode - -from ...models.shufflenetv2 import ShuffleNetV2 -from ..transforms.presets import ImageNetEval -from ._api import Weights, WeightEntry -from ._meta import _IMAGENET_CATEGORIES - - -__all__ = [ - "ShuffleNetV2", - "ShuffleNetV2_x0_5Weights", - "ShuffleNetV2_x1_0Weights", - "ShuffleNetV2_x1_5Weights", - "ShuffleNetV2_x2_0Weights", - "shufflenet_v2_x0_5", - "shufflenet_v2_x1_0", - "shufflenet_v2_x1_5", - "shufflenet_v2_x2_0", -] - - -def _shufflenetv2( - weights: Optional[Weights], - progress: bool, - *args: Any, - **kwargs: Any, -) -> ShuffleNetV2: - if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) - - model = ShuffleNetV2(*args, **kwargs) - - if weights is not None: - model.load_state_dict(weights.state_dict(progress=progress)) - - return model - - -_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} - - -class ShuffleNetV2_x0_5Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( - url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageNetEval, crop_size=224), - meta={ - **_common_meta, - "recipe": "", - "acc@1": 69.362, - "acc@5": 88.316, - }, - ) - - -class ShuffleNetV2_x1_0Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( - url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageNetEval, crop_size=224), - meta={ - **_common_meta, - "recipe": "", - "acc@1": 60.552, - "acc@5": 81.746, - }, - ) - - -class ShuffleNetV2_x1_5Weights(Weights): - pass - - -class ShuffleNetV2_x2_0Weights(Weights): - pass - - -def shufflenet_v2_x0_5( - weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - if "pretrained" in kwargs: - warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = ShuffleNetV2_x0_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None - weights = ShuffleNetV2_x0_5Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) - - -def shufflenet_v2_x1_0( - weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - if "pretrained" in kwargs: - warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = ShuffleNetV2_x1_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None - weights = ShuffleNetV2_x1_0Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) - - -def shufflenet_v2_x1_5( - weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - if "pretrained" in kwargs: - warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = ShuffleNetV2_x1_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None - weights = ShuffleNetV2_x1_5Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) - - -def shufflenet_v2_x2_0( - weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - if "pretrained" in kwargs: - warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = ShuffleNetV2_x2_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None - weights = ShuffleNetV2_x2_0Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) From 069bba4880e360615b527e1b6f5492d47020ea57 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 19 Jan 2022 14:22:55 +0000 Subject: [PATCH 5/7] Add RenderedSST2 dataset --- docs/source/datasets.rst | 1 + test/test_datasets.py | 22 +++++++ torchvision/datasets/__init__.py | 2 + torchvision/datasets/rendered_sst2.py | 91 +++++++++++++++++++++++++++ 4 files changed, 116 insertions(+) create mode 100644 torchvision/datasets/rendered_sst2.py diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 7f09ff245ca..a4027d36761 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -60,6 +60,7 @@ You can also create your own datasets using the provided :ref:`base classes `_. + + Rendered SST2 is a image classification dataset used to evaluate the models capability on optical + character recognition. This dataset was generated bu rendering sentences in the Standford Sentiment + Treebank v2 dataset. + + This dataset contains two classes (positive and negative) and is divided in three splits: a train + split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images + (444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative). + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), `"valid"` and ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz" + _MD5 = "2384d08e9dcfa4bd55b324e610496ee5" + + def __init__( + self, + root: str, + split: str = "train", + download: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._split = verify_str_arg(split, "split", ("train", "valid", "test")) + self._base_folder = Path(self.root) / "rendered-sst2" + self.classes = ["negative", "positive"] + self.class_to_idx = {"negative": 0, "positive": 1} + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._labels = [] + self._image_files = [] + + for p in (self._base_folder / self._split).glob("**/*.png"): + self._labels.append(self.class_to_idx[p.parent.name]) + self._image_files.append(p) + print(self._labels) + print(self._image_files) + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def extra_repr(self) -> str: + return f"split={self._split}" + + def _check_exists(self) -> bool: + for class_label in set(self.classes): + if not ( + (self._base_folder / self._split / class_label).exists() + and (self._base_folder / self._split / class_label).is_dir() + ): + return False + return True + + def _download(self) -> None: + if self._check_exists(): + return + download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) From 78f5e4593e86f335d4b7c42d679fceafe0ec8268 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 19 Jan 2022 17:14:05 +0000 Subject: [PATCH 6/7] Address PR comments --- test/test_datasets.py | 12 ++++++------ torchvision/datasets/rendered_sst2.py | 23 +++++++++++------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index fadffc8393b..4d5ac3f3c43 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2667,24 +2667,24 @@ def inject_fake_data(self, tmpdir: str, config): class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.RenderedSST2 - FEATURE_TYPES = (PIL.Image.Image, int) - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "valid", "test")) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) + SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"} def inject_fake_data(self, tmpdir: str, config): root_folder = pathlib.Path(tmpdir) / "rendered-sst2" - image_folder = root_folder / config["split"] + image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]] - num_images_per_class = 5 + num_images_per_class = {"train": 5, "test": 6, "val": 7} sampled_classes = ["positive", "negative"] for cls in sampled_classes: datasets_utils.create_image_folder( image_folder, cls, file_name_fn=lambda idx: f"{idx}.png", - num_examples=num_images_per_class, + num_examples=num_images_per_class[config["split"]], ) - return len(sampled_classes) * num_images_per_class + return len(sampled_classes) * num_images_per_class[config["split"]] if __name__ == "__main__": diff --git a/torchvision/datasets/rendered_sst2.py b/torchvision/datasets/rendered_sst2.py index 122c8c398f5..72adbcbfb93 100644 --- a/torchvision/datasets/rendered_sst2.py +++ b/torchvision/datasets/rendered_sst2.py @@ -10,8 +10,8 @@ class RenderedSST2(VisionDataset): """`The Rendered SST2 Dataset `_. - Rendered SST2 is a image classification dataset used to evaluate the models capability on optical - character recognition. This dataset was generated bu rendering sentences in the Standford Sentiment + Rendered SST2 is an image classification dataset used to evaluate the models capability on optical + character recognition. This dataset was generated by rendering sentences in the Standford Sentiment Treebank v2 dataset. This dataset contains two classes (positive and negative) and is divided in three splits: a train @@ -20,7 +20,10 @@ class RenderedSST2(VisionDataset): Args: root (string): Root directory of the dataset. - split (string, optional): The dataset split, supports ``"train"`` (default), `"valid"` and ``"test"``. + split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. @@ -33,12 +36,13 @@ def __init__( self, root: str, split: str = "train", - download: bool = True, + download: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) - self._split = verify_str_arg(split, "split", ("train", "valid", "test")) + self._split = verify_str_arg(split, "split", ("train", "val", "test")) + self._split_to_folder = {"train": "train", "val": "valid", "test": "test"} self._base_folder = Path(self.root) / "rendered-sst2" self.classes = ["negative", "positive"] self.class_to_idx = {"negative": 0, "positive": 1} @@ -52,11 +56,9 @@ def __init__( self._labels = [] self._image_files = [] - for p in (self._base_folder / self._split).glob("**/*.png"): + for p in (self._base_folder / self._split_to_folder[self._split]).glob("**/*.png"): self._labels.append(self.class_to_idx[p.parent.name]) self._image_files.append(p) - print(self._labels) - print(self._image_files) def __len__(self) -> int: return len(self._image_files) @@ -78,10 +80,7 @@ def extra_repr(self) -> str: def _check_exists(self) -> bool: for class_label in set(self.classes): - if not ( - (self._base_folder / self._split / class_label).exists() - and (self._base_folder / self._split / class_label).is_dir() - ): + if not (self._base_folder / self._split / class_label).is_dir(): return False return True From e6c95ada620f5e56c16238edd91dd27da52ac358 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Wed, 19 Jan 2022 17:19:04 +0000 Subject: [PATCH 7/7] Fix bug in dataset verification --- torchvision/datasets/rendered_sst2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/rendered_sst2.py b/torchvision/datasets/rendered_sst2.py index 72adbcbfb93..0001c6f6472 100644 --- a/torchvision/datasets/rendered_sst2.py +++ b/torchvision/datasets/rendered_sst2.py @@ -80,7 +80,7 @@ def extra_repr(self) -> str: def _check_exists(self) -> bool: for class_label in set(self.classes): - if not (self._base_folder / self._split / class_label).is_dir(): + if not (self._base_folder / self._split_to_folder[self._split] / class_label).is_dir(): return False return True