diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index c8baf719eea..1b31fd5ee0d 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -62,6 +62,7 @@ You can also create your own datasets using the provided :ref:`base classes `_. + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``. + target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or + ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent: + + - ``category`` (int): Label for one of the 37 pet categories. + - ``segmentation`` (PIL image): Segmentation trimap of the image. + + If empty, ``None`` will be returned as target. + + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/dtd``. If + dataset is already downloaded, it is not downloaded again. + """ + + _RESOURCES = ( + ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"), + ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"), + ) + _VALID_TARGET_TYPES = ("category", "segmentation") + + def __init__( + self, + root: str, + split: str = "trainval", + target_types: Union[Sequence[str], str] = "category", + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, + ): + self._split = verify_str_arg(split, "split", ("trainval", "test")) + if isinstance(target_types, str): + target_types = [target_types] + self._target_types = [ + verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types + ] + + super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) + self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet" + self._images_folder = self._base_folder / "images" + self._anns_folder = self._base_folder / "annotations" + self._segs_folder = self._anns_folder / "trimaps" + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + image_ids = [] + self._labels = [] + with open(self._anns_folder / f"{self._split}.txt") as file: + for line in file: + image_id, label, *_ = line.strip().split() + image_ids.append(image_id) + self._labels.append(int(label) - 1) + + self.classes = [ + " ".join(part.title() for part in raw_cls.split("_")) + for raw_cls, _ in sorted( + {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)}, + key=lambda image_id_and_label: image_id_and_label[1], + ) + ] + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + + self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids] + self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids] + + def __len__(self) -> int: + return len(self._images) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image = Image.open(self._images[idx]).convert("RGB") + + target: Any = [] + for target_type in self._target_types: + if target_type == "category": + target.append(self._labels[idx]) + else: # target_type == "segmentation" + target.append(Image.open(self._segs[idx])) + + if not target: + target = None + elif len(target) == 1: + target = target[0] + else: + target = tuple(target) + + if self.transforms: + image, target = self.transforms(image, target) + + return image, target + + def _check_exists(self) -> bool: + for folder in (self._images_folder, self._anns_folder): + if not (os.path.exists(folder) and os.path.isdir(folder)): + return False + else: + return True + + def _download(self) -> None: + if self._check_exists(): + return + + for url, md5 in self._RESOURCES: + download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5) diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 009573d44ad..012bfc3ca18 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -7,6 +7,7 @@ from .fer2013 import FER2013 from .imagenet import ImageNet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST +from .oxford_iiit_pet import OxfordIITPet from .sbd import SBD from .semeion import SEMEION from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories b/torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories new file mode 100644 index 00000000000..36d29465b04 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories @@ -0,0 +1,37 @@ +Abyssinian +American Bulldog +American Pit Bull Terrier +Basset Hound +Beagle +Bengal +Birman +Bombay +Boxer +British Shorthair +Chihuahua +Egyptian Mau +English Cocker Spaniel +English Setter +German Shorthaired +Great Pyrenees +Havanese +Japanese Chin +Keeshond +Leonberger +Maine Coon +Miniature Pinscher +Newfoundland +Persian +Pomeranian +Pug +Ragdoll +Russian Blue +Saint Bernard +Samoyed +Scottish Terrier +Shiba Inu +Siamese +Sphynx +Staffordshire Bull Terrier +Wheaten Terrier +Yorkshire Terrier diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py new file mode 100644 index 00000000000..4e43613715e --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -0,0 +1,150 @@ +import functools +import io +import pathlib +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, + DatasetType, +) +from torchvision.prototype.datasets.utils._internal import ( + INFINITE_BUFFER_SIZE, + hint_sharding, + hint_shuffling, + getitem, + path_accessor, + path_comparator, +) +from torchvision.prototype.features import Label + + +class OxfordIITPet(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "oxford-iiit-pet", + type=DatasetType.IMAGE, + homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", + valid_options=dict( + # FIXME + split=("trainval", "test", "train"), + ), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + images = HttpResource( + "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", + sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d", + decompress=True, + ) + anns = HttpResource( + "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", + sha256="52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91", + decompress=True, + ) + return [images, anns] + + def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]: + return { + "annotations": 0, + "trimaps": 1, + }.get(pathlib.Path(data[0]).parent.name) + + def _filter_images(self, data: Tuple[str, Any]) -> bool: + return pathlib.Path(data[0]).suffix == ".jpg" + + def _filter_segmentations(self, data: Tuple[str, Any]) -> bool: + return not pathlib.Path(data[0]).name.startswith(".") + + def _decode_classification_data(self, data: Dict[str, str]) -> Dict[str, Any]: + label_idx = int(data["label"]) - 1 + return dict( + label=Label(label_idx, category=self.info.categories[label_idx]), + species="cat" if data["species"] == "1" else "dog", + ) + + def _collate_and_decode_sample( + self, + data: Tuple[Tuple[Dict[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> Dict[str, Any]: + ann_data, image_data = data + classification_data, segmentation_data = ann_data + segmentation_path, segmentation_buffer = segmentation_data + image_path, image_buffer = image_data + + return dict( + self._decode_classification_data(classification_data), + segmentation_path=segmentation_path, + segmentation=decoder(segmentation_buffer) if decoder else segmentation_buffer, + image_path=image_path, + image=decoder(image_buffer) if decoder else image_buffer, + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + images_dp, anns_dp = resource_dps + + images_dp = Filter(images_dp, self._filter_images) + + split_and_classification_dp, segmentations_dp = Demultiplexer( + anns_dp, + 2, + self._classify_anns, + drop_none=True, + buffer_size=INFINITE_BUFFER_SIZE, + ) + + split_and_classification_dp = Filter( + split_and_classification_dp, path_comparator("name", f"{config.split}.txt") + ) + split_and_classification_dp = CSVDictParser( + split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" " + ) + split_and_classification_dp = hint_sharding(split_and_classification_dp) + split_and_classification_dp = hint_shuffling(split_and_classification_dp) + + segmentations_dp = Filter(segmentations_dp, self._filter_segmentations) + + anns_dp = IterKeyZipper( + split_and_classification_dp, + segmentations_dp, + key_fn=getitem("image_id"), + ref_key_fn=path_accessor("stem"), + buffer_size=INFINITE_BUFFER_SIZE, + ) + + dp = IterKeyZipper( + anns_dp, + images_dp, + key_fn=getitem(0, "image_id"), + ref_key_fn=path_accessor("stem"), + buffer_size=INFINITE_BUFFER_SIZE, + ) + return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + + def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: + return self._classify_anns(data) == 0 + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + config = self.default_config + dp = self.resources(config)[1].load(pathlib.Path(root) / self.name) + dp = Filter(dp, self._filter_split_and_classification_anns) + dp = Filter(dp, path_comparator("name", f"{config.split}.txt")) + dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ") + raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp} + raw_categories, _ = zip( + *sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1])) + ) + return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories]