diff --git a/test/test_datasets.py b/test/test_datasets.py index 55681c8b378..79b23488ded 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2325,5 +2325,37 @@ def inject_fake_data(self, tmpdir: str, config): return total_number_of_examples +class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.CLEVRClassification + FEATURE_TYPES = (PIL.Image.Image, (int, type(None))) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) + + def inject_fake_data(self, tmpdir, config): + data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0" + + images_folder = data_folder / "images" + image_files = datasets_utils.create_image_folder( + images_folder, config["split"], lambda idx: f"CLEVR_{config['split']}_{idx:06d}.png", num_examples=5 + ) + + scenes_folder = data_folder / "scenes" + scenes_folder.mkdir() + if config["split"] != "test": + with open(scenes_folder / f"CLEVR_{config['split']}_scenes.json", "w") as file: + json.dump( + dict( + info=dict(), + scenes=[ + dict(image_filename=image_file.name, objects=[dict()] * int(torch.randint(10, ()))) + for image_file in image_files + ], + ), + file, + ) + + return len(image_files) + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 0be209250c9..c26d4ce2928 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -3,6 +3,7 @@ from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 from .cityscapes import Cityscapes +from .clevr import CLEVRClassification from .coco import CocoCaptions, CocoDetection from .dtd import DTD from .fakedata import FakeData @@ -85,4 +86,5 @@ "DTD", "FER2013", "GTSRB", + "CLEVRClassification", ) diff --git a/torchvision/datasets/clevr.py b/torchvision/datasets/clevr.py new file mode 100644 index 00000000000..7ba5ca6cc47 --- /dev/null +++ b/torchvision/datasets/clevr.py @@ -0,0 +1,88 @@ +import json +import pathlib +from typing import Any, Callable, Optional, Tuple, List +from urllib.parse import urlparse + +from PIL import Image + +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class CLEVRClassification(VisionDataset): + """`CLEVR `_ classification dataset. + + The number of objects in a scene are used as label. + + Args: + root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is + set to True. + split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in them target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If + dataset is already downloaded, it is not downloaded again. + """ + + _URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip" + _MD5 = "b11922020e72d0cd9154779b2d3d07d2" + + def __init__( + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, + ) -> None: + self._split = verify_str_arg(split, "split", ("train", "val", "test")) + super().__init__(root, transform=transform, target_transform=target_transform) + self._base_folder = pathlib.Path(self.root) / "clevr" + self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*")) + + self._labels: List[Optional[int]] + if self._split != "test": + with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file: + content = json.load(file) + num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]} + self._labels = [num_objects[image_file.name] for image_file in self._image_files] + else: + self._labels = [None] * len(self._image_files) + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_file = self._image_files[idx] + label = self._labels[idx] + + image = Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def _check_exists(self) -> bool: + return self._data_folder.exists() and self._data_folder.is_dir() + + def _download(self) -> None: + if self._check_exists(): + return + + download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5) + + def extra_repr(self) -> str: + return f"split={self._split}" diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 465f8ef79ee..009573d44ad 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -1,6 +1,7 @@ from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import Cifar10, Cifar100 +from .clevr import CLEVR from .coco import Coco from .dtd import DTD from .fer2013 import FER2013 diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py new file mode 100644 index 00000000000..447c1b5190d --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -0,0 +1,110 @@ +import functools +import io +import pathlib +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, + DatasetType, +) +from torchvision.prototype.datasets.utils._internal import ( + INFINITE_BUFFER_SIZE, + hint_sharding, + hint_shuffling, + path_comparator, + path_accessor, + getitem, +) +from torchvision.prototype.features import Label + + +class CLEVR(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "clevr", + type=DatasetType.IMAGE, + homepage="https://cs.stanford.edu/people/jcjohns/clevr/", + valid_options=dict(split=("train", "val", "test")), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + archive = HttpResource( + "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip", + sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1", + ) + return [archive] + + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: + path = pathlib.Path(data[0]) + if path.parents[1].name == "images": + return 0 + elif path.parent.name == "scenes": + return 1 + else: + return None + + def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool: + key, _ = data + return key == "scenes" + + def _add_empty_anns(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[str, io.IOBase], None]: + return data, None + + def _collate_and_decode_sample( + self, + data: Tuple[Tuple[str, io.IOBase], Optional[Dict[str, Any]]], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> Dict[str, Any]: + image_data, scenes_data = data + path, buffer = image_data + + return dict( + path=path, + image=decoder(buffer) if decoder else buffer, + label=Label(len(scenes_data["objects"])) if scenes_data else None, + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + archive_dp = resource_dps[0] + images_dp, scenes_dp = Demultiplexer( + archive_dp, + 2, + self._classify_archive, + drop_none=True, + buffer_size=INFINITE_BUFFER_SIZE, + ) + + images_dp = Filter(images_dp, path_comparator("parent.name", config.split)) + images_dp = hint_sharding(images_dp) + images_dp = hint_shuffling(images_dp) + + if config.split != "test": + scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json")) + scenes_dp = JsonParser(scenes_dp) + scenes_dp = Mapper(scenes_dp, getitem(1, "scenes")) + scenes_dp = UnBatcher(scenes_dp) + + dp = IterKeyZipper( + images_dp, + scenes_dp, + key_fn=path_accessor("name"), + ref_key_fn=getitem("image_filename"), + buffer_size=INFINITE_BUFFER_SIZE, + ) + else: + dp = Mapper(images_dp, self._add_empty_anns) + + return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index d0071206dd8..e21e8ffd25f 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -108,7 +108,7 @@ def __iter__(self) -> Iterator[Tuple[int, D]]: yield from enumerate(self.datapipe, self.start) -def _getitem_closure(obj: Any, *, items: Tuple[Any, ...]) -> Any: +def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any: for item in items: obj = obj[item] return obj @@ -118,8 +118,14 @@ def getitem(*items: Any) -> Callable[[Any], Any]: return functools.partial(_getitem_closure, items=items) +def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any: + for attr in attrs: + obj = getattr(obj, attr) + return obj + + def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D: - return cast(D, getattr(path, name)) + return cast(D, _getattr_closure(path, attrs=name.split("."))) def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D: