From f5df5fc2a6a6c72589fc3beab47bcf0ce84f8da9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Nov 2021 11:27:34 +0000 Subject: [PATCH 1/7] Add Kitti and Sintel --- torchvision/datasets/__init__.py | 3 + torchvision/datasets/_optical_flow.py | 149 ++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 torchvision/datasets/_optical_flow.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 72a73d1d51b..5edcd1bc584 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,3 +1,4 @@ +from ._optical_flow import KittiFlow, Sintel from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -71,4 +72,6 @@ "INaturalist", "LFWPeople", "LFWPairs", + "KittiFlow", + "Sintel", ) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py new file mode 100644 index 00000000000..b979790ccaa --- /dev/null +++ b/torchvision/datasets/_optical_flow.py @@ -0,0 +1,149 @@ +import os +from abc import ABC, abstractmethod +from glob import glob +from pathlib import Path + +import numpy as np +import torch +from PIL import Image + +from ..io.image import _read_png_16 +from .vision import VisionDataset + + +__all__ = ( + "KittiFlow", + "Sintel", +) + + +class FlowDataset(ABC, VisionDataset): + def __init__(self, root, transforms=None): + + super().__init__(root=root) + self.transforms = transforms + + self._flow_list = [] + self._image_list = [] + + def _read_img(self, file_name): + return Image.open(file_name) + + @abstractmethod + def _read_flow(self, file_name): + # Return the flow or a tuple (flow, valid) for datasets where the valid mask is built-in + pass + + def __getitem__(self, index): + # Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid + # For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow), + # and it's up to whatever consumes the dataset to decide what `valid` should be. + + img1 = self._read_img(self._image_list[index][0]) + img2 = self._read_img(self._image_list[index][1]) + flow = self._read_flow(self._flow_list[index]) if self._flow_list else None + + if isinstance(flow, tuple): + flow, valid = flow + else: + valid = None + + if self.transforms is not None: + img1, img2, flow, valid = self.transforms(img1, img2, flow, valid) + + if valid is None: + return img1, img2, flow + else: + return img1, img2, flow, valid + + def __len__(self): + return len(self._image_list) + + +class Sintel(FlowDataset): + def __init__( + self, + root, + split="train", + dstype="clean", + transforms=None, + ): + + super().__init__(root=root, transforms=transforms) + + if split not in ("train", "test"): + raise ValueError("split must be either 'train' or 'test'") + + if dstype not in ("clean", "final"): + raise ValueError("dstype must be either 'clean' or 'final'") + + split_dir = "training" if split == "train" else split + flow_root = Path(root) / split_dir / "flow" + image_root = Path(root) / split_dir / dstype + + for scene in os.listdir(image_root): + image_list = sorted(glob(str(image_root / scene / "*.png"))) + for i in range(len(image_list) - 1): + self._image_list += [[image_list[i], image_list[i + 1]]] + + if split == "train": + self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) + + def _read_flow(self, file_name): + return _read_flo(file_name) + + +class KittiFlow(FlowDataset): + def __init__( + self, + root, + split="train", + transforms=None, + ): + super().__init__(root=root, transforms=transforms) + + if split not in ("train", "test"): + raise ValueError("split must be either 'train' or 'test'") + + root = Path(root) / ("training" if split == "train" else split) + images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) + images2 = sorted(glob(str(root / "image_2" / "*_11.png"))) + + for img1, img2 in zip(images1, images2): + self._image_list += [[img1, img2]] + + if split == "train": + self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) + + def _read_flow(self, file_name): + return _read_16bits_png_with_flow_and_valid_mask(file_name) + + +def _read_flo(file_name): + """Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(file_name, "rb") as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + raise ValueError("Magic number incorrect. Invalid .flo file") + + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + +def _read_16bits_png_with_flow_and_valid_mask(file_name): + + flow_and_valid = _read_png_16(file_name).to(torch.float32) + flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] + flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive + + return flow, valid From c3dd41b7318db1e9f8b4f2633258ff6474eac110 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Nov 2021 18:40:47 +0000 Subject: [PATCH 2/7] Add tests --- test/datasets_utils.py | 8 +- test/test_datasets.py | 127 ++++++++++++++++++++++++++ torchvision/datasets/_optical_flow.py | 53 +++++++---- 3 files changed, 168 insertions(+), 20 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 3fb89a6d3da..6c3124ae9e7 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -198,6 +198,7 @@ class DatasetTestCase(unittest.TestCase): ``transforms``, or ``download``. - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not available, the tests are skipped. + - EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on. The fake data should resemble the original data as close as necessary, while containing only few examples. During @@ -249,6 +250,8 @@ def test_baz(self): ADDITIONAL_CONFIGS = None REQUIRED_PACKAGES = None + EXTRA_PATCHES = None + # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS. _TRANSFORM_KWARGS = { "transform", @@ -374,6 +377,9 @@ def create_dataset( if patch_checks: patchers.update(self._patch_checks()) + if self.EXTRA_PATCHES is not None: + patchers.update(self.EXTRA_PATCHES) + with get_tmp_dir() as tmpdir: args = self.dataset_args(tmpdir, complete_config) info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None @@ -381,7 +387,7 @@ def create_dataset( with self._maybe_apply_patches(patchers), disable_console_output(): dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs) - yield dataset, info + yield dataset, info @classmethod def setUpClass(cls): diff --git a/test/test_datasets.py b/test/test_datasets.py index 575e5ccb811..02de9d4e0d8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1871,5 +1871,132 @@ def _inject_pairs(self, root, num_pairs, same): datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250) +class SintelTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Sintel + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final")) + # We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files, + # which is something we want to # avoid. + _FAKE_FLOW = "Fake Flow" + EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)} + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None))) + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "Sintel" + + num_images_per_scene = 3 if config["split"] == "train" else 4 + num_scenes = 2 + + for split_dir in ("training", "test"): + for pass_name in ("clean", "final"): + image_root = root / split_dir / pass_name + + for scene_id in range(num_scenes): + scene_dir = image_root / f"scene_{scene_id}" + datasets_utils.create_image_folder( + image_root, + name=str(scene_dir), + file_name_fn=lambda image_idx: f"frame_000{image_idx}.png", + num_examples=num_images_per_scene, + ) + + # For the ground truth flow value we just create empty files so that they're properly discovered, + # see comment above about EXTRA_PATCHES + flow_root = root / "training" / "flow" + for scene_id in range(num_scenes): + scene_dir = flow_root / f"scene_{scene_id}" + os.makedirs(scene_dir) + for i in range(num_images_per_scene - 1): + open(str(scene_dir / f"frame_000{i}.flo"), "a").close() + + # with e.g. num_images_per_scene = 3, for a single scene with have 3 images + # which are frame_0000, frame_0001 and frame_0002 + # They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002), + # that is 3 - 1 = 2 examples. Hence the formula below + num_examples = (num_images_per_scene - 1) * num_scenes + return num_examples + + def test_flow(self): + # Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images + with self.create_dataset(split="train") as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow in dataset: + assert flow == self._FAKE_FLOW + + # Make sure flow is always None for test split + with self.create_dataset(split="test") as (dataset, _): + assert dataset._image_list and not dataset._flow_list + for _, _, flow in dataset: + assert flow is None + + def test_bad_input(self): + with pytest.raises(ValueError, match="split must be either"): + with self.create_dataset(split="bad"): + pass + + with pytest.raises(ValueError, match="pass_name must be either"): + with self.create_dataset(pass_name="bad"): + pass + + +class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.KittiFlow + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "Kitti" + + num_examples = 2 if config["split"] == "train" else 3 + for split_dir in ("training", "test"): + + datasets_utils.create_image_folder( + root / split_dir, + name="image_2", + file_name_fn=lambda image_idx: f"{image_idx}_10.png", + num_examples=num_examples, + ) + datasets_utils.create_image_folder( + root / split_dir, + name="image_2", + file_name_fn=lambda image_idx: f"{image_idx}_11.png", + num_examples=num_examples, + ) + + # For kitti the ground truth flows are encoded as 16-bits pngs. + # create_image_folder() will actually create 8-bits pngs, but it doesn't + # matter much: the flow reader will still be able to read the files, it + # will just be garbage flow value - but we don't care about that here. + datasets_utils.create_image_folder( + root / "training", + name="flow_occ", + file_name_fn=lambda image_idx: f"{image_idx}_10.png", + num_examples=num_examples, + ) + + return num_examples + + def test_flow_and_valid(self): + # Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images + # Also assert flow and valid are of the expected shape + with self.create_dataset(split="train") as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow, valid in dataset: + two, h, w = flow.shape + assert two == 2 + assert valid.shape == (h, w) + + # Make sure flow and valid are always None for test split + with self.create_dataset(split="test") as (dataset, _): + assert dataset._image_list and not dataset._flow_list + for _, _, flow, valid in dataset: + assert flow is None + assert valid is None + + def test_bad_input(self): + with pytest.raises(ValueError, match="split must be either"): + with self.create_dataset(split="bad"): + pass + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index b979790ccaa..4d26ee0d68c 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -18,6 +18,11 @@ class FlowDataset(ABC, VisionDataset): + # Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid + # For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow), + # and it's up to whatever consumes the dataset to decide what `valid` should be. + _has_builtin_flow_mask = False + def __init__(self, root, transforms=None): super().__init__(root=root) @@ -31,30 +36,30 @@ def _read_img(self, file_name): @abstractmethod def _read_flow(self, file_name): - # Return the flow or a tuple (flow, valid) for datasets where the valid mask is built-in + # Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True pass def __getitem__(self, index): - # Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid - # For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow), - # and it's up to whatever consumes the dataset to decide what `valid` should be. img1 = self._read_img(self._image_list[index][0]) img2 = self._read_img(self._image_list[index][1]) - flow = self._read_flow(self._flow_list[index]) if self._flow_list else None - if isinstance(flow, tuple): - flow, valid = flow + if self._flow_list: # it will be empty for some dataset when split="test" + flow = self._read_flow(self._flow_list[index]) + if self._has_builtin_flow_mask: + flow, valid = flow + else: + valid = None else: - valid = None + flow = valid = None if self.transforms is not None: img1, img2, flow, valid = self.transforms(img1, img2, flow, valid) - if valid is None: - return img1, img2, flow - else: + if self._has_builtin_flow_mask: return img1, img2, flow, valid + else: + return img1, img2, flow def __len__(self): return len(self._image_list) @@ -65,7 +70,7 @@ def __init__( self, root, split="train", - dstype="clean", + pass_name="clean", transforms=None, ): @@ -74,12 +79,14 @@ def __init__( if split not in ("train", "test"): raise ValueError("split must be either 'train' or 'test'") - if dstype not in ("clean", "final"): - raise ValueError("dstype must be either 'clean' or 'final'") + if pass_name not in ("clean", "final"): + raise ValueError("pass_name must be either 'clean' or 'final'") + + root = Path(root) / "Sintel" split_dir = "training" if split == "train" else split - flow_root = Path(root) / split_dir / "flow" - image_root = Path(root) / split_dir / dstype + image_root = root / split_dir / pass_name + flow_root = root / "training" / "flow" for scene in os.listdir(image_root): image_list = sorted(glob(str(image_root / scene / "*.png"))) @@ -94,6 +101,8 @@ def _read_flow(self, file_name): class KittiFlow(FlowDataset): + _has_builtin_flow_mask = True + def __init__( self, root, @@ -105,10 +114,15 @@ def __init__( if split not in ("train", "test"): raise ValueError("split must be either 'train' or 'test'") - root = Path(root) / ("training" if split == "train" else split) + root = Path(root) / "Kitti" / ("training" if split == "train" else split) images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) images2 = sorted(glob(str(root / "image_2" / "*_11.png"))) + if not images1 or not images2: + raise FileNotFoundError( + "Could not find the Kitti flow images. Please make sure the directory structure is correct." + ) + for img1, img2 in zip(images1, images2): self._image_list += [[img1, img2]] @@ -137,7 +151,7 @@ def _read_flo(file_name): data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) # Reshape data into 3D array (columns, rows, bands) # The reshape here is for visualization, the original code is (w,h,2) - return np.resize(data, (int(h), int(w), 2)) + return np.resize(data, (2, int(h), int(w))) def _read_16bits_png_with_flow_and_valid_mask(file_name): @@ -146,4 +160,5 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name): flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive - return flow, valid + # For consistency with other datasets, we convert to numpy + return flow.numpy(), valid.numpy() From 8c766021600353ec055be2b1197067664435e85c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 09:40:59 +0000 Subject: [PATCH 3/7] Add some docs --- torchvision/datasets/_optical_flow.py | 51 ++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 4d26ee0d68c..2c27f0c1af3 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -101,6 +101,29 @@ def _read_flow(self, file_name): class KittiFlow(FlowDataset): + """Kitti Dataset for optical flow (2015) + + The dataset can be downloaded `from here + `_. + + The dataset is expected to have the following structure: :: + + root + Kitti + testing + image_2 + training + image_2 + flow_occ + + + Args: + root (string): Root directory of the KittiFlow Dataset. + split (string, optional): The dataset split, either "train" (default) or "test" + transforms (callable, optional): A function/transform that takes in + ``img1, img2, flow, valid`` and returns a transformed version. + """ + _has_builtin_flow_mask = True def __init__( @@ -129,6 +152,21 @@ def __init__( if split == "train": self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) + def __getitem__(self, index): + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow, + valid)`` where ``valid`` is a numpy boolean mask of shape (H, W) + indicating which flow values are valid. The flow is a numpy array of + shape (2, H, W) and the images are PIL images. If `split="test"`, a + 4-tuple with ``(img1, img2, None, None)`` is returned. + """ + return super().__getitem__(index) + def _read_flow(self, file_name): return _read_16bits_png_with_flow_and_valid_mask(file_name) @@ -137,21 +175,16 @@ def _read_flo(file_name): """Read .flo file in Middlebury format""" # Code adapted from: # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy - # WARNING: this will work on little-endian architectures (eg Intel x86) only! - # print 'fn = %s'%(fn) with open(file_name, "rb") as f: magic = np.fromfile(f, np.float32, count=1) if 202021.25 != magic: raise ValueError("Magic number incorrect. Invalid .flo file") - w = np.fromfile(f, np.int32, count=1) - h = np.fromfile(f, np.int32, count=1) - # print 'Reading %d x %d flo file\n' % (w, h) - data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) - # Reshape data into 3D array (columns, rows, bands) - # The reshape here is for visualization, the original code is (w,h,2) - return np.resize(data, (2, int(h), int(w))) + w = int(np.fromfile(f, np.int32, count=1)) + h = int(np.fromfile(f, np.int32, count=1)) + data = np.fromfile(f, np.float32, count=2 * w * h) + return data.reshape(2, h, w) def _read_16bits_png_with_flow_and_valid_mask(file_name): From e6ecc4ef7f70428650557643b42790ba68d615f9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 09:47:20 +0000 Subject: [PATCH 4/7] More docs --- docs/source/datasets.rst | 2 + torchvision/datasets/_optical_flow.py | 53 ++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index fdf01eb8ffa..89dfe7e08d8 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes `_. + + The dataset is expected to have the following structure: :: + + root + Sintel + testing + clean + scene_1 + scene_2 + ... + final + scene_1 + scene_2 + ... + training + clean + scene_1 + scene_2 + ... + final + scene_1 + scene_2 + ... + flow + scene_1 + scene_2 + ... + + Args: + root (string): Root directory of the Sintel Dataset. + split (string, optional): The dataset split, either "train" (default) or "test" + transforms (callable, optional): A function/transform that takes in + ``img1, img2, flow, valid`` and returns a transformed version. + ``valid`` is expected for consistency with other datasets which + return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. + """ def __init__( self, root, @@ -96,6 +135,19 @@ def __init__( if split == "train": self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) + def __getitem__(self, index): + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow). + The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a + 4-tuple with ``(img1, img2, None)`` is returned. + """ + return super().__getitem__(index) + def _read_flow(self, file_name): return _read_flo(file_name) @@ -116,7 +168,6 @@ class KittiFlow(FlowDataset): image_2 flow_occ - Args: root (string): Root directory of the KittiFlow Dataset. split (string, optional): The dataset split, either "train" (default) or "test" From 721b94b686c00c6fd4d3e4829d07849f7274abfc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 09:52:05 +0000 Subject: [PATCH 5/7] more docs --- torchvision/datasets/_optical_flow.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index ad15d591617..51d522fbddd 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -66,9 +66,7 @@ def __len__(self): class Sintel(FlowDataset): - """Sintel Dataset for optical flow. - - The dataset can be downloaded `from here `_. + """`Sintel `_ Dataset for optical flow. The dataset is expected to have the following structure: :: @@ -100,11 +98,14 @@ class Sintel(FlowDataset): Args: root (string): Root directory of the Sintel Dataset. split (string, optional): The dataset split, either "train" (default) or "test" + pass_name (string, optional): The pass to use, either "clean" (default) or "final". See link above for + details on the different passes. transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid`` and returns a transformed version. ``valid`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ + def __init__( self, root, @@ -144,7 +145,7 @@ def __getitem__(self, index): Returns: tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow). The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a - 4-tuple with ``(img1, img2, None)`` is returned. + 3-tuple with ``(img1, img2, None)`` is returned. """ return super().__getitem__(index) @@ -153,10 +154,7 @@ def _read_flow(self, file_name): class KittiFlow(FlowDataset): - """Kitti Dataset for optical flow (2015) - - The dataset can be downloaded `from here - `_. + """`Kitti `_ dataset for optical flow (2015). The dataset is expected to have the following structure: :: From 6f95da0bc5cd3dc2503bd5cf7793076d4b043e03 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 10:38:58 +0000 Subject: [PATCH 6/7] test -> testing for Kitti --- test/test_datasets.py | 2 +- torchvision/datasets/_optical_flow.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 02de9d4e0d8..57c2a80181a 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1947,7 +1947,7 @@ def inject_fake_data(self, tmpdir, config): root = pathlib.Path(tmpdir) / "Kitti" num_examples = 2 if config["split"] == "train" else 3 - for split_dir in ("training", "test"): + for split_dir in ("training", "testing"): datasets_utils.create_image_folder( root / split_dir, diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index aef41852c73..dd699d80fe2 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -186,7 +186,7 @@ def __init__( if split not in ("train", "test"): raise ValueError("split must be either 'train' or 'test'") - root = Path(root) / "Kitti" / ("training" if split == "train" else split) + root = Path(root) / "Kitti" / (split + "ing") images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) images2 = sorted(glob(str(root / "image_2" / "*_11.png"))) From 3b8ba30c73df4ae5fcc55d39cbbc1434acff66ce Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 10:45:50 +0000 Subject: [PATCH 7/7] less vert space --- torchvision/datasets/_optical_flow.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index dd699d80fe2..7cb19e8d8c4 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -106,14 +106,7 @@ class Sintel(FlowDataset): return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ - def __init__( - self, - root, - split="train", - pass_name="clean", - transforms=None, - ): - + def __init__(self, root, split="train", pass_name="clean", transforms=None): super().__init__(root=root, transforms=transforms) if split not in ("train", "test"): @@ -175,12 +168,7 @@ class KittiFlow(FlowDataset): _has_builtin_flow_mask = True - def __init__( - self, - root, - split="train", - transforms=None, - ): + def __init__(self, root, split="train", transforms=None): super().__init__(root=root, transforms=transforms) if split not in ("train", "test"):