diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 08c841399c2..036a2ac9a12 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -46,6 +46,7 @@ You can also create your own datasets using the provided :ref:`base classes `_ Dataset. + + .. warning:: + + This class needs `scipy `_ to load target files from `.mat` format. + + Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The + flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of + between 40 and 258 images. + + The images have large scale, pose and light variations. In addition, there are categories that + have large variations within the category, and several very similar categories. + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image and returns a + transformed version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/" + _file_dict = { # filename, md5 + "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"), + "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"), + "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"), + } + _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"} + + def __init__( + self, + root: str, + split: str = "train", + download: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._split = verify_str_arg(split, "split", ("train", "val", "test")) + self._base_folder = Path(self.root) / "flowers-102" + self._images_folder = self._base_folder / "jpg" + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + from scipy.io import loadmat + + set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True) + image_ids = set_ids[self._splits_map[self._split]].tolist() + + labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True) + image_id_to_label = dict(enumerate(labels["labels"].tolist(), 1)) + + self._labels = [] + self._image_files = [] + for image_id in image_ids: + self._labels.append(image_id_to_label[image_id]) + self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg") + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def extra_repr(self) -> str: + return f"split={self._split}" + + def _check_integrity(self): + if not (self._images_folder.exists() and self._images_folder.is_dir()): + return False + + for id in ["label", "setid"]: + filename, md5 = self._file_dict[id] + if not check_integrity(str(self._base_folder / filename), md5): + return False + return True + + def download(self): + if self._check_integrity(): + return + download_and_extract_archive( + f"{self._download_url_prefix}{self._file_dict['image'][0]}", + str(self._base_folder), + md5=self._file_dict["image"][1], + ) + for id in ["label", "setid"]: + filename, md5 = self._file_dict[id] + download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)