diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index ceb517ced8f..cb02f2bcaa3 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -149,6 +149,13 @@ Kinetics-400 :members: __getitem__ :special-members: +KITTI +~~~~~~~~~ + +.. autoclass:: Kitti + :members: __getitem__ + :special-members: + KMNIST ~~~~~~~~~~~~~ diff --git a/test/test_datasets.py b/test/test_datasets.py index db80b55a90f..1c45af1163c 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1702,5 +1702,41 @@ def test_classes(self, config): self.assertSequenceEqual(dataset.classes, info["classes"]) +class KittiTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Kitti + FEATURE_TYPES = (PIL.Image.Image, (list, type(None))) # test split returns None as target + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) + + def inject_fake_data(self, tmpdir, config): + kitti_dir = os.path.join(tmpdir, "Kitti", "raw") + os.makedirs(kitti_dir) + + split_to_num_examples = { + True: 1, + False: 2, + } + + # We need to create all folders(training and testing). + for is_training in (True, False): + num_examples = split_to_num_examples[is_training] + + datasets_utils.create_image_folder( + root=kitti_dir, + name=os.path.join("training" if is_training else "testing", "image_2"), + file_name_fn=lambda image_idx: f"{image_idx:06d}.png", + num_examples=num_examples, + ) + if is_training: + for image_idx in range(num_examples): + target_file_dir = os.path.join(kitti_dir, "training", "label_2") + os.makedirs(target_file_dir) + target_file_name = os.path.join(target_file_dir, f"{image_idx:06d}.txt") + target_contents = "Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01\n" # noqa + with open(target_file_name, "w") as target_file: + target_file.write(target_contents) + + return split_to_num_examples[config["train"]] + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 0ce0fd6bd60..b60fc7c7964 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -24,6 +24,7 @@ from .hmdb51 import HMDB51 from .ucf101 import UCF101 from .places365 import Places365 +from .kitti import Kitti __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', @@ -34,4 +35,5 @@ 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', 'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset', 'VisionDataset', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101', - 'Places365') + 'Places365', 'Kitti', + ) diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py new file mode 100644 index 00000000000..8db2e45b715 --- /dev/null +++ b/torchvision/datasets/kitti.py @@ -0,0 +1,161 @@ +import csv +import os +from typing import Any, Callable, List, Optional, Tuple + +from PIL import Image + +from .utils import download_and_extract_archive +from .vision import VisionDataset + + +class Kitti(VisionDataset): + """`KITTI `_ Dataset. + + Args: + root (string): Root directory where images are downloaded to. + Expects the following folder structure if download=False: + + .. code:: + + + └── Kitti + └─ raw + ├── training + | ├── image_2 + | └── label_2 + └── testing + └── image_2 + train (bool, optional): Use ``train`` split if true, else ``test`` split. + Defaults to ``train``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample + and its target as entry and returns a transformed version. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + + data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/" + resources = [ + "data_object_image_2.zip", + "data_object_label_2.zip", + ] + image_dir_name = "image_2" + labels_dir_name = "label_2" + + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, + download: bool = False, + ): + super().__init__( + root, + transform=transform, + target_transform=target_transform, + transforms=transforms, + ) + self.images = [] + self.targets = [] + self.root = root + self.train = train + self._location = "training" if self.train else "testing" + + if download: + self.download() + if not self._check_exists(): + raise RuntimeError( + "Dataset not found. You may use download=True to download it." + ) + + image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name) + if self.train: + labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name) + for img_file in os.listdir(image_dir): + self.images.append(os.path.join(image_dir, img_file)) + if self.train: + self.targets.append( + os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt") + ) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Get item at a given index. + + Args: + index (int): Index + Returns: + tuple: (image, target), where + target is a list of dictionaries with the following keys: + + - type: str + - truncated: float + - occluded: int + - alpha: float + - bbox: float[4] + - dimensions: float[3] + - locations: float[3] + - rotation_y: float + + """ + image = Image.open(self.images[index]) + target = self._parse_target(index) if self.train else None + if self.transforms: + image, target = self.transforms(image, target) + return image, target + + def _parse_target(self, index: int) -> List: + target = [] + with open(self.targets[index]) as inp: + content = csv.reader(inp, delimiter=" ") + for line in content: + target.append({ + "type": line[0], + "truncated": float(line[1]), + "occluded": int(line[2]), + "alpha": float(line[3]), + "bbox": [float(x) for x in line[4:8]], + "dimensions": [float(x) for x in line[8:11]], + "location": [float(x) for x in line[11:14]], + "rotation_y": float(line[14]), + }) + return target + + def __len__(self) -> int: + return len(self.images) + + @property + def _raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "raw") + + def _check_exists(self) -> bool: + """Check if the data directory exists.""" + folders = [self.image_dir_name] + if self.train: + folders.append(self.labels_dir_name) + return all( + os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) + for fname in folders + ) + + def download(self) -> None: + """Download the KITTI data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self._raw_folder, exist_ok=True) + + # download files + for fname in self.resources: + download_and_extract_archive( + url=f"{self.data_url}{fname}", + download_root=self._raw_folder, + filename=fname, + )