diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index ab0425aee4e..306ed1dac6a 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -75,6 +75,7 @@ You can also create your own datasets using the provided :ref:`base classes `_ Dataset + + The Cars dataset contains 16,185 images of 196 classes of cars. The data is + split into 8,144 training images and 8,041 testing images, where each class + has been split roughly in a 50-50 split + + .. note:: + + This class needs `scipy `_ to load target files from `.mat` format. + + Args: + root (string): Root directory of dataset + split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again.""" + + def __init__( + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + + try: + import scipy.io as sio + except ImportError: + raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") + + super().__init__(root, transform=transform, target_transform=target_transform) + + self._split = verify_str_arg(split, "split", ("train", "test")) + self._base_folder = pathlib.Path(root) / "stanford_cars" + devkit = self._base_folder / "devkit" + + if self._split == "train": + self._annotations_mat_path = devkit / "cars_train_annos.mat" + self._images_base_path = self._base_folder / "cars_train" + else: + self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" + self._images_base_path = self._base_folder / "cars_test" + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._samples = [ + ( + str(self._images_base_path / annotation["fname"]), + annotation["class"] - 1, # Original target mapping starts from 1, hence -1 + ) + for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] + ] + + self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() + self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + """Returns pil_image and class_id for given index""" + image_path, target = self._samples[idx] + pil_image = Image.open(image_path).convert("RGB") + + if self.transform is not None: + pil_image = self.transform(pil_image) + if self.target_transform is not None: + target = self.target_transform(target) + return pil_image, target + + def download(self) -> None: + if self._check_exists(): + return + + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", + download_root=str(self._base_folder), + md5="c3b158d763b6e2245038c8ad08e45376", + ) + if self._split == "train": + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", + download_root=str(self._base_folder), + md5="065e5b463ae28d29e77c1b4b166cfe61", + ) + else: + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", + download_root=str(self._base_folder), + md5="4ce7ebf6a94d07f1952d94dd34c4d501", + ) + download_url( + url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", + root=str(self._base_folder), + md5="b0a2b23655a3edd16d84508592a98d10", + ) + + def _check_exists(self) -> bool: + if not (self._base_folder / "devkit").is_dir(): + return False + + return self._annotations_mat_path.exists() and self._images_base_path.is_dir()